In [15]:
import torch
from torch import nn, optim
import torch.nn.functional as F

import os
import time
import numpy as np
import tiktoken

# Model

### Layer Normalization

$y = \frac{x - E[x]}{\sqrt{E[(x - E[x])^2] + \epsilon}} \odot \gamma + \beta$

In [89]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, bias=True, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.ones(normalized_shape)) if bias else None
        self.eps = eps
    

    def forward(self, x):
        y = F.layer_norm(x, self.weight.shape, self.weight, self.bias, self.eps) 
        return y  

### Self Attention

$[q_{i, j}, k_{i, j}, v_{i, j}] = x_{i, j}[W_q, W_k, W_v] + [b_q, b_k, b_v]$ 
where all $q, k, v, x, b$ are row vectors

$[q_{i, j}, k_{i, j}, v_{i, j}]$ are computed for $x_i$ in the 3D tensor 
$x = \begin{bmatrix}
x_{1, 1} & \dots & x_{1, T} \\
\vdots & \ddots & \vdots \\
x_{B, 1} & \dots & x_{B, T}
\end{bmatrix}$ resulting in tensors $q, k, v$

$x$ has shape $(B, T, C)$ where $B$ is the batch size, $T$ is the sequence length,
and $C$ is the number of embedding dimensions

$q, k, v$ have shape $(B, T, ND)$ where $B$ is the batch size, 
$T$ is the sequence length, $N$ is the number of attention heads, and $D$ is the
number of query/key dimensions

$q, k, v$ are reshaped to $(B, N, T, D)$

$a_{i, j} = q_{i, j} k_{i, j}^T$

$a_{i, j} = -\infty$ for all $i < j$

$a_{i, j} = \text{softmax}(a_{i, j})$ where softmax is computed rowwise

$y_{i, j} = a_{i, j} v_{i, j}$

$y$ has shape $(B, N, T, D)$

$y$ is reshaped to $(B, T, ND)$, so $y_{i, j}$ is a row vector

$y_{i, j} = y_{i, j}W_p + b_p$ 

In [90]:
class SelfAttention(nn.Module):
    def __init__(self, n_embed, n_head, blk_size, dropout, bias=True):
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        assert n_embed % n_head == 0 
        self.D = n_embed // n_head

        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)
        self.res_dropout = nn.Dropout(dropout)

        self.attn_scale = 1.0 / np.sqrt(self.D)

        self.register_buffer(
            'bias', 
            torch.tril(torch.ones(blk_size, blk_size)).view(1, 1, blk_size, blk_size),
        )

    
    def forward(self, x):
        B, T, n_embed = x.shape
        assert n_embed == self.n_embed

        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)

        q = q.view(B, T, self.n_head, self.D).transpose(1, 2) # gives shape (B, N, T, D)
        k = k.view(B, T, self.n_head, self.D).transpose(1, 2) # gives shape (B, N, T, D)
        v = v.view(B, T, self.n_head, self.D).transpose(1, 2) # gives shape (B, N, T, D)

        # (B, N, T, D) @ (B, N, D, T) = (B, N, T, T)
        a = self.attn_scale * q @ k.transpose(2, 3)
        a.masked_fill_(self.bias[:, :, :T, :T] == 0, -torch.inf)
        a = F.softmax(a, dim=3)
        a = self.attn_dropout(a)

        # (B, N, T, T) @ (B, N, T, D) = (B, N, T, D)
        # (T, T) @ (T, D) = (T, D) for each batch and head
        y = a @ v

        y = y.transpose(1, 2).contiguous().view(B, T, self.n_embed) # concat head outputs 

        y = self.c_proj(y)
        y = self.res_dropout(y)
        return y

### MLP

$x_{i, j} = x_{i, j} W_{c} + B_{c}$

$y_{i, j} = x_{i, j} W_{p} + B_{p}$

In [91]:
class MLP(nn.Module):
    def __init__(self, n_embed, dropout, bias=True):
        super().__init__()
        self.c_fc = nn.Linear(n_embed, 4 * n_embed, bias=bias)
        self.c_proj = nn.Linear(4 * n_embed, n_embed, bias=bias)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x):
        x = F.gelu(self.c_fc(x))
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

### Block

Composition of layer normalization, self attention, and mlp

In [92]:
class Block(nn.Module):
    def __init__(self, n_embed, n_head, blk_size, dropout, bias=True):
        super().__init__()
        self.ln_1 = LayerNorm(n_embed, bias=bias)
        self.attn = SelfAttention(n_embed, n_head, blk_size, dropout, bias)
        self.ln_2 = LayerNorm(n_embed, bias=bias)
        self.mlp = MLP(n_embed, dropout, bias)


    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

### GPT

In [119]:
class GPT(nn.Module):
    def __init__(self, n_embed, n_head, blk_size, 
                 vocab_size, n_layer, dropout, bias=True):
        
        super().__init__()
        self.blk_size = blk_size

        args = (n_embed, n_head, blk_size, dropout, bias)

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, n_embed),
            wpe = nn.Embedding(blk_size, n_embed),
            drop = nn.Dropout(dropout),
            h = nn.ModuleList([Block(*args) for _ in range(n_layer)]),
            ln_f = LayerNorm(n_embed, bias=bias),
        ))

        self.lm_head = nn.Linear(n_embed, vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)
    

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def forward(self, x_idx):
        device = x_idx.device
        _, T = x_idx.size()

        assert T <= self.blk_size, \
            f"Cannot forward sequence of length {T}, block size is only {self.blk_size}"
        
        pos = torch.arange(0, T, dtype=torch.long, device=device)

        tok_emb = self.transformer.wte(x_idx) # shape (B, T, C)
        pos_emb = self.transformer.wpe(pos) # shape (T, C)

        # (B, T, C) + (T, C) = (B, T, C)
        # elementwise addition for each batch
        x = self.transformer.drop(tok_emb + pos_emb)
        for blk in self.transformer.h:
            x = blk(x)
        x = self.transformer.ln_f(x)
        x = self.lm_head(x)
        return x


    def configure_optimizers(self, lr, betas, weight_decay):
        params = [p for p in self.parameters() if p.requires_grad]
        # any parameters that is 2D will be weight decayed, otherwise no
        # weight tensors in matmuls and embeddings have weight decay
        # biases and layernorms don"t have weight decay
        decay_params = [p for p in params if p.dim() >= 2]
        nodecay_params = [p for p in params if p.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]

        optimizer = optim.AdamW(optim_groups, lr=lr, betas=betas)
        return optimizer
    

    @torch.no_grad()
    def generate(self, x_idx, max_new_tokens, temperature=1.0):
        # Take a conditioning sequence of indices x_idx (int64 tensor of shape (B, T)) and 
        # complete the sequence max_new_tokens times, feeding the predictions back into 
        # the model each time. Most likely you"ll want to make sure to be in model.eval() 
        # mode of operation for this.
        for _ in range(max_new_tokens):
            if x_idx.shape[1] <= self.blk_size:
                x_idx_cropped = x_idx 
            else:
                x_idx_cropped = x_idx[:, -self.blk_size:]

            logits = self(x_idx_cropped)
            logits = logits[:, -1, :] / temperature

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            x_idx = torch.cat((x_idx, idx_next), dim=1)

        return x_idx    
    

    @classmethod
    def from_pretrained(cls, model_type, override_args=None):
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        override_args = override_args or {} # default to empty dict
        # only dropout can be overridden see more notes below
        assert all(k == 'dropout' for k in override_args)
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embed are determined from model_type
        config_args = {
            "gpt2":         dict(n_layer=12, n_head=12, n_embed=768),  # 124M params
            "gpt2-medium":  dict(n_layer=24, n_head=16, n_embed=1024), # 350M params
            "gpt2-large":   dict(n_layer=36, n_head=20, n_embed=1280), # 774M params
            "gpt2-xl":      dict(n_layer=48, n_head=25, n_embed=1600), # 1558M params
        }[model_type]
        
        print("forcing vocab_size=50257, blk_size=1024, bias=True")
        config_args["vocab_size"] = 50257 # always 50257 for GPT model checkpoints
        config_args["blk_size"] = 1024 # always 1024 for GPT model checkpoints
        config_args["bias"] = True # always True for GPT model checkpoints
        
        if "dropout" in override_args:
            print(f"overriding dropout rate to {override_args['dropout']}")
            config_args["dropout"] = override_args["dropout"]
        else:
            config_args["dropout"] = 0.0

        model = GPT(**config_args)
        sd = model.state_dict()
        sd_keys = sd.keys()
        # discard this mask / buffer, not a param
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] 

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.masked_bias")]
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")]
        transposed = [
            "attn.c_attn.weight", 
            "attn.c_proj.weight", 
            "mlp.c_fc.weight", 
            "mlp.c_proj.weight",
        ]

        assert len(sd_keys_hf) == len(sd_keys), \
            f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].T)
            else:
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

# Training

In [85]:
out_dir = "out"
eval_interval = 1000
log_interval = 100
eval_iters = 50
always_save_checkpoint = True

data_dir = "data"

batch_size = 4

# model
n_layer = 12
n_head = 12
n_embed = 768
blk_size = 1024
vocab_size = 50257
dropout = 0.0
bias = True

# adamw optimizer
lr = 6e-4 # max learning rate
max_iters = 1000000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95

# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
min_lr = 6e-5 # minimum learning rate, should be ~= lr/10 per Chinchilla
device = "mps"
dtype = torch.float32
compile = False

In [8]:
def get_batch(split):
    if split == "train":
        data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode="r")
    else:
        data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode="r")

    def make_block(i):
        return torch.from_numpy((data[i:i + blk_size]).astype(np.int64))
    

    idxs = torch.randint(len(data) - blk_size, (batch_size,))
    x = torch.stack([make_block(i) for i in idxs]).to(device)
    y = torch.stack([make_block(i + 1) for i in idxs]).to(device)
    return x, y

In [9]:
model = GPT(n_embed, n_head, blk_size, vocab_size, n_layer, dropout, bias).to(device)

In [10]:
optimizer = model.configure_optimizers(lr, (beta1, beta2), weight_decay)

In [11]:
# checkpoint = torch.load('out/checkpoint.pt')
# model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])

In [12]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [13]:
def get_lr(iter_num):
    # linear warmup for warmup_iters steps
    if iter_num < warmup_iters:
        return lr * iter_num / warmup_iters
    # if iter_num > lr_decay_iters, return min learning rate
    if iter_num > lr_decay_iters:
        return min_lr
    # in between, use cosine decay down to min learning rate
    decay_ratio = (iter_num - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (lr - min_lr)

In [14]:
iter_num = 1
best_val_loss = 1e9

x, y = get_batch("train")
t0 = time.time()

while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else lr
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
    
    x, y = get_batch("train")

    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0:
        print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt:.2f}s")
    
    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(f"train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}")
        
        if losses["val"] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses["val"]
            if iter_num > 0:
                checkpoint = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "iter_num": iter_num,
                    "best_val_loss": best_val_loss,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, "checkpoint.pt"))
   
    iter_num += 1
    if iter_num > max_iters:
        break

# Prompt

In [124]:
model = GPT.from_pretrained("gpt2-xl").to(device)

loading weights from pretrained gpt: gpt2-xl
forcing vocab_size=50257, blk_size=1024, bias=True


config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [126]:
encoder = tiktoken.get_encoding("gpt2")

In [127]:
@torch.no_grad()
def generate(prompt, max_new_tokens=500, temperature=1.0):
    encoded = encoder.encode(prompt)
    x = torch.tensor(encoded, dtype=torch.int64, device=device)[None, :]
    y = model.generate(x, max_new_tokens, temperature)
    decoded = encoder.decode(y[0].tolist())
    return decoded

In [131]:
prompt = 'plants are very important, they'

In [132]:
generate(prompt, max_new_tokens=10)

"plants are very important, they're the basis of your move, so we need"