In [129]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

device = torch.device("mps")

In [143]:
# simple tokenization by characters

lines = open('./input.txt', 'r').read()

vocab = list(set(lines))
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}

def encode(s):
    return [stoi[ch] for ch in s]

def decode(l):
    return ''.join([itos[i] for i in l])

print('vocab size:', len(vocab))

vocab size: 65


In [144]:
dataset = torch.tensor(encode(lines), dtype=torch.int8)
dataset

tensor([23, 64, 21,  ..., 42, 31, 18], dtype=torch.int8)

We want to run training with minibatches of size 8.

In [145]:
config = {
    "d_model": 100,
    "vocab_size": len(vocab),
    'batch_size': 32,
    'context_window': 10,
}


def get_batches(data, split, batch_size, context_window):
    train, val, test = data[:int(.8 * len(data))], data[int(.8 * len(data)): int(.9 * len(data))], data[int(.9 * len(data)):]
    
    if split == 'train':
        data = train
    elif split == 'val':
        data = val
    
    xs = torch.zeros(batch_size, context_window, dtype=torch.long)
    ys = torch.zeros(batch_size, context_window, dtype=torch.long)

    # pick random starting points
    starts = torch.randint(0, len(data) - context_window - 1, (batch_size,))
    for item, start in enumerate(starts):
        
        xs[item] += data[start:start+context_window]
        ys[item] += data[start+1:start+context_window+1]
    return xs, ys
    
xs, ys = get_batches(dataset, 'train', config['batch_size'], config['context_window'])

In [146]:
class MaskedRotarySelfAttentionHead(nn.Module):
    """
    Rotary embeddings.

    Input: (BATCH_SIZE x CONTEXT_WINDOW)
    Output: (BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM)
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        if config['d_model'] % 2 != 0:
            raise ValueError("d_model must be divisible by 2")
        self.w_q = torch.randn(config['d_model'], config['d_model'])
        self.w_k = torch.randn(config['d_model'], config['d_model'])
        self.w_v = torch.randn(config['d_model'], config['d_model'])
        self.register_buffer("tril", torch.tril(torch.ones(config['context_window'], config['context_window']))) # mask
        
        self.R = self.get_rotary_matrix() # [d_model x d_model x context_window]

    def get_rotary_matrix(self):
        context_window = self.config['context_window']
        d = self.config['d_model']
        R = torch.zeros((d, d, context_window))
        for m in range(context_window):
            for i in range(d//2):
                theta = 10000. ** (-2.*(i - 1) / d)
                R[2*i,2*i, m] = np.cos(m * theta)
                R[2*i,2*i+1, m] = - np.sin(m * theta)
                R[2*i+1,2*i, m] = np.sin(m * theta)
                R[2*i+1,2*i+1, m] = np.cos(m * theta)
        return R

    def forward(self, x):
        """
        x: [BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM]
        out: [BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM]
        """
        b = self.config['batch_size']
        d = self.config['d_model']
        m = self.config['context_window']

        q_rotated_query_weight = self.w_q @ self.R # [d x d x m]
        q = q_rotated_query_weight.view(m, d, d) @ x.view(m, d, b)
        q = q.view(b, m, d)

        k_rotated_query_weight = self.w_k @ self.R # [d x d x m]
        k = k_rotated_query_weight.view(m, d, d) @ x.view(m, d, b)
        k = k.view(b, m, d)
        
        B = (q @ k.transpose(1,2)) / np.sqrt(self.config['d_model'])
        mask = B.masked_fill(self.tril[:m, :m] == 0, float("-inf"))
        a = F.softmax(mask, dim=-1) # attention
        v = x @ self.w_v 
        out = a @ v
        return out


config = {
    "batch_size": 3,
    "context_window": 10,
    "d_model": 128,
}
r = RotarySelfAttention(config)
batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))

r(batch).shape

# ==== check ====
# s = batch[0,:,:]
# (r.w_q @ r.R)

torch.Size([3, 10, 128])

In [147]:
class MaskedRotaryNultiHeadedAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.heads = nn.ModuleList([
            MaskedRotarySelfAttentionHead(config) for _ in range(config["n_heads"])
        ])
        self.linear = nn.Linear(config["n_heads"] * config["d_model"], config["d_model"])
        self.dropout = nn.Dropout(0.1)


    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.linear(out))
        return out

config = {
    "n_heads": 8,
    "d_model": 128,
    "batch_size": 32,
    "context_window": 16,
}

batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))
m = MaskedRotaryNultiHeadedAttention(config)

m(batch).shape

torch.Size([32, 16, 128])

In [148]:
class SwiGLU(nn.Module):
    """
    Swish-Gated Linear Unit
    https://arxiv.org/pdf/2002.05202v1.pdf
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.linear = nn.Linear(config['d_model'], config['d_model'])
        self.beta = nn.Linear(config['d_model'], config['d_model'], bias=False)

    def forward(self, x):
        out = x * torch.sigmoid(self.beta(x)) + (1 - torch.sigmoid(self.beta(x))) * self.linear(x)
        return out

In [149]:
config = {
    "n_heads": 8,
    "d_model": 512,
    "batch_size": 3,
    "context_window": 10,
}

class LlamaLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.prenorm = nn.LayerNorm(config['d_model'])
        self.multihead = MaskedRotaryNultiHeadedAttention(config)
        self.ffn = nn.Sequential(
            nn.Linear(config['d_model'], config['d_model']),
            SwiGLU(config),
        )
        self.postnorm = nn.LayerNorm(config['d_model'])

    def forward(self, x):
        mid = x + self.multihead(self.prenorm(x))
        out = mid + self.ffn(mid)

        return out

batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))
m = LlamaLayer(config)

In [187]:
from collections import OrderedDict

class Llama(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
        self.linear = nn.Linear(config['d_model'], config['vocab_size'])
        self.layers = nn.Sequential(OrderedDict([
            (f"llama_layer_{i}", LlamaLayer(config)) for i in range(config['n_layers'])
        ]))

    def forward(self, idx, targets=None):
        """
        idx: [batch_size, seq_len]
        targets: [batch_size, seq_len]
        """

        # logits = x # todo
        embeds = self.embedding(idx) # [batch_size, seq_len, hidden_size]
        logits = self.layers(embeds) # [batch_size, seq_len, vocab_size]
        logits = F.softmax(self.linear(logits), dim=-1) # [batch_size, seq_len, vocab_size]

        if targets is None:
            return logits
        
        else:
            loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
            return logits, loss

    def generate(self, idx, max_len=20):
        """
        idx: [batch_size, seq_len]
        """
        for i in range(max_len):
            logits = self.forward(idx)
            idx = torch.cat([idx, torch.argmax(logits, dim=-1)], dim=-1)

        return idx

In [188]:
import time

config = {
    "n_heads": 8,
    "d_model": 128,
    "batch_size": 32,
    "n_layers": 4,
    "context_window": 16,
    "vocab_size": len(vocab),
    "epochs": 1000,
    "log_interval": 10,
}

model = Llama(config)
sum([p.numel() for p in model.parameters()])


741185

In [196]:
idx = torch.tensor([encode("Hello")])
model(idx)

RuntimeError: shape '[16, 128, 32]' is invalid for input of size 640

In [186]:
from matplotlib import pyplot as plt

@torch.no_grad()  # don't compute gradients for this function
def evaluate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = []
        for _ in range(10):
            xb, yb = get_batches(dataset, split, config['batch_size'], config['context_window'])
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[split] = np.mean(losses)
    model.train()
    return out

# train
optimizer = torch.optim.Adam(model.parameters(), betas=(.9, .95), weight_decay=.1, eps=1e-9, lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 300, eta_min=1e-5)

start_time = time.time()
losses = []

for epoch in range(config['epochs']):
    optimizer.zero_grad()
    
    xs, ys = get_batches(dataset, 'train', config['batch_size'], config['context_window'])
    _, loss = model(xs, targets=ys)
    loss.backward()

    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()

    if epoch % config['log_interval'] == 0:
        batch_time = time.time() - start_time
        losses += [evaluate_loss()]
        print(f"Epoch {epoch} | Loss {loss.item():.3f} | Time {batch_time:.3f} | ETA in seconds {batch_time * (config['epochs'] - epoch)/config['log_interval'] :.3f}")
        start_time = time.time()

        # inspecting to make sure no gradients are exploding
        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         if param.grad is not None:
        #             print(f"{name} | grad norm {param.grad.norm().item():.3f} | param norm {param.norm().item():.3f}")



Epoch 0 | Loss 4.043 | Time 0.087 | ETA in seconds 8.721
Epoch 10 | Loss 4.051 | Time 0.887 | ETA in seconds 87.798
Epoch 20 | Loss 4.046 | Time 0.770 | ETA in seconds 75.492
Epoch 30 | Loss 4.056 | Time 0.820 | ETA in seconds 79.512


KeyboardInterrupt: 