# Assignment 3: Transformer is All You Need

Federico Giorgi (fg2617)

## Basics

In [13]:
# Import all the libraries
import os, math, torch, urllib.request
from dataclasses import dataclass

from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

## 1 Data Preparation

### 1.1 Load the Tiny Shakespeare text

In [12]:
# Define local path for dataset
TINY_PATH = "tiny_shakespeare.txt"

# Download the Tiny Shakespeare corpus from Karpathyâ€™s repo
urllib.request.urlretrieve(
    "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
    TINY_PATH
)

# Read the corpus into memory as a single string
with open(TINY_PATH, "r", encoding="utf-8") as f:
    corpus_text = f.read()

# Print the first 100 characters
print(corpus_text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


### 1.2 Tokenization

In [10]:
# Define special tokens for padding, unknown words, and sequence boundaries
special_tokens = ["[PAD]", "[UNK]", "[BOS]", "[EOS]"]

# Initialize a Byte Pair Encoding (BPE) tokenizer with an unknown token
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))

# Split text initially on whitespace before BPE merges are applied
tokenizer.pre_tokenizer = Whitespace()

# Train the BPE tokenizer on the corpus, limiting vocabulary size to 500 tokens
trainer = BpeTrainer(vocab_size=500, min_frequency=2, special_tokens=special_tokens)
tokenizer.train_from_iterator([corpus_text], trainer=trainer)

# Retrieve token IDs for the special tokens and vocabulary size
pad_id  = tokenizer.token_to_id("[PAD]")
bos_id  = tokenizer.token_to_id("[BOS]")
eos_id  = tokenizer.token_to_id("[EOS]")
vocab_size = tokenizer.get_vocab_size()

# Encode entire corpus to integer IDs
ids = tokenizer.encode(corpus_text).ids

### 1.3 Sequence formatting

In [None]:
# For next-token prediction: Input = first N tokens, Target = same sequence shifted by 1
SEQ_LEN = 50

def make_windows(token_ids, seq_len):
    # produce (inp, tgt) pairs with stride=1, overlapping
    # last complete window ends at len-1 to allow shift
    L = len(token_ids)
    # Need inp length = seq_len, tgt length = seq_len, so we need i .. i+seq_len for inp and i+1 .. i+seq_len+1 for tgt
    # That means i must go until L - (seq_len + 1)
    limit = L - (seq_len + 1)
    inputs = []
    targets = []
    for i in range(0, max(0, limit + 1)):
        seq = token_ids[i : i + seq_len + 1]
        inp = seq[:-1]
        tgt = seq[1:]
        inputs.append(inp)
        targets.append(tgt)
    return inputs, targets

inputs, targets = make_windows(ids, SEQ_LEN)

### 1. 4 Data split

In [None]:
# ========= 3) 80/20 split (on sequence-pairs) =========
dataset_size = len(inputs)
train_size = int(0.8 * dataset_size)
val_size   = dataset_size - train_size

class NextTokenDataset(Dataset):
    def __init__(self, X, Y):
        self.X = [torch.tensor(x, dtype=torch.long) for x in X]
        self.Y = [torch.tensor(y, dtype=torch.long) for y in Y]
    def __len__(self): return len(self.X)
    def __getitem__(self, i): return self.X[i], self.Y[i]

full_ds = NextTokenDataset(inputs, targets)
train_ds, val_ds = random_split(full_ds, [train_size, val_size], generator=torch.Generator().manual_seed(42))

### 1.5 Token embedding

In [None]:
BATCH_SIZE = 128
def collate_batch(batch):
    # All sequences are already length SEQ_LEN, so simple stack
    X = torch.stack([b[0] for b in batch], dim=0)  # (B, T)
    Y = torch.stack([b[1] for b in batch], dim=0)  # (B, T)
    return X, Y

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=collate_batch)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, drop_last=False, collate_fn=collate_batch)

# ========= 5) Token embeddings + positional encodings =========
# Option A (learned positions): nn.Embedding for both tokens and positions
class TokenPosEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_len, d_model)
    def forward(self, x):
        # x: (B, T) token IDs
        B, T = x.size()
        pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)  # (B, T)
        return self.tok(x) + self.pos(pos)  # (B, T, d_model)

# Option B (sinusoidal positions): classic transformer-style fixed encodings
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)  # (max_len, d_model)
    def forward(self, x):
        # x: (B, T, d_model) token embeddings
        T = x.size(1)
        return x + self.pe[:T].unsqueeze(0)

# Example: build embeddings and run one batch through
device = "cuda" if torch.cuda.is_available() else "cpu"
d_model = 256
max_len = SEQ_LEN  # since our sequences are fixed-length windows

tokpos = TokenPosEmbedding(vocab_size, d_model, max_len).to(device)
sinpos = SinusoidalPositionalEncoding(d_model, max_len).to(device)

xb, yb = next(iter(train_loader))  # (B, T), (B, T)
xb = xb.to(device)
emb_tokpos = tokpos(xb)            # (B, T, d_model) learned positions
emb_sin    = sinpos(tokpos.tok(xb))# (B, T, d_model) sinusoidal positions added to token embeddings

print("Vocab size:", vocab_size)
print("Train batches:", len(train_loader), "Val batches:", len(val_loader))
print("Embedded shapes (learned / sinusoidal):", emb_tokpos.shape, emb_sin.shape)

# Your model can now consume emb_tokpos (or emb_sin). For next-token prediction,
# typical loss is cross-entropy over logits shaped (B, T, vocab_size) vs targets (B, T).


Vocab size: 500
Train batches: 2800 Val batches: 701
Embedded shapes (learned / sinusoidal): torch.Size([128, 50, 256]) torch.Size([128, 50, 256])


## 2 Tiny Transformer Implementation

In [None]:
# ===== RMSNorm =====
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-8):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps
    def forward(self, x):
        # x: (..., d_model)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        x_norm = x / rms
        return self.weight * x_norm

# ===== Sinusoidal Positional Encoding =====
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe)   # (max_len, d_model)
    def forward(self, x):                # x: (B,T,D)
        T = x.size(1)
        return x + self.pe[:T].unsqueeze(0)

# ===== Causal Self-Attention =====
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, attn_dropout=0.0, resid_dropout=0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop = nn.Dropout(attn_dropout)
        self.resid_drop = nn.Dropout(resid_dropout)
        self.register_buffer("mask_cache", None, persistent=False)

    def _causal_mask(self, T, device):
        # Cache an upper-triangular mask of shape (1,1,T,T) with -inf above diagonal
        if self.mask_cache is None or self.mask_cache.size(-1) < T:
            size = max(T, 1)
            m = torch.full((1, 1, size, size), float("-inf"))
            m = torch.triu(m, diagonal=1)
            self.mask_cache = m.to(device)
        return self.mask_cache[:, :, :T, :T]

    def forward(self, x):
        B, T, D = x.shape
        qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.head_dim).transpose(1, 2)  # (B,3,T,H,Hd) -> (B,3,H,T,Hd)
        q, k, v = qkv[:,0], qkv[:,1], qkv[:,2]                                       # each: (B,H,T,Hd)

        # scaled dot-product attention with causal mask
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)                   # (B,H,T,T)
        att = att + self._causal_mask(T, x.device)                                   # prevent attending to future
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v                                                                  # (B,H,T,Hd)
        y = y.transpose(1, 2).contiguous().view(B, T, D)                             # (B,T,D)
        y = self.proj(y)
        return self.resid_drop(y)

# ===== MLP / Feed-Forward =====
class FeedForward(nn.Module):
    def __init__(self, d_model, expansion=4, dropout=0.0):
        super().__init__()
        hidden = expansion * d_model
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.GELU(),
            nn.Linear(hidden, d_model),
            nn.Dropout(dropout),
        )
    def forward(self, x): return self.net(x)

# ===== Transformer Block (Pre-norm + RMSNorm) =====
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, attn_dropout=0.0, resid_dropout=0.0, ff_dropout=0.0, expansion=4):
        super().__init__()
        self.norm1 = RMSNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, attn_dropout, resid_dropout)
        self.norm2 = RMSNorm(d_model)
        self.ffn  = FeedForward(d_model, expansion, ff_dropout)
    def forward(self, x):
        x = x + self.attn(self.norm1(x))   # residual + pre-norm
        x = x + self.ffn(self.norm2(x))    # residual + pre-norm
        return x

# ===== Tiny Transformer Language Model =====
class TinyTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_layers=4, n_heads=4, max_seq_len=SEQ_LEN, dropout=0.1):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_seq_len)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, attn_dropout=dropout, resid_dropout=dropout, ff_dropout=dropout, expansion=4)
            for _ in range(n_layers)
        ])
        self.norm_f = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # tie weights (optional, common trick)
        self.lm_head.weight = self.tok.weight

    def forward(self, idx, targets=None):
        # idx: (B,T) token ids
        x = self.tok(idx)                      # (B,T,D)
        x = self.pos(x)                        # add sinusoidal positions
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm_f(x)
        logits = self.lm_head(x)               # (B,T,V)

        loss = None
        if targets is not None:
            # Cross-entropy over vocabulary at each time step
            # Flatten to (B*T,V) vs (B*T,)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

# ===== Instantiate model and train =====
torch.manual_seed(0)
model = TinyTransformerLM(vocab_size=vocab_size, d_model=256, n_layers=4, n_heads=4, max_seq_len=SEQ_LEN, dropout=0.1).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01)

def evaluate(loader):
    model.eval()
    total_loss, n_tokens = 0.0, 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            _, loss = model(xb, yb)
            total = xb.numel()
            total_loss += loss.item() * total
            n_tokens += total
    ppl = math.exp(total_loss / n_tokens)
    return total_loss / n_tokens, ppl

EPOCHS = 2  # raise for better quality
for epoch in range(1, EPOCHS + 1):
    model.train()
    tot, count = 0.0, 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        _, loss = model(xb, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        tot += loss.item()
        count += 1
    val_loss, val_ppl = evaluate(val_loader)
    print(f"Epoch {epoch:02d} | train_loss {tot/count:.4f} | val_loss {val_loss:.4f} | val_ppl {val_ppl:.2f}")

# ===== Quick sampling (greedy) =====
@torch.no_grad()
def generate(prefix_ids, max_new_tokens=50):
    model.eval()
    x = torch.tensor(prefix_ids, dtype=torch.long, device=device).unsqueeze(0)  # (1,T)
    for _ in range(max_new_tokens):
        x_in = x[:, -SEQ_LEN:]                      # crop to block size
        logits, _ = model(x_in)
        next_id = logits[:, -1, :].argmax(dim=-1)   # greedy
        x = torch.cat([x, next_id.unsqueeze(1)], dim=1)
    return x.squeeze(0).tolist()

# Example: take the first training sequence prefix and generate a few tokens
sample_inp, _ = next(iter(train_loader))
seed = sample_inp[0].tolist()[:20]
gen_ids = generate(seed, max_new_tokens=20)

print("Seed IDs:    ", seed)
print("Generated IDs", gen_ids)


RuntimeError: The size of tensor a (4) must match the size of tensor b (50) at non-singleton dimension 3