In [1]:
import torch
import torch.nn as nn
import time
from tqdm import tqdm

from model.transformer import TLM
from model.utils import n_params
from tokens import Tokens

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(42)
d_opts = [('cuda', torch.cuda.is_available()), ('mps', torch.backends.mps.is_available()), ('cpu', True)]
device = next(device for device, available in d_opts if available)
print(f'using device: {device}')

using device: mps


In [3]:
vocab_size = 5000
with open ('data/tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
    corpus = f.read() # 15,057 unique words
    tks = Tokens(corpus, vocab_size) # 50,000 in gpt-2
tokenized = tks.tokenize(corpus)
encoded = tks.encode(tokenized)

4999/5000

In [4]:
data = torch.tensor(encoded, dtype=torch.long, device=device)
n = int(0.85*len(data))
train_data = data[:n]
val_data = data[n:]

In [5]:
def get_batch(split: str):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [6]:
@torch.no_grad()
def estimate_loss(m):
    out = {}
    m.eval()
    eval_iters = 200
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

In [11]:
# hyperparameters
batch_size = 32
block_size = 144 # 1024 in gpt-2
n_embd = 96 # 768 in gpt-2
n_blocks = 6
n_heads = 4

lr = 1e-2
iters = 10000
i_eval = 625

In [12]:
model = TLM(block_size=block_size, n_embd=n_embd, vocab_size=vocab_size,
            n_blocks=n_blocks, n_heads=n_heads, device=device).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

print(f'num of params: {n_params(model)}') # gpt-2 has 1,500,000,000 (1.5B)

num of params: 1648136


In [13]:
st = time.time()
for i in tqdm(range(iters)):
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    lr = 1e-3 if i > 6500 else 1e-2
    if i % i_eval == 0:
        tv_loss = estimate_loss(model)
        print(f"step {i}: train loss {tv_loss['train']:.4f} val loss {tv_loss['val']:.4f}")
et = time.time()
print()
print(f'training took: {et-st:.2f}s or {(et-st)/60:.2f}m')

  0%|                                                                                                                        | 3/10000 [00:10<7:17:25,  2.63s/it]

step 0: train loss 17.7879 val loss 17.6634


  6%|███████▍                                                                                                              | 627/10000 [01:13<3:49:54,  1.47s/it]

step 625: train loss 3.8196 val loss 5.1356


 13%|██████████████▋                                                                                                      | 1253/10000 [02:18<2:48:07,  1.15s/it]

step 1250: train loss 3.3460 val loss 5.2896


 19%|█████████████████████▉                                                                                               | 1877/10000 [03:37<7:16:52,  3.23s/it]

step 1875: train loss 2.9930 val loss 5.5515


 25%|█████████████████████████████▎                                                                                       | 2502/10000 [04:57<5:15:49,  2.53s/it]

step 2500: train loss 2.7473 val loss 5.8028


 31%|████████████████████████████████████▌                                                                                | 3127/10000 [06:15<4:42:37,  2.47s/it]

step 3125: train loss 2.5955 val loss 5.9678


 38%|███████████████████████████████████████████▉                                                                         | 3752/10000 [07:41<4:22:27,  2.52s/it]

step 3750: train loss 2.4413 val loss 6.1392


 44%|███████████████████████████████████████████████████▏                                                                 | 4377/10000 [08:56<2:40:00,  1.71s/it]

step 4375: train loss 2.3332 val loss 6.3541


 50%|██████████████████████████████████████████████████████████▌                                                          | 5002/10000 [10:05<2:16:06,  1.63s/it]

step 5000: train loss 2.2518 val loss 6.4904


 56%|█████████████████████████████████████████████████████████████████▊                                                   | 5627/10000 [11:20<2:03:41,  1.70s/it]

step 5625: train loss 2.1702 val loss 6.5579


 63%|█████████████████████████████████████████████████████████████████████████▏                                           | 6252/10000 [12:37<2:02:22,  1.96s/it]

step 6250: train loss 2.1227 val loss 6.6757


 69%|█████████████████████████████████████████████████████████████████████████████████▊                                     | 6878/10000 [13:53<59:47,  1.15s/it]

step 6875: train loss 2.0604 val loss 6.7585


 75%|███████████████████████████████████████████████████████████████████████████████████████▊                             | 7502/10000 [14:56<1:00:09,  1.44s/it]

step 7500: train loss 2.0164 val loss 6.8601


 81%|████████████████████████████████████████████████████████████████████████████████████████████████▋                      | 8128/10000 [16:00<34:59,  1.12s/it]

step 8125: train loss 1.9570 val loss 6.9410


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏              | 8752/10000 [17:07<31:45,  1.53s/it]

step 8750: train loss 1.9214 val loss 6.9566


 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▌       | 9378/10000 [18:11<11:12,  1.08s/it]

step 9375: train loss 1.9062 val loss 7.0632


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [19:08<00:00,  8.71it/s]


training took: 1148.40s or 19.14m





In [17]:
print('-- After Training')
tv_loss = estimate_loss(model)
print(f"train loss: {tv_loss['train']:.4f} val loss: {tv_loss['val']:.4f}")
out = model.generate(torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=2500)
dout = tks.decode(out[0].tolist())
print(tks.detokenize(dout))

-- After Training
train loss: 1.8865 val loss: 7.1341
