In [5]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from matplotlib import pyplot as plt
import time

device = torch.device("mps")

In [2]:
# simple tokenization by characters

lines = open('./input.txt', 'r').read()

vocab = list(set(lines))
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}

def encode(s):
    return [stoi[ch] for ch in s]

def decode(l):
    return ''.join([itos[i] for i in l])

print('vocab size:', len(vocab))

vocab size: 65


In [3]:
dataset = torch.tensor(encode(lines), dtype=torch.int8)
dataset

tensor([62, 25, 41,  ...,  5, 27,  9], dtype=torch.int8)

In [4]:
config = {
    "d_model": 100,
    "vocab_size": len(vocab),
    'batch_size': 32,
    'context_window': 10,
}


def get_batches(data, split, batch_size, context_window):
    train = data[:int(.8 * len(data))]
    val = data[int(.8 * len(data)): int(.9 * len(data))]
    test = data[int(.9 * len(data)):]
    
    batch_data = train
    if split == 'val':
        batch_data = val
    
    # pick random starting points
    ix = torch.randint(0, batch_data.size(0) - context_window - 1, (batch_size,))
    x = torch.stack([batch_data[i:i+context_window] for i in ix]).long()
    y = torch.stack([batch_data[i+1:i+context_window+1] for i in ix]).long()
    return x, y

xs, ys = get_batches(dataset, 'train', config['batch_size'], config['context_window'])

[(decode(xs[i].tolist()), decode(ys[i].tolist())) for i in range(len(xs))]

[('\nGood lord', 'Good lords'),
 ('re fit to ', 'e fit to b'),
 ('uous branc', 'ous branch'),
 (';\nAnd I wi', '\nAnd I wil'),
 ('uld.\n\nDUKE', 'ld.\n\nDUKE '),
 ('e loving c', ' loving ci'),
 ('zes not su', 'es not suc'),
 (' purpose i', 'purpose in'),
 ('the tuft o', 'he tuft of'),
 ('y day, my ', ' day, my l'),
 ('INGHAM:\nAh', 'NGHAM:\nAh,'),
 ('other\nWith', 'ther\nWithi'),
 ("guess'd, b", "uess'd, be"),
 ('ply me wit', 'ly me with'),
 ('le kiss th', 'e kiss the'),
 ('oal-black,', 'al-black,\n'),
 ('n to touch', ' to touch '),
 (' join our ', 'join our l'),
 (' Thursday ', 'Thursday e'),
 ('was his fa', 'as his fat'),
 (' with a su', 'with a sud'),
 ('atery beam', 'tery beams'),
 ('o through ', ' through t'),
 ("'Twas odds", 'Twas odds,'),
 ('or both: u', 'r both: up'),
 (' lie unto ', 'lie unto h'),
 ('spoke: but', 'poke: but '),
 ("'s good wi", 's good wit'),
 ('nds.\nOurse', 'ds.\nOursel'),
 (' ere now;\n', 'ere now;\nA'),
 ('od husband', 'd husbandr'),
 ('rvant:\nMy ', 'vant:

In [14]:
class LLama(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config['vocab_size'], config['d_model'])

        # simple feed forward network
        self.ff = nn.Sequential(
            nn.Linear(config['d_model'], config['d_model']),
            nn.ReLU(),
            nn.Linear(config['d_model'], config['vocab_size'])
        )

    def forward(self, idx, targets=None):
        x = self.embeddings(idx)
        logits = self.ff(x)

        if targets is None:
            return logits
        
        else:
            loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
            return logits, loss

In [15]:
config = {
    "n_heads": 8,
    "d_model": 128,
    "batch_size": 32,
    "n_layers": 4,
    "context_window": 16,
    "vocab_size": len(vocab),
    "epochs": 1000,
    "log_interval": 10,
}
model = LLama(config)

optimizer = torch.optim.Adam(
    model.parameters(), 
    # betas=(.9, .95), 
    # weight_decay=.1, 
    # eps=1e-9, 
    lr=3e-4
)

start_time = time.time()
for epoch in range(config['epochs']):
    optimizer.zero_grad()
    
    xs, ys = get_batches(dataset, 'train', config['batch_size'], config['context_window'])
    logits, loss = model(xs, targets=ys)
    loss.backward()

    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    batch_time = time.time() - start_time
    print(f"Epoch {epoch} | Loss {loss.item():.3f} | Time {batch_time:.3f} | ETA in seconds {batch_time * (config['epochs'] - epoch)/config['log_interval'] :.3f}")
    start_time = time.time()

Epoch 0 | Loss 4.166 | Time 0.003 | ETA in seconds 0.286
Epoch 1 | Loss 4.159 | Time 0.002 | ETA in seconds 0.197
Epoch 2 | Loss 4.138 | Time 0.002 | ETA in seconds 0.194
Epoch 3 | Loss 4.135 | Time 0.002 | ETA in seconds 0.228
Epoch 4 | Loss 4.114 | Time 0.002 | ETA in seconds 0.197
Epoch 5 | Loss 4.070 | Time 0.002 | ETA in seconds 0.244
Epoch 6 | Loss 4.057 | Time 0.002 | ETA in seconds 0.227
Epoch 7 | Loss 4.024 | Time 0.002 | ETA in seconds 0.196
Epoch 8 | Loss 4.038 | Time 0.002 | ETA in seconds 0.184
Epoch 9 | Loss 3.986 | Time 0.002 | ETA in seconds 0.158
Epoch 10 | Loss 3.993 | Time 0.002 | ETA in seconds 0.180
Epoch 11 | Loss 3.976 | Time 0.002 | ETA in seconds 0.197
Epoch 12 | Loss 3.959 | Time 0.002 | ETA in seconds 0.218
Epoch 13 | Loss 3.955 | Time 0.002 | ETA in seconds 0.190
Epoch 14 | Loss 3.903 | Time 0.002 | ETA in seconds 0.194
Epoch 15 | Loss 3.925 | Time 0.002 | ETA in seconds 0.207
Epoch 16 | Loss 3.892 | Time 0.002 | ETA in seconds 0.203
Epoch 17 | Loss 3.875 | 