In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F

import os
import time
import numpy as np

# Model

### Layer Normalization

$y = \frac{x - E[x]}{\sqrt{E[(x - E[x])^2] + \epsilon}} \odot \gamma + \beta$

In [2]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, bias=True, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.ones(normalized_shape)) if bias else None
        self.eps = eps
    

    def forward(self, x):
        y = F.layer_norm(x, self.gamma.shape, self.gamma, self.beta, self.eps) 
        return y  

### Self Attention

$[q_{i, j}, k_{i, j}, v_{i, j}] = x_{i, j}[W_q, W_k, W_v] + [b_q, b_k, b_v]$ 
where all $q, k, v, x, b$ are row vectors

$[q_{i, j}, k_{i, j}, v_{i, j}]$ are computed for $x_i$ in the 3D tensor 
$x = \begin{bmatrix}
x_{1, 1} & \dots & x_{1, T} \\
\vdots & \ddots & \vdots \\
x_{B, 1} & \dots & x_{B, T}
\end{bmatrix}$ resulting in tensors $q, k, v$

$x$ has shape $(B, T, C)$ where $B$ is the batch size, $T$ is the sequence length,
and $C$ is the number of embedding dimensions

$q, k, v$ have shape $(B, T, ND)$ where $B$ is the batch size, 
$T$ is the sequence length, $N$ is the number of attention heads, and $D$ is the
number of query/key dimensions

$q, k, v$ are reshaped to $(B, N, T, D)$

$a_{i, j} = q_{i, j} k_{i, j}^T$

$a_{i, j} = -\infty$ for all $i < j$

$a_{i, j} = \text{softmax}(a_{i, j})$ where softmax is computed rowwise

$y_{i, j} = a_{i, j} v_{i, j}$

$y$ has shape $(B, N, T, D)$

$y$ is reshaped to $(B, T, ND)$, so $y_{i, j}$ is a row vector

$y_{i, j} = y_{i, j}W_p + b_p$ 

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, C, N, blk_size, dropout, bias=True):
        super().__init__()
        self.C = C
        self.N = N
        assert C % N == 0 
        self.D = C // N

        self.c_attn = nn.Linear(C, 3 * C, bias=bias)
        self.c_proj = nn.Linear(C, C, bias=bias)

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)
        self.res_dropout = nn.Dropout(dropout)

        self.attn_scale = 1.0 / np.sqrt(self.D)

        self.register_buffer(
            'bias', 
            torch.tril(torch.ones(blk_size, blk_size)).view(1, 1, blk_size, blk_size),
        )

    
    def forward(self, x):
        B, T, C = x.shape
        assert C == self.C

        q, k, v = self.c_attn(x).split(self.C, dim=2)

        q = q.view(B, T, self.N, self.D).transpose(1, 2) # gives shape (B, N, T, D)
        k = k.view(B, T, self.N, self.D).transpose(1, 2) # gives shape (B, N, T, D)
        v = v.view(B, T, self.N, self.D).transpose(1, 2) # gives shape (B, N, T, D)

        # (B, N, T, D) @ (B, N, D, T) = (B, N, T, T)
        a = self.attn_scale * q @ k.transpose(2, 3)
        a.masked_fill_(self.bias[:, :, :T, :T] == 0, -torch.inf)
        a = F.softmax(a, dim=3)
        a = self.attn_dropout(a)

        # (B, N, T, T) @ (B, N, T, D) = (B, N, T, D)
        # (T, T) @ (T, D) = (T, D) for each batch and head
        y = a @ v

        y = y.transpose(1, 2).contiguous().view(B, T, self.C) # concat head outputs 

        y = self.c_proj(y)
        y = self.res_dropout(y)
        return y

### MLP

$x_{i, j} = x_{i, j} W_{c} + B_{c}$

$y_{i, j} = x_{i, j} W_{p} + B_{p}$

In [4]:
class MLP(nn.Module):
    def __init__(self, C, dropout, bias=True):
        super().__init__()
        self.c_fc = nn.Linear(C, 4 * C, bias=bias)
        self.c_proj = nn.Linear(4 * C, C, bias=bias)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x):
        x = F.gelu(self.c_fc(x))
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

### Block

Composition of layer normalization, self attention, and mlp

In [5]:
class Block(nn.Module):
    def __init__(self, C, N, blk_size, dropout, bias=True):
        super().__init__()
        self.layer_norm1 = LayerNorm(C, bias=bias)
        self.attn = SelfAttention(C, N, blk_size, dropout, bias)
        self.layer_norm2 = LayerNorm(C, bias=bias)
        self.mlp = MLP(C, dropout, bias)


    def forward(self, x):
        x = x + self.attn(self.layer_norm1(x))
        x = x + self.mlp(self.layer_norm2(x))
        return x

### GPT

In [6]:
class GPT(nn.Module):
    def __init__(self, C, N, blk_size, vocab_size, n_layer, dropout, bias=True):
        super().__init__()
        self.blk_size = blk_size
        self.wte = nn.Embedding(vocab_size, C)
        self.wpe = nn.Embedding(blk_size, C)
        self.dropout = nn.Dropout(dropout)

        self.blks = nn.ModuleList(
            [Block(C, N, blk_size, dropout, bias) for _ in range(n_layer)])
            
        self.layer_norm = LayerNorm(C, bias=bias)
        self.fc = nn.Linear(C, vocab_size, bias=False)
        self.wte.weight = self.fc.weight

        self.apply(self._init_weights)
    

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def forward(self, x_idx):
        device = x_idx.device
        _, T = x_idx.size()

        assert T <= self.blk_size, \
            f'Cannot forward sequence of length {T}, block size is only {self.blk_size}'
        
        pos = torch.arange(0, T, dtype=torch.long, device=device)

        tok_emb = self.wte(x_idx) # shape (B, T, C)
        pos_emb = self.wpe(pos) # shape (T, C)

        # (B, T, C) + (T, C) = (B, T, C)
        # elementwise addition for each batch
        x = self.dropout(tok_emb + pos_emb)
        for blk in self.blks:
            x = blk(x)
        x = self.layer_norm(x)
        x = self.fc(x)
        return x


    def configure_optimizers(self, lr, betas, weight_decay):
        params = [p for p in self.parameters() if p.requires_grad]

        # any parameters that is 2D will be weight decayed, otherwise no
        # weight tensors in matmuls and embeddings have weight decay
        # biases and layernorms don't have weight decay
        decay_params = [p for p in params if p.dim() >= 2]
        nodecay_params = [p for p in params if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0},
        ]

        optimizer = optim.AdamW(optim_groups, lr=lr, betas=betas)
        return optimizer
    

    @torch.no_grad()
    def generate(self, x_idx, max_new_tokens, temperature=1.0):
        '''
        Take a conditioning sequence of indices x_idx (LongTensor of shape (B, T)) and 
        complete the sequence max_new_tokens times, feeding the predictions back into the 
        model each time. Most likely you'll want to make sure to be in model.eval() mode 
        of operation for this.
        '''

        for _ in range(max_new_tokens):
            if x_idx.shape[1] <= self.config.block_size:
                x_idx_cropped = x_idx 
            else:
                x_idx_cropped = x_idx[:, -self.blk_size:]

            logits, _ = self(x_idx_cropped)
            logits = logits[:, -1, :] / temperature

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            x_idx = torch.cat((x_idx, idx_next), dim=1)

        return x_idx    

# Training

In [7]:
out_dir = 'out'
eval_interval = 1000
log_interval = 100
eval_iters = 50
always_save_checkpoint = True

data_dir = 'data'

batch_size = 4

# model
n_layer = 12
n_head = 12
n_embd = 768
blk_size = 1024
vocab_size = 50304
dropout = 0.0
bias = False

# adamw optimizer
lr = 6e-4 # max learning rate
max_iters = 10000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95

# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
min_lr = 6e-5 # minimum learning rate, should be ~= lr/10 per Chinchilla
device = 'mps'
dtype = torch.float32
compile = False

In [8]:
def get_batch(split):
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')

    def make_block(i):
        return torch.from_numpy((data[i:i + blk_size]).astype(np.int64))
    

    idxs = torch.randint(len(data) - blk_size, (batch_size,))
    x = torch.stack([make_block(i) for i in idxs]).to(device)
    y = torch.stack([make_block(i + 1) for i in idxs]).to(device)
    return x, y

In [9]:
model = GPT(n_embd, n_head, blk_size, vocab_size, n_layer, dropout, bias).to(device)

In [10]:
optimizer = model.configure_optimizers(lr, (beta1, beta2), weight_decay)

In [11]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [12]:
def get_lr(iter_num):
    # 1) linear warmup for warmup_iters steps
    if iter_num < warmup_iters:
        return lr * iter_num / warmup_iters
    # 2) if iter_num > lr_decay_iters, return min learning rate
    if iter_num > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (iter_num - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (lr - min_lr)

In [13]:
iter_num = 1
best_val_loss = 1e9

x, y = get_batch('train')
t0 = time.time()

while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
    
    x, y = get_batch('train')

    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0:
        print(f'iter {iter_num}: loss {loss.item():.4f}, time {dt:.2f}s')
    
    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(f"train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                }
                print(f'saving checkpoint to {out_dir}')
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
   
    iter_num += 1
    if iter_num > max_iters:
        break

# time 1.69-1.7

KeyboardInterrupt: 