In [1]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cuda device


In [2]:
data_file_name = 'tiny_shakespeare'
data_file_path = f'data/{data_file_name}.txt'
with open(data_file_path, 'r', encoding = 'utf-8') as f:
    text = f.read()
len(text)

1115394

### Vocab

In [3]:
unique_symbols = sorted(list(set(text)))
vocab_size = len(unique_symbols)

print(''.join(unique_symbols))
print(f'{vocab_size=}')


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab_size=65


### Encoder/decoder

In [4]:
stoi = {s: i for i, s in enumerate(unique_symbols)}
itos = {i: s for s, i in stoi.items()}

encode = lambda s: [stoi[ch] for ch in s]  # symbols to tokens
decode = lambda l: ''.join([itos[i] for i in l])  # tokens to symbols

### Byte Pair Encoding

In [5]:
tokens = list(text.encode('utf-8'))
print('text len:', len(text))
print('tokens len', len(tokens))

# for tiny-shakespeare it is the same number 
# because this dataset consists of only eng chars, punctuation and numbers 
# but for reddit-jokes 'tokens len' should be a bit bigger!

text len: 1115394
tokens len 1115394


In [6]:
def get_stats(ids):
    stat = {}
    for ch1, ch2 in zip(ids, ids[1:]):
        stat[(ch1, ch2)] = stat.get((ch1, ch2), 0) + 1
    return stat

stats = get_stats(tokens)
stats

{(70, 105): 308,
 (105, 114): 2700,
 (114, 115): 2649,
 (115, 116): 6823,
 (116, 32): 16508,
 (32, 67): 693,
 (67, 105): 122,
 (105, 116): 6114,
 (116, 105): 2977,
 (105, 122): 208,
 (122, 101): 285,
 (101, 110): 7568,
 (110, 58): 587,
 (58, 10): 8762,
 (10, 66): 1669,
 (66, 101): 367,
 (101, 102): 903,
 (102, 111): 3325,
 (111, 114): 8458,
 (114, 101): 9843,
 (101, 32): 27643,
 (32, 119): 10546,
 (119, 101): 2982,
 (32, 112): 4490,
 (112, 114): 1720,
 (114, 111): 3613,
 (111, 99): 373,
 (99, 101): 3354,
 (101, 101): 3807,
 (101, 100): 3676,
 (100, 32): 14165,
 (32, 97): 13541,
 (97, 110): 10197,
 (110, 121): 484,
 (121, 32): 10283,
 (32, 102): 6563,
 (102, 117): 563,
 (117, 114): 5313,
 (114, 116): 2278,
 (116, 104): 22739,
 (104, 101): 18203,
 (101, 114): 11771,
 (114, 44): 1540,
 (44, 32): 14098,
 (32, 104): 11925,
 (101, 97): 6288,
 (97, 114): 7081,
 (114, 32): 10516,
 (32, 109): 10786,
 (109, 101): 6135,
 (32, 115): 12287,
 (115, 112): 1136,
 (112, 101): 1989,
 (97, 107): 1561,
 (

In [7]:
def merge(ids, pair, new_token):
    new_tokens = []
    i = 0
    while i < len(ids):
        if (i < len(ids) - 1) and (ids[i] == pair[0]) and (ids[i+1] == pair[1]):
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(ids[i])
            i += 1

    return new_tokens

merge([1, 2, 3, 4, 1, 2, 4, 5], (1, 2), 99)

[99, 3, 4, 99, 4, 5]

In [8]:
# first merge
print(f'initial len: {len(tokens):_}')

pair_to_merge = max(stats, key=stats.get)
tokens2 = merge(tokens, pair_to_merge, vocab_size)

print('top pair:', pair_to_merge, stats[pair_to_merge])
print(f'decoded pair: "{bytes(pair_to_merge).decode('utf-8')}"')
print(f'len after first merge: {len(tokens2):_}')
print(f'diff: {len(tokens) - len(tokens2):_}')

initial len: 1_115_394
top pair: (101, 32) 27643
decoded pair: "e "
len after first merge: 1_087_751
diff: 27_643


In [9]:
INIT_VOCAB_MAX = 256
NO_OF_MERGES = 20
new_vocab_size = INIT_VOCAB_MAX + NO_OF_MERGES

In [10]:
tokens_merged = list(tokens)
merges = {}

for new_token in range(INIT_VOCAB_MAX, new_vocab_size):
    stats = get_stats(tokens_merged)
    top_pair = max(stats, key = stats.get)

    tokens_merged = merge(tokens_merged, top_pair, new_token)
    merges[new_token] = top_pair

merges

{256: (101, 32),
 257: (116, 104),
 258: (116, 32),
 259: (115, 32),
 260: (100, 32),
 261: (44, 32),
 262: (111, 117),
 263: (101, 114),
 264: (105, 110),
 265: (121, 32),
 266: (97, 110),
 267: (58, 10),
 268: (111, 114),
 269: (111, 32),
 270: (101, 110),
 271: (10, 10),
 272: (97, 114),
 273: (32, 257),
 274: (111, 110),
 275: (108, 108)}

In [11]:
print(f'initial len: {len(tokens):_}')
print(f'len after merges: {len(tokens_merged):_}')
print(f'comparation ratio: {len(tokens) / len(tokens_merged):.2f}')

initial len: 1_115_394
len after merges: 882_737
comparation ratio: 1.26


### Save vocab

In [12]:
# save decoder
import json

vocab = {'merges': {f'{k}': v for k, v in merges.items()}}

vocab_file_path = f'vocabs/{data_file_name}_bpe_vocab.json'
with open(vocab_file_path, 'w', encoding ='utf8') as f: 
    json.dump(vocab, f,  indent = 4, ensure_ascii = False)

### Train/val split

In [13]:
data = torch.tensor(tokens_merged, dtype = torch.long)
n = int(0.9 * len(data))

train_data = data[:n]
val_data = data[n:]

print(train_data.shape, val_data.shape)

torch.Size([794463]) torch.Size([88274])


### Build a dataset

In [14]:
import transformer.transformer
import transformer.config

config = transformer.config.config_default
config.vocab_size = new_vocab_size
config.device = device

In [15]:
torch.manual_seed(1337)
batch_size = 64

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.block_size, (batch_size,))
    x_batch = torch.stack([data[i:i+config.block_size] for i in ix])
    y_batch = torch.stack([data[i+1:i+config.block_size+1] for i in ix])
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    return x_batch, y_batch

## Model

In [16]:
m = transformer.transformer.Decoder(config).to(device)

In [17]:
num_eval_batches = 100

@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(num_eval_batches)
        for n_batch in range(num_eval_batches):
            x_batch, y_batch = get_batch(split)
            _, loss = m(x_batch, y_batch)
            losses[n_batch] = loss.item()
        out[split] = losses.mean().item()
    m.train()
    return out

In [18]:
max_iters = 6_000
scheduler_steps = [4_000, 5_000, 5_700]

optimizer = torch.optim.AdamW(m.parameters(), lr = 1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones = scheduler_steps, 
                                                 gamma = 0.1)

In [19]:
for step in range(max_iters):
    x_batch, y_batch = get_batch('train')
    logits, loss = m(x_batch, y_batch)

    if step % 1000 == 0 or step == max_iters - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {scheduler.get_last_lr()[0]}")

    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()  
    scheduler.step()

# tiny shakespeare:
# without TPE: train loss 1.2350, val loss 1.4911

step 0: train loss 5.8114, val loss 5.8051, lr 0.001
step 1000: train loss 1.8710, val loss 2.1210, lr 0.001
step 2000: train loss 1.6902, val loss 1.9856, lr 0.001
step 3000: train loss 1.6107, val loss 1.9409, lr 0.001
step 4000: train loss 1.5591, val loss 1.9173, lr 0.0001
step 5000: train loss 1.5104, val loss 1.8824, lr 1e-05
step 5999: train loss 1.5051, val loss 1.8896, lr 1.0000000000000002e-06


In [20]:
save_model_path = f'models/{data_file_name}_bpe_model.pth'
torch.save(m.state_dict(), save_model_path)

In [25]:
def unmerge(ids):
    new_ids = []
    for i in ids:
        if i < INIT_VOCAB_MAX:
            new_ids.append(i)
        else:
            merg_pair = merges[i]
            unmerged_ids = unmerge(merg_pair)
            new_ids.extend(unmerged_ids)
    return new_ids

# tokens1 = [0, 16, 33, 23, 87, 18, 53, 56, 126]
# print(merges[87], merges[126], merges[80])
# print(unmerge(tokens1))

In [26]:
idx_to_bytes = {i: bytes([i]) for i in range(256)}
for merge_id, (p0, p1) in merges.items():
    idx_to_bytes[merge_id] = idx_to_bytes[p0] + idx_to_bytes[p1] # '+' here is a bytes concat!

def decode_tokens(byte_tokens):
    unmerged_bytes = b''.join([idx_to_bytes[b] for b in byte_tokens])
    return unmerged_bytes.decode('utf-8', errors = 'replace')

In [28]:
start_token = torch.tensor([[int(ord('\n'))]], device = device)
gen_text = m.generate(start_token, 200)

In [29]:
print(bytes(unmerge(gen_text[0].tolist())).decode('utf-8'))


As so to'st his messenger hardly noble
To bound his lady's lipens: I'll give him words a
Inderive no rih. New be gone; ware, The hour into ell.
Do thurs, Ext Oither sensity Clarence!

WARURE IV:
Come, Edward of Wills, thou slainlock now to my grieve.

S


In [30]:
print(decode_tokens(gen_text[0].tolist()))


As so to'st his messenger hardly noble
To bound his lady's lipens: I'll give him words a
Inderive no rih. New be gone; ware, The hour into ell.
Do thurs, Ext Oither sensity Clarence!

WARURE IV:
Come, Edward of Wills, thou slainlock now to my grieve.

S
