In [1]:
import torch
import torch.nn as nn
import time
from tqdm import tqdm

from model.transformer import TLM
from model.utils import n_params, flatten_list
from tokens import get_word_freqs, get_v, train, tokenize, detokenize, encode, decode, f_stoi, f_itos

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(42)
d_opts = [('cuda', torch.cuda.is_available()), ('mps', torch.backends.mps.is_available()), ('cpu', True)]
device = next(device for device, available in d_opts if available)
print(f'using device: {device}')

using device: mps


In [3]:
with open ('data/truths.txt', 'r', encoding='utf-8') as f: 
    corpus = f.read().split('\n')
    tks = Tokens(corpus, 1000)
vocab_size = len(tks.vocab)
tokenized = [tks.tokenize(i) for i in corpus]
encoded = [tks.encode(i) for i in tokenized]
data = torch.tensor(flatten_list(encoded), dtype=torch.long, device=device)
n = int(0.9*len(data)) # 90%, 10%
train_data = data[:n]
val_data = data[n:]

999/1000

In [4]:
def get_batch(split: str):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [5]:
@torch.no_grad()
def estimate_loss(m):
    out = {}
    m.eval()
    eval_iters = 200
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

In [21]:
# hyperparameters
batch_size = 32
block_size = 8
n_embd = 16
n_blocks = 4
n_heads = 4

lr = 1e-2
epochs = 5000
epoch_eval = 500

In [24]:
model = TLM(block_size=block_size, n_embd=n_embd, vocab_size=vocab_size, 
                      n_blocks=n_blocks, n_heads=n_heads, device=device).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

print(f'num of params: {n_params(model)}') # gpt-2 has 1,500,000,000 (1.5B)

num of params: 46056


In [23]:
st = time.time()
for epoch in tqdm(range(epochs)):
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    lr = 1e-3 if epoch > 3000 else 1e-2
    if epoch % epoch_eval == 0:
        tv_loss = estimate_loss(model)
        print(f'step {epoch}: train loss {tv_loss['train']:.4f} val loss {tv_loss['val']:.4f}')
    et = time.time()
print(f'training took: {et-st:.2f}s or {(et-st)/60:.2f}m')

  0%|                                                                                                                                      | 0/5000 [00:00<?, ?it/s]

step 0: train loss 7.0940 val loss 7.0964


  2%|█▉                                                                                                                         | 78/5000 [03:59<4:12:24,  3.08s/it]


KeyboardInterrupt: 

In [None]:
print('-- After Training')
tv_loss = estimate_loss(model)
print(f'train loss: {tv_loss['train']:.4f} val loss: {tv_loss['val']:.4f}')
out = model.generate(torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=50)
dout = decode(out[0].tolist(), itos)
print(detokenize(dout))