In [1]:
import torch

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cuda device


In [3]:
data_file_name = 'tiny_shakespeare'
data_file_path = f'data/{data_file_name}.txt'
with open(data_file_path, 'r', encoding = 'utf-8') as f:
    text = f.read()
len(text)

1115394

### Vocab

In [4]:
unique_symbols = sorted(list(set(text)))
vocab_size = len(unique_symbols)

print(''.join(unique_symbols))
print(f'{vocab_size=}')


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab_size=65


### Encoder/decoder

In [5]:
stoi = {s: i for i, s in enumerate(unique_symbols)}
itos = {i: s for s, i in stoi.items()}

encode = lambda s: [stoi[ch] for ch in s]  # symbols to tokens
decode = lambda l: ''.join([itos[i] for i in l])  # tokens to symbols

### Token-Pairs Encoding

In [6]:
def get_stats(ids):
    stat = {}
    for ch1, ch2 in zip(ids, ids[1:]):
        stat[(ch1, ch2)] = stat.get((ch1, ch2), 0) + 1
    return stat

tokens = encode(text)
stats = get_stats(tokens)

In [7]:
sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]

[((43, 1), 27643),
 ((1, 58), 23837),
 ((58, 46), 22739),
 ((46, 43), 18203),
 ((58, 1), 16508),
 ((57, 1), 15364),
 ((42, 1), 14165),
 ((6, 1), 14098),
 ((1, 39), 13541),
 ((53, 59), 12730)]

In [8]:
# decode((43, 1)), decode((1, 58)), decode((58, 46))

In [9]:
new_pairs = {}

def merge(ids, pair, new_token):
    new_tokens = []
    i = 0
    while i < len(ids):
        if (i < len(ids) - 1) and (ids[i] == pair[0]) and (ids[i+1] == pair[1]):
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(ids[i])
            i += 1

    return new_tokens

merge([1, 2, 3, 4, 1, 2, 4, 5], (1, 2), 99)

[99, 3, 4, 99, 4, 5]

In [10]:
# first merge
print(f'initial len: {len(tokens):_}')

pair_to_merge = max(stats, key=stats.get)
tokens2 = merge(tokens, pair_to_merge, vocab_size)

print(pair_to_merge, stats[pair_to_merge])
print(f'len after first merge: {len(tokens2):_}')
print(f'diff: {len(tokens) - len(tokens2):_}')

initial len: 1_115_394
(43, 1) 27643
len after first merge: 1_087_751
diff: 27_643


In [11]:
new_vocab_size = vocab_size * 2
tokens_merged = list(tokens)
merges = {}

for new_token in range(vocab_size, new_vocab_size):
    stats = get_stats(tokens_merged)
    top_pair = max(stats, key = stats.get)

    tokens_merged = merge(tokens_merged, top_pair, new_token)
    merges[new_token] = top_pair

merges

{65: (43, 1),
 66: (58, 46),
 67: (58, 1),
 68: (57, 1),
 69: (42, 1),
 70: (6, 1),
 71: (53, 59),
 72: (43, 56),
 73: (47, 52),
 74: (63, 1),
 75: (39, 52),
 76: (10, 0),
 77: (53, 56),
 78: (53, 1),
 79: (43, 52),
 80: (0, 0),
 81: (39, 56),
 82: (1, 66),
 83: (53, 52),
 84: (50, 50),
 85: (46, 39),
 86: (6, 0),
 87: (8, 80),
 88: (47, 68),
 89: (43, 57),
 90: (63, 71),
 91: (1, 57),
 92: (58, 78),
 93: (75, 69),
 94: (53, 61),
 95: (43, 39),
 96: (1, 51),
 97: (1, 61),
 98: (53, 44),
 99: (1, 46),
 100: (73, 45),
 101: (53, 51),
 102: (1, 39),
 103: (41, 46),
 104: (66, 65),
 105: (57, 58),
 106: (1, 40),
 107: (52, 53),
 108: (47, 56),
 109: (44, 77),
 110: (60, 65),
 111: (43, 70),
 112: (47, 66),
 113: (82, 65),
 114: (57, 43),
 115: (50, 47),
 116: (32, 46),
 117: (84, 1),
 118: (56, 43),
 119: (57, 67),
 120: (39, 67),
 121: (13, 52),
 122: (21, 1),
 123: (43, 81),
 124: (47, 51),
 125: (47, 58),
 126: (53, 53),
 127: (45, 46),
 128: (39, 58),
 129: (47, 57)}

In [12]:
print(f'initial len: {len(tokens):_}')
print(f'len after merges: {len(tokens_merged):_}')
print(f'comparation ratio: {len(tokens) / len(tokens_merged):.2f}')

initial len: 1_115_394
len after merges: 741_815
comparation ratio: 1.50


### Save vocab

In [13]:
# save decoder
import json

vocab = {'initial': itos, 'merges': {f'{k}': v for k, v in merges.items()}}

vocab_file_path = f'vocabs/{data_file_name}_vocab.json'
with open(vocab_file_path, 'w', encoding ='utf8') as f: 
    json.dump(vocab, f,  indent=4, ensure_ascii = False)

### Train/val split

In [14]:
data = torch.tensor(encode(text), dtype = torch.long)
n = int(0.9 * len(data))

train_data = data[:n]
val_data = data[n:]

print(train_data.shape, val_data.shape)

torch.Size([1003854]) torch.Size([111540])


### Build a dataset

In [15]:
import transformer.transformer
import transformer.config

config = transformer.config.config_default
config.vocab_size = new_vocab_size
config.device = device

In [16]:
torch.manual_seed(1337)
batch_size = 64

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.block_size, (batch_size,))
    x_batch = torch.stack([data[i:i+config.block_size] for i in ix])
    y_batch = torch.stack([data[i+1:i+config.block_size+1] for i in ix])
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    return x_batch, y_batch

## Model

In [17]:
m = transformer.transformer.Decoder(config).to(device)

In [18]:
num_eval_batches = 100

@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(num_eval_batches)
        for n_batch in range(num_eval_batches):
            x_batch, y_batch = get_batch(split)
            _, loss = m(x_batch, y_batch)
            losses[n_batch] = loss.item()
        out[split] = losses.mean().item()
    m.train()
    return out

In [19]:
max_iters = 6_000
scheduler_steps = [4_000, 5_000, 5_700]

optimizer = torch.optim.AdamW(m.parameters(), lr = 1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones = scheduler_steps, 
                                                 gamma = 0.1)

In [None]:
for step in range(max_iters):
    x_batch, y_batch = get_batch('train')
    logits, loss = m(x_batch, y_batch)

    if step % 1000 == 0 or step == max_iters - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {scheduler.get_last_lr()[0]}")

    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()  
    scheduler.step()

# tiny shakespeare:
# added just blocks: train loss 2.1716, val loss 2.2226
# added skip connections: train loss 1.9001, val loss 2.0237
# added layer norm: train loss 1.8880, val loss 2.0194
# after model was scaled up: train loss 1.2350, val loss 1.4911

step 0: train loss 4.3826, val loss 4.3819, lr 0.001
step 1000: train loss 1.5468, val loss 1.7271, lr 0.001
step 2000: train loss 1.3874, val loss 1.5989, lr 0.001
step 3000: train loss 1.3218, val loss 1.5566, lr 0.001
step 4000: train loss 1.2793, val loss 1.5202, lr 0.0001
step 5000: train loss 1.2431, val loss 1.5052, lr 1e-05
step 5999: train loss 1.2350, val loss 1.4911, lr 1.0000000000000002e-06


In [23]:
start_token = torch.zeros((1, 1), dtype= torch.int, device = device)
gen_text = m.generate(start_token, 500)
print(decode(unmerge(gen_text[0].tolist())))


First, thy case.

ANTIGONUS:
I know 'tis well?
'Tis good Montague doef Northumberland pity?'
And go,
And fear'st sharp him atternony; you are all.
And to confusions, my sentence worse to thence;
On some to a voice must blushing to do be patient,
To Both her band up Burgunding but frottunes
Tell for one several thandsmen how you have
The gods you seen?

WARWICK:
O he break, Mistoner, is bounded harb'd put to put the.

FoRD:

ROMEO:
I will not I do remember in stand functly
The creek morn of good,
