In [None]:
from typing import Sequence, Generator
import torch
from torch import nn, optim, Tensor
import torch.nn.functional as F

import time
import numpy as np
import matplotlib.pyplot as plt

# Load and Process Data

In [None]:
with open('../data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

Tokens are chars, so the vocab size is the number of unique chars

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

char2int = {c: i for i, c in enumerate(chars)}
int2char = {i: c for i, c in enumerate(chars)}


def encode(s: str) -> list[int]:
    return [char2int[c] for c in s if c in char2int]


def decode(y: list[int] | np.ndarray | Tensor) -> str:
    return "".join([int2char[int(i)] for i in y if int(i) in int2char])

The input text is encoded as an `Tensor`, then split into training and validation
splits

In [None]:
full_data = torch.tensor(encode(text), dtype=torch.int64)

val_size = len(full_data) // 10

train_data = full_data[val_size:]
val_data = full_data[:val_size]

### Convert data into blocks

$x_i = [d_i, d_{i + 1}, ..., d_{i + b}]$

$y_i = [d_{i + 1}, d_{i + 2}, ..., d_{i + b + 1}]$

In [None]:
def block_data(data: Tensor, block_size: int) -> tuple[Tensor, Tensor]:
    n_blocks = len(data) - block_size - 1
    x = torch.stack([data[i:i + block_size] for i in range(n_blocks)])
    y = torch.stack([data[i:i + block_size] for i in range(1, n_blocks + 1)])
    return x, y

### Generate random batches for dataset

In [None]:
def batch_iterate(
    x: Tensor, 
    y: Tensor, 
    batch_size: int,
) -> Generator[tuple[Tensor, Tensor], None, None]:
    permutation = torch.tensor(np.random.permutation(y.shape[0]), dtype=torch.int64)
    for s in range(0, y.shape[0], batch_size):
        idxs = permutation[s:s + batch_size]
        bx = x[idxs]
        by = y[idxs]
        yield bx, by

# Model

### Layer Normalization

$y = \frac{x - E[x]}{\sqrt{E[(x - E[x])^2] + \epsilon}} \odot \gamma + \beta$

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape: Sequence[int], bias=True, eps=1e-5) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.ones(normalized_shape)) if bias else None
        self.eps = eps
    

    def forward(self, x: Tensor) -> Tensor: 
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, self.eps) 

### Self Attention

$[q_{i, j}, k_{i, j}, v_{i, j}] = x_{i, j}[W_q, W_k, W_v] + [b_q, b_k, b_v]$ 
where all $q, k, v, x, b$ are row vectors

$[q_{i, j}, k_{i, j}, v_{i, j}]$ are computed for $x_i$ in the 3D tensor 
$x = \begin{bmatrix}
x_{1, 1} & \dots & x_{1, T} \\
\vdots & \ddots & \vdots \\
x_{B, 1} & \dots & x_{B, T}
\end{bmatrix}$ resulting in tensors $q, k, v$

$x$ has shape $(B, T, C)$ where $B$ is the batch size, $T$ is the sequence length,
and $C$ is the number of embedding dimensions

$q, k, v$ have shape $(B, T, ND)$ where $B$ is the batch size, 
$T$ is the sequence length, $N$ is the number of attention heads, and $D$ is the
number of query/key dimensions

$q, k, v$ are reshaped to $(B, N, T, D)$

$a_{i, j} = q_{i, j} k_{i, j}^T$

$a_{i, j} = -\infty$ for all $i < j$

$a_{i, j} = \text{softmax}(a_{i, j})$ where softmax is computed rowwise

$y_{i, j} = a_{i, j} v_{i, j}$

$y$ has shape $(B, N, T, D)$

$y$ is reshaped to $(B, T, ND)$, so $y_{i, j}$ is a row vector

$y_{i, j} = y_{i, j}W_p + b_p$ 

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, n_embed: int, n_head: int, dropout: float, bias=True) -> None:
        super().__init__()
        self.n_embed = n_embed
        self.n_head = n_head
        assert n_embed % n_head == 0 
        self.D = n_embed // n_head

        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)
        self.res_dropout = nn.Dropout(dropout)

        self.attn_scale = 1.0 / np.sqrt(self.D)
    

    def forward(self, x: Tensor) -> Tensor:
        B, T, n_embed = x.shape
        assert n_embed == self.n_embed

        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)

        q = q.view(B, T, self.n_head, self.D).transpose(1, 2) # gives shape (B, N, T, D)
        k = k.view(B, T, self.n_head, self.D).transpose(1, 2) # gives shape (B, N, T, D)
        v = v.view(B, T, self.n_head, self.D).transpose(1, 2) # gives shape (B, N, T, D)

        y = F.scaled_dot_product_attention(
            q, k, v, 
            dropout_p=self.dropout, 
            is_causal=True,
            scale=self.attn_scale,
        )
        
        y = y.transpose(1, 2).contiguous().view(B, T, self.n_embed) # concat head outputs 
        y = self.c_proj(y)
        y = self.res_dropout(y)
        return y

### MLP

$x_{i, j} = x_{i, j} W_{c} + B_{c}$

$y_{i, j} = x_{i, j} W_{p} + B_{p}$

In [None]:
class MLP(nn.Module):
    def __init__(self, n_embed: int, dropout: float, bias=True):
        super().__init__()
        self.c_fc = nn.Linear(n_embed, 4 * n_embed, bias=bias)
        self.c_proj = nn.Linear(4 * n_embed, n_embed, bias=bias)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x: Tensor) -> Tensor:
        x = F.gelu(self.c_fc(x))
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

### Block

Composition of layer normalization, self attention, and mlp

In [None]:
class Block(nn.Module):
    def __init__(self, n_embed: int, n_head: int, dropout: float, bias=True) -> None:
        super().__init__()
        self.ln_1 = LayerNorm(n_embed, bias=bias)
        self.attn = SelfAttention(n_embed, n_head, dropout, bias)
        self.ln_2 = LayerNorm(n_embed, bias=bias)
        self.mlp = MLP(n_embed, dropout, bias)


    def forward(self, x: Tensor) -> Tensor:
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

### Generative Transformer
- Input is an array of token indexes
- Computes token embeddings from the input
- Computes position embeddings from the sequence $[0, 1, ..., T - 1]$
- $x$ is the sum of the token and position embeddings
- $x$ is forwarded through all the blocks
- $x$ is layer normalized one more time
- $x$ is forwarded through a linear layer to transform it from the embedding dimension 
    to the vocab size
- If generating, $p = \text{softmax}(x)$, and the next index is drawn from the 
    distribution $p$

In [None]:
class GenerativeTransformer(nn.Module):
    def __init__(
        self, n_embed: int, n_head: int, block_size: int, 
        vocab_size: int, n_layer: int, dropout: float, bias=True,
    ) -> None:
        super().__init__()
        self.block_size = block_size

        args = (n_embed, n_head, dropout, bias)

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, n_embed),
            wpe = nn.Embedding(block_size, n_embed),
            drop = nn.Dropout(dropout),
            h = nn.ModuleList([Block(*args) for _ in range(n_layer)]),
            ln_f = LayerNorm(n_embed, bias=bias),
        ))

        self.lm_head = nn.Linear(n_embed, vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight


        def init_weights(module: nn.Module) -> None:
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


        self.apply(init_weights)
    

    def forward(self, x_idx: Tensor) -> Tensor:
        device = x_idx.device
        _, T = x_idx.shape

        assert T <= self.block_size, \
            f"cannot forward sequence of length {T}, block size is only {self.block_size}"
        
        pos = torch.arange(0, T, dtype=torch.int64, device=device)

        tok_emb = self.transformer.wte(x_idx) # shape (B, T, C)
        pos_emb = self.transformer.wpe(pos) # shape (T, C)

        # (B, T, C) + (T, C) = (B, T, C)
        # elementwise addition for each batch
        x = self.transformer.drop(tok_emb + pos_emb)
        for blk in self.transformer.h:
            x = blk(x)
        x = self.transformer.ln_f(x)
        x = self.lm_head(x)
        return x
    

    @torch.no_grad()
    def generate(self, x_idx: Tensor, max_new_tokens: int, temperature=1.0) -> Tensor:
        # Take a conditioning sequence of indices x_idx (int64 tensor of shape (B, T)) and 
        # complete the sequence max_new_tokens times, feeding the predictions back into 
        # the model each time. Most likely you"ll want to make sure to be in model.eval() 
        # mode of operation for this.
        for _ in range(max_new_tokens):
            if x_idx.shape[1] <= self.block_size:
                x_idx_cropped = x_idx 
            else:
                x_idx_cropped = x_idx[:, -self.block_size:]

            logits = self(x_idx_cropped)
            logits = logits[:, -1, :] / temperature

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            x_idx = torch.cat((x_idx, idx_next), dim=1)

        return x_idx  

# Training

In [None]:
EVAL_INTERVAL = 2500
LOG_INTERVAL = 500

BLOCK_SIZE = 32
BATCH_SIZE = 16

DEVICE = "mps"
DTYPE = torch.float32

MAX_ITERS = 10000

MAX_LR = 1e-4
WARMUP_ITERS = 100
LR_DECAY_ITERS = 2500
MIN_LR = 1e-5

### Convert data to blocks

$x_i = [d_i, d_{i + 1}, ..., d_{i + b}]$

$y_i = [d_{i + 1}, d_{i + 2}, ..., d_{i + b + 1}]$

In [None]:
x_train, y_train = block_data(train_data, BLOCK_SIZE)
x_val, y_val = block_data(val_data, BLOCK_SIZE)

### Initialize model and optimizer

In [None]:
model = GenerativeTransformer(
    n_embed=640, # changed so n_embed % n_head == 0
    n_head=4, 
    block_size=BLOCK_SIZE, 
    vocab_size=vocab_size,
    n_layer=4, 
    dropout=0.0, 
    bias=True,
).to(device=DEVICE, dtype=DTYPE)

optimizer = optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
# checkpoint = torch.load("checkpoints/checkpoint.pt")
# model.load_state_dict(checkpoint["model"])
# optimizer.load_state_dict(checkpoint["optimizer"])

Estimate loss from tensors $x, y$

In [None]:
@torch.no_grad()
def evaluate_loss(x, y, max_iters=100):
    loss_sum = 0
    cnt = 0
    for i, (bx, by) in enumerate(batch_iterate(x, y, BATCH_SIZE)):
        if i >= max_iters:
            break

        bx = bx.to(DEVICE)
        by = by.to(DEVICE)
    
        logits = model(bx)
        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), by.view(-1)).cpu()
        loss_sum += loss.cpu() * len(x)
        cnt += len(x)
    return loss_sum / cnt

### Change learning rate over time

$\eta_i = \begin{cases}
    \frac{\eta \cdot i}{N_{\text{warmup}}} & i < N_{\text{warmup}} \\
    \eta_{\text{min}} + \left(
        \frac{1}{2} + \frac{1}{2}\cos\left(
            \pi \frac{N_{\text{warmup}} \cdot i}{N_{\text{decay}} - N_{\text{warmup}}}
        \right)
    \right)(\eta_0 - \eta_{\text{min}}) & N_{\text{warmup}} \leq i < N_{\text{decay}} \\
    \eta_{\text{min}} & N_{\text{decay}} \leq i
\end{cases}$

In [None]:
def get_lr(iter_num: int) -> float:
    if iter_num < WARMUP_ITERS: 
        return MAX_LR * iter_num / WARMUP_ITERS 
    
    if iter_num > LR_DECAY_ITERS:
        return MIN_LR
    
    decay_ratio = (iter_num - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
    assert 0 <= decay_ratio and decay_ratio <= 1
    coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio))
    return MIN_LR + coeff * (MAX_LR - MIN_LR)

In [None]:
plt.plot([get_lr(i) for i in range(1, MAX_ITERS + 1)])
plt.xlabel("Iteration")
plt.ylabel("Learning Rate")
plt.show()

### Crossentropy loss:

$l(x, y, \theta) = -\sum_i y_i \log(f(x_i, \theta))$

<br>

### Train Step with Adam Optimizer

$g_t = \nabla_{\theta_{t - 1}} l(x, y, )$

$\alpha = \eta \frac{\sqrt{1 - \beta_2^t}}{1 - \beta_1^t}$

$m_t = \beta_1 m_{t - 1} + (1 - \beta_1)g_t$

$m_t = \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2$

$\theta_t = \theta_{t - 1} - \alpha \frac{m_t}{\sqrt{v_t} + \epsilon}$

In [None]:
i = 1
t0 = time.time()
best_val_loss = float('inf')

while True:
    if i > MAX_ITERS:
        break
    
    for x, y in batch_iterate(x_train, y_train, batch_size=BATCH_SIZE):
        if i > MAX_ITERS:
            break

        lr = get_lr(i)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        x = x.to(DEVICE)
        y = y.to(DEVICE)
        
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        if i % LOG_INTERVAL == 0:
            t1 = time.time()
            dt = t1 - t0
            t0 = t1
            print(f"[{i:4}] loss: {loss.item():.3f}, time: {dt:.3f}s")
        
        if i % EVAL_INTERVAL == 0:
            train_loss = evaluate_loss(x_train, y_train)
            val_loss = evaluate_loss(x_val, y_val)
            print(f"    train loss: {train_loss:.4f}, val loss: {val_loss:.4f}")
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                checkpoint = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "iter_num": i,
                    "best_val_loss": best_val_loss,
                }
                torch.save(checkpoint, "checkpoints/checkpoint1.pt")
                print(f"    saved checkpoint to checkpoints")

        i += 1   

# float32:
# 5m 47.4s
# train loss: 1.5059, val loss: 1.6205

# bfloat16:
# 4m 28.9s
# train loss: 2.0559, val loss: 2.1174

# 

# Testing

In [None]:
model = GenerativeTransformer(
    n_embed=640, # changed so n_embed % n_head == 0
    n_head=4, 
    block_size=BLOCK_SIZE, 
    vocab_size=vocab_size,
    n_layer=4, 
    dropout=0.0, 
    bias=True,
).to(device=DEVICE, dtype=DTYPE)

In [None]:
checkpoint = torch.load("checkpoints/checkpoint1.pt")
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

In [None]:
train_loss = evaluate_loss(x_train, y_train, max_iters=500)
val_loss = evaluate_loss(x_val, y_val)
print(f"train loss: {train_loss:.4f}, val loss: {val_loss:.4f}")

In [None]:
context = torch.tensor(encode("Hello, my name is "), dtype=torch.int64).to(DEVICE)
output = model.generate(context[None], max_new_tokens=100, temperature=1)
print(decode(output[0]))