In [14]:
import torch
import torch.nn as nn
import time

from model.transformer import TLM
from model.utils import n_params

In [2]:
torch.manual_seed(42)
d_opts = [('cuda', torch.cuda.is_available()), ('mps', torch.backends.mps.is_available()), ('cpu', True)]
device = next(device for device, available in d_opts if available)
print(f'using device: {device}')

using device: mps


In [3]:
with open('data/truths.txt', 'r', encoding='utf-8') as f: data_txt = f.read()
# get all chars
chars = sorted(list(set(data_txt)))
vocab_size = len(chars)
# encode/decode funcs
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda e: ''.join([itos[i] for i in e])
# encode data (chars to corresponding number)
data = torch.tensor(encode(data_txt), dtype=torch.long, device=device)
# split data into train/val
n = int(0.9*len(data)) # 90%, 10%
train_data = data[:n]
val_data = data[n:]

In [4]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [15]:
@torch.no_grad()
def estimate_loss(m):
    out = {}
    m.eval()
    eval_iters = 200
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

In [16]:
# hyperparameters
batch_size = 32
block_size = 8
n_embd = 32
n_blocks = 6
n_heads = 4

lr = 1e-3
epochs = 5000
epoch_eval = 500

In [17]:
model = TLM(block_size=block_size, n_embd=n_embd, vocab_size=vocab_size, 
                      n_blocks=n_blocks, n_heads=n_heads, device=device).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

print(f'num of params: {n_params(model)}') # gpt-2 has 1,500,000,000 (1.5B)

num of params: 79024


In [11]:
st = time.time()
for epoch in range(epochs):
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    #lr = 1e-4 if epoch > 5000 else 1e-3
    if epoch % epoch_eval == 0:
        tv_loss = estimate_loss(model)
        print(f'step {epoch}: train loss {tv_loss['train']:.4f} val loss {tv_loss['val']:.4f}')
et = time.time()
print(f'training took: {et-st:.1f}s')

step 0: train loss 3.9734 val loss 3.9806
step 500: train loss 2.2068 val loss 2.2016
step 1000: train loss 2.0046 val loss 1.9909
step 1500: train loss 1.9008 val loss 1.9133
step 2000: train loss 1.8286 val loss 1.8374
step 2500: train loss 1.7858 val loss 1.8190
step 3000: train loss 1.7484 val loss 1.7496
step 3500: train loss 1.7122 val loss 1.7286
step 4000: train loss 1.6871 val loss 1.7052
step 4500: train loss 1.6630 val loss 1.7049
training took: 389.6s


In [12]:
print('-- After Training')
tv_loss = estimate_loss(model)
print(f'train loss: {tv_loss['train']:.4f} val loss: {tv_loss['val']}')
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=250)[0].tolist()))

-- After Training
train loss: 1.6433 val loss: 1.665024757385254

mactiparctioxid hif cend
branphict a plevion is rease
a peroces is wilcome equirigquirs
near
tho sodbicans posito
coloon orvitatal anite
pood combent
coll is a wasits
a poperistion
morved
if earting rasuren revay cuse"
reverignefriains to somethinm a
