In [1]:
import numpy as np
from mlx import core as mx, nn, optimizers

import os
import time
from functools import partial
import tiktoken

# Model

### Layer Normalization

$y = \frac{x - E[x]}{\sqrt{E[(x - E[x])^2] + \epsilon}} \odot \gamma + \beta$

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, bias=True, eps=1e-5):
        super().__init__()
        self.weight = mx.ones(normalized_shape)
        self.bias = mx.ones(normalized_shape) if bias else None
        self.eps = eps
    

    def __call__(self, x): 
        return mx.fast.layer_norm(x, self.weight, self.bias, self.eps) 

### Self Attention

$[q_{i, j}, k_{i, j}, v_{i, j}] = x_{i, j}[W_q, W_k, W_v] + [b_q, b_k, b_v]$ 
where all $q, k, v, x, b$ are row vectors

$[q_{i, j}, k_{i, j}, v_{i, j}]$ are computed for $x_i$ in the 3D tensor 
$x = \begin{bmatrix}
x_{1, 1} & \dots & x_{1, T} \\
\vdots & \ddots & \vdots \\
x_{B, 1} & \dots & x_{B, T}
\end{bmatrix}$ resulting in tensors $q, k, v$

$x$ has shape $(B, T, C)$ where $B$ is the batch size, $T$ is the sequence length,
and $C$ is the number of embedding dimensions

$q, k, v$ have shape $(B, T, ND)$ where $B$ is the batch size, 
$T$ is the sequence length, $N$ is the number of attention heads, and $D$ is the
number of query/key dimensions

$q, k, v$ are reshaped to $(B, N, T, D)$

$a_{i, j} = q_{i, j} k_{i, j}^T$

$a_{i, j} = -\infty$ for all $i < j$

$a_{i, j} = \text{softmax}(a_{i, j})$ where softmax is computed rowwise

$y_{i, j} = a_{i, j} v_{i, j}$

$y$ has shape $(B, N, T, D)$

$y$ is reshaped to $(B, T, ND)$, so $y_{i, j}$ is a row vector

$y_{i, j} = y_{i, j}W_p + b_p$ 

In [None]:
class SelfAttention(nn.Module):
    def __init__(
        self, n_embed: int, n_head: int, 
        mask: mx.array, dropout: float, bias=True
    ) -> None:
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        assert n_embed % n_head == 0 
        self.D = n_embed // n_head

        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)
        self.dropout = nn.Dropout(dropout)

        self.mask = mask
        self.scale = 1.0 / np.sqrt(self.D)
    

    def __call__(self, x: mx.array):
        B, T, n_embed = x.shape
        assert n_embed == self.n_embed

        tmp = self.c_attn(x)
        tmp = tmp.split(self.n_embed, axis=2)

        q, k, v = mx.split(self.c_attn(x), 3, axis=2)

        # reshape to (B, N, T, D)
        q = q.reshape((B, T, self.n_head, self.D)).transpose((0, 2, 1, 3)) 
        k = k.reshape((B, T, self.n_head, self.D)).transpose((0, 2, 1, 3))
        v = v.reshape((B, T, self.n_head, self.D)).transpose((0, 2, 1, 3))
        
        y = mx.fast.scaled_dot_product_attention(
            q, k, v, 
            mask=self.mask[:T, :T], 
            scale=self.scale,
        )
        
        y = y.transpose((0, 2, 1, 3)).reshape((B, T, self.n_embed)) # concat head outputs 
        y = self.c_proj(y)
        y = self.dropout(y)
        return y

### MLP

$x_{i, j} = x_{i, j} W_{c} + B_{c}$

$y_{i, j} = x_{i, j} W_{p} + B_{p}$

In [None]:
class MLP(nn.Module):
    def __init__(self, n_embed, dropout, bias=True):
        super().__init__()
        self.c_fc = nn.Linear(n_embed, 4 * n_embed, bias=bias)
        self.c_proj = nn.Linear(4 * n_embed, n_embed, bias=bias)
        self.dropout = nn.Dropout(dropout)


    def __call__(self, x):
        x = nn.gelu(self.c_fc(x))
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

### Block

Composition of layer normalization, self attention, and mlp

In [None]:
class Block(nn.Module):
    def __init__(
        self, n_embed: int, n_head: int, 
        mask: mx.array, dropout: float, bias=True,
    ) -> None:
        super().__init__()
        self.ln_1 = LayerNorm(n_embed, bias=bias)
        self.attn = SelfAttention(n_embed, n_head, mask, dropout, bias)
        self.ln_2 = LayerNorm(n_embed, bias=bias)
        self.mlp = MLP(n_embed, dropout, bias)


    def __call__(self, x: mx.array):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

### Generative Transformer
- Input is an array of token indexes
- Computes token embeddings from the input
- Computes position embeddings from the sequence $[0, 1, ..., T - 1]$
- $x$ is the sum of the token and position embeddings
- $x$ is forwarded through all the blocks
- $x$ is layer normalized one more time
- $x$ is forwarded through a linear layer to transform it from the embedding dimension 
    to the vocab size
- If generating, $p = \text{softmax}(x)$, and the next index is drawn from the 
    distribution $p$

In [None]:
class GenerativeTransformer(nn.Module):
    def __init__(
        self, n_embed: int, n_head: int, block_size: int, 
        vocab_size: int, n_layer: int, dropout: float, bias=True,
    ) -> None:
        super().__init__()
        self.block_size = block_size
        self.wte = nn.Embedding(vocab_size, n_embed)
        self.wpe = nn.Embedding(block_size, n_embed)
        self.drop = nn.Dropout(dropout)

        mask = np.zeros((block_size, block_size), dtype=np.float32)
        mask[np.tril(np.ones((block_size, block_size))) == 0] = -np.inf
        mask = mx.array(mask)
        
        self.h = [Block(n_embed, n_head, mask, dropout, bias) for _ in range(n_layer)]
        self.ln_f = LayerNorm(n_embed, bias=bias)

        self.lm_head = nn.Linear(n_embed, vocab_size, bias=False)
        self.wte.weight = self.lm_head.weight

        def init_weights(_, m: nn.Module):
            if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
                m.weight = nn.init.normal(0.0, 0.02)(m.weight)
                if hasattr(m, "bias") and m.bias is not None:
                    m.bias = mx.zeros_like(m.bias)
        

        self.apply_to_modules(init_weights)
                
    

    def __call__(self, x_idx: mx.array):
        _, T = x_idx.shape

        assert T <= self.block_size, \
            f"cannot forward sequence of length {T}, block size is only {self.block_size}"
        
        pos = mx.arange(0, T, dtype=mx.int64)

        tok_emb = self.wte(x_idx) # shape (B, T, C)
        pos_emb = self.wpe(pos) # shape (T, C)

        # (B, T, C) + (T, C) = (B, T, C)
        # elementwise addition for each batch
        x = self.drop(tok_emb + pos_emb)
        for blk in self.h:
            x = blk(x)
        x = self.ln_f(x)
        x = self.lm_head(x)
        return x
    

    def generate(self, x_idx: mx.array, max_new_tokens: int, temperature=1.0):
        # Take a conditioning sequence of indices x_idx (int64 tensor of shape (B, T)) and 
        # complete the sequence max_new_tokens times, feeding the predictions back into 
        # the model each time. Most likely you"ll want to make sure to be in model.eval() 
        # mode of operation for this.
        for _ in range(max_new_tokens):
            if x_idx.shape[1] <= self.block_size:
                x_idx_cropped = x_idx 
            else:
                x_idx_cropped = x_idx[:, -self.block_size:]

            logits = self(x_idx_cropped)
            logits = logits[:, -1, :] / temperature
            next_idx = mx.random.categorical(logits)[None]
            x_idx = mx.concatenate((x_idx, next_idx), axis=1)
        return x_idx  

# Training

In [None]:
OUT_DIR = "out"
EVAL_INTERVAL = 1000
LOG_INTERVAL = 1

DATA_DIR = "../data"

BATCH_SIZE = 3
BLOCK_SIZE = 1024

lr = 1e-3
WARMUP_ITERS = 2000
LR_DECAY_ITERS = 600000
MIN_LR = 1e-4

In [None]:
def get_batch(split):
    if split == "train":
        data = np.memmap(os.path.join(DATA_DIR, "train.bin"), dtype=np.uint16, mode="r")
    else:
        data = np.memmap(os.path.join(DATA_DIR, "val.bin"), dtype=np.uint16, mode="r")

    def make_block(i):
        return mx.array(data[i:i + BLOCK_SIZE].astype(np.int64))
    

    idxs = np.random.randint(0, len(data) - BLOCK_SIZE, [BATCH_SIZE])
    x = mx.stack([make_block(i) for i in idxs])
    y = mx.stack([make_block(i + 1) for i in idxs])
    return x, y

In [None]:
def get_lr(iter_num):
    if iter_num < WARMUP_ITERS: 
        return lr * iter_num / WARMUP_ITERS 
    
    if iter_num > LR_DECAY_ITERS:
        return MIN_LR
    
    decay_ratio = (iter_num - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
    assert 0 <= decay_ratio and decay_ratio <= 1
    coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
    return MIN_LR + coeff * (lr - MIN_LR)

In [None]:
model = GenerativeTransformer(
    n_embed=768, 
    n_head=12, 
    block_size=BLOCK_SIZE, 
    vocab_size=50304, # 50257 for gpt2
    n_layer=12, 
    dropout=0.0, 
    bias=True,
)

# model.load_weights(f"{OUT_DIR}/model.npz")

model.set_dtype(mx.bfloat16)

optimizer = optimizers.AdamW(lr, (0.9, 0.95), 1e-7, 0.1)

state = [model.state, optimizer.state]

mx.eval(state)

In [None]:
def loss_fn(model, x, y):
    return nn.losses.cross_entropy(model(x), y, reduction="mean")


@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss


@partial(mx.compile, inputs=state)
def eval_step(x, y):
    return loss_fn(model, x, y)


def estimate_loss(n_iters=50):
    out = {}
    for split in ["train", "val"]:
        losses = np.zeros(n_iters, dtype=np.float32)
        for k in range(n_iters):
            x, y = get_batch(split)
            loss = eval_step(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    return out

In [None]:
iter_num = 1
best_val_loss = float("inf")
t0 = time.time()

while True:
    optimizer.learning_rate = get_lr(iter_num)

    x, y = get_batch("train")

    loss = train_step(x, y)
    mx.eval(state)

    if iter_num % LOG_INTERVAL == 0:
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt:.2f}s")

    if iter_num % EVAL_INTERVAL == 0:
        losses = estimate_loss()
        print(f"train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if losses["val"] < best_val_loss:
            best_val_loss = losses["val"]
            model.save_weights(f"{OUT_DIR}/model.npz")        
            print(f"saved checkpoint to {OUT_DIR}")

    iter_num += 1

# Prompt

In [None]:
encoder = tiktoken.get_encoding("gpt2")

In [None]:
def generate(prompt, max_new_tokens=500, temperature=1.0):
    encoded = encoder.encode(prompt)
    x = mx.array(encoded, dtype=mx.int64)[None, :]
    y = model.generate(x, max_new_tokens, temperature)
    decoded = encoder.decode(y[0].tolist())
    return decoded

In [None]:
prompt = "I like to play violin, I play in a"

In [None]:
generate(prompt, max_new_tokens=10)