#  BPE vocabulary

In [14]:
import torch
from torch.utils.data import Dataset, DataLoader
import math
from tqdm.auto import tqdm

from transformer import build_transformer

import torch.nn as nn
from torch.optim import AdamW

import tiktoken

## Parameters

In [None]:
PAD, BOS, EOS, = 0, 1, 2
MICHAEL, JIM, PAM = 3, 4, 5
MAX_SRC_LEN = 256 # max tokens for dialogue context
MAX_TGT_LEN = 128 # max tokens for Michael's response

## Helpers

In [4]:
# Define a helper function to read a text file and clean it
def read_lines(p):
    return [l.strip() for l in open(p, encoding="utf-8").read().splitlines() if l.strip()]

`to_bytes` helper:
- Convert a string to a list of byte values (integers 0–255).
- Each character becomes its UTF-8 byte representation.
- This is how a byte-level tokenizer works — every possible character is just one token.

In [5]:
# --- helpers: boolean masks where True = keep, False = pad/blocked ---

def pad_to(ids, L, pad=PAD):
    return ids[:L] + [pad] * max(0, L - len(ids))

def enc_pad_mask(x, pad_id=PAD):
    """(B,S) → (B,1,1,S) True for real tokens"""
    return (x != pad_id).unsqueeze(1).unsqueeze(2)

def dec_masks(tgt_in, pad_id=PAD):
    """(B,T) → tgt_pad (B,1,1,T), tgt_causal (1,1,T,T), tgt_mask (B,1,T,T)"""
    B, T = tgt_in.size()
    tgt_pad = (tgt_in != pad_id).unsqueeze(1).unsqueeze(2)
    tgt_causal = torch.tril(
        torch.ones(T, T, dtype=torch.bool, device=tgt_in.device)
    ).unsqueeze(0).unsqueeze(1)
    tgt_mask = tgt_pad & tgt_causal
    return tgt_pad, tgt_causal, tgt_mask

**`pad_to`** (paddling function):
- if the list ids (token sequence) is longer than $L$, it will truncate it: `ids[:L]`
- if shorter, append enough [pad] tokens to reach length L
  
**`enc_pad_mask`**:
- `x`: input token of shape (batch_size, seq_len)
- `(x != pad_id)`: boolean tensor — False when the token is a padding token PAD, else True
- `unsqueeze(1).unsqueeze(2)`: add two singleton dimensions(B, 1, 1, S), to match the attention score shape (B, heads, T, S)

**`dec_masks`**:
- `tgt_in`: decoder input of shape (B, T)
- `tgt_pad`: same idea as before, marks non-PAD positions → shape (B,1,1,T)
- `tgt_causal`: a lower-triangular matrix of Trues → (1,1,T,T), ensures that token $t$ can only attend $\le t$ (causal/self-attention mask)
- `tgt_mask = tgt_pad & tgt_causal`: combine both → (B,1,T,T), only keeps real tokens in the past or present

## Tokenisation (BPE)

https://sebastianraschka.com/blog/2025/bpe-from-scratch.html

In [22]:
src_lines = read_lines("src.txt")
tgt_lines = read_lines("tgt.txt")

gpt2_tokenizer = tiktoken.get_encoding("gpt2")

gpt2_tokenizer.decode([3392])

' ones'

In [None]:
# Build vocab as set of byte values + specials
vocab = {PAD, BOS, EOS} | set(range(256))
vocab_size = max(vocab)+1  # = 257+ (we're reserving 0..2 already)

def encode_str(s):
    return [BOS] + [b+3 for b in to_bytes(s)] + [EOS]  # shift bytes by +3 to leave 0..2 for specials

def decode_ids(ids):
    b = [i-3 for i in ids if i>=3]
    return bytes(b).decode("utf-8", errors="replace")

# Encode all lines
enc_src = [encode_str(s) for s in src_lines]
enc_tgt = [encode_str(s) for s in tgt_lines]

## Load Data

In [8]:
class OfficeSeq2Seq(Dataset):

    def __init__(self, enc_src, enc_tgt, max_src: int=MAX_SRC_LEN, max_tgt: int=MAX_TGT_LEN) -> None:
        super().__init__()
        self.src = enc_src
        self.tgt = enc_tgt
        self.max_src = max_src
        self.max_tgt = max_tgt

    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, i): 
        s = pad_to(self.src[i], self.max_src)
        t = pad_to(self.tgt[i], self.max_tgt)
        dec_in = t[:-1] # Shift right for teacher forcing
        labels = t[1:]
        return torch.tensor(s), torch.tensor(dec_in), torch.tensor(labels)

`__len__`: tells PyTorch how many samples are in the dataset

`__getitem__`: when DataLoader asks for the i-th item in our dataset:
- take the i-th source (self.src[i]) and pad it to max_src
- take the i-th target (self.tgt[i]) and pad it to max_tgt

**Teacher Forcing & Sequence Shifting**:

In seq2seq models (like Transformers), we train the decoder to predict the next token
given all previous true tokens, not its own predictions.
This is called Teacher Forcing.

We prepare the target sequence `t` like this:

| Token role | Example | Explanation |
|-------------|----------|-------------|
| `t` | `[BOS, H, e, l, l, o, EOS]` | the full target sequence |
| Decoder input `dec_in` | `[BOS, H, e, l, l, o]` | shifted right (starts with BOS) |
| Labels `labels` | `[H, e, l, l, o, EOS]` | shifted left (the "next" tokens) |

During training, at each time step *t*, the model sees the real previous token (from `dec_in[t-1]`) and learns to predict the next token (`labels[t]`).

At inference time, we feed back the model’s *own* predictions instead
(one token at a time).

In [9]:
N = len(enc_src); split = int(0.9*N)
train_ds = OfficeSeq2Seq(enc_src[:split], enc_tgt[:split])
val_ds   = OfficeSeq2Seq(enc_src[split:],  enc_tgt[split:])
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, drop_last=True)
val_dl   = DataLoader(val_ds,   batch_size=32, shuffle=False, drop_last=False)

## Build Model

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
vocab_size = 256 + 3  # bytes + specials (PAD,BOS,EOS)

model = build_transformer(
    src_vocab_size=vocab_size, tgt_vocab_size=vocab_size,
    src_seq_len=MAX_SRC_LEN, tgt_seq_len=MAX_TGT_LEN-1,  # decoder sees T-1
    d_model=256, N=3, h=4, dropout=0.1, d_ff=1024
).to(device)

  nn.init.xavier_uniform(p)


## Loss / Optimizer

In [11]:
crit = nn.CrossEntropyLoss(ignore_index=PAD)
opt  = AdamW(model.parameters(), lr=3e-4, betas=(0.9,0.95), weight_decay=0.05)

## Train / Eval loops

In [23]:
# --- train/eval loops with tqdm progress bar ---

def run_epoch(loader, train=True):
    model.train(train)
    total = 0.0
    steps = 0

    # tqdm bar setup
    phase = "train" if train else "val"
    pbar = tqdm(loader, desc=f"{phase} epoch", leave=False)
    
    for src, dec_in, labels in pbar:
        src, dec_in, labels = src.to(device), dec_in.to(device), labels.to(device)

        # Masks
        src_mask = enc_pad_mask(src)
        _, _, tgt_mask = dec_masks(dec_in)

        # Forward
        enc_out = model.encode(src, src_mask)
        dec_out = model.decode(enc_out, src_mask, dec_in, tgt_mask)
        logits = model.project(dec_out)

        loss = crit(logits.reshape(-1, logits.size(-1)), labels.view(-1))

        if train:
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

        total += loss.item()
        steps += 1

        # update progress bar
        pbar.set_postfix({"loss": f"{loss.item():.3f}"})

    pbar.close()
    return total / max(1, steps)


@torch.no_grad()
def evaluate(loader):
    return run_epoch(loader, train=False)


# --- main training loop with progress and stats ---
for epoch in range(10):
    tr = run_epoch(train_dl, True)
    va = evaluate(val_dl)
    ppl = math.exp(va)
    print(f"epoch {epoch+1:02d} | train {tr:.3f} | val {va:.3f} | ppl {ppl:.2f}")

NameError: name 'train_dl' is not defined

**`for src, dec_in, labels in pbar`** batch loop:
- `src`: source sequence (the dialogue before Michael speaks)
- `dec_in`: decoder input (shifted Michael response)
- `labels`: expected next tokens
- `src_mask` &  `tgt_mask`: builds proper masks for src & tgt
- `enc_out`: encode the src sequence
- `dec_out`: decode given src sequence & previous outputs
- `logits`: projects to vocab size
- `loss`: computes the CrossEntropyLoss
- `if train`: then backpropagates the weights

## Generation

In [None]:
def dec_mask_from(dec_in, pad_id=PAD):
    # (B,T) -> (B,1,T,T) keep-mask (padding ∧ causal)
    B, T = dec_in.size()
    tgt_pad = (dec_in != pad_id).unsqueeze(1).unsqueeze(2)               # (B,1,1,T)
    tgt_causal = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dec_in.device)).unsqueeze(0).unsqueeze(1)  # (1,1,T,T)
    return tgt_pad & tgt_causal

@torch.no_grad()
def generate_reply(context_text, max_new_tokens=80):
    model.eval()
    # encode & pad the source once
    src_ids = pad_to(encode_str(context_text), MAX_SRC_LEN)
    src = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)   # (1,S)
    src_mask = enc_pad_mask(src)                                                # (1,1,1,S)
    enc_out = model.encode(src, src_mask)                                       # (1,S,d)

    # start decoder with BOS
    dec = torch.tensor([[BOS]], dtype=torch.long, device=device)                # (1,1)

    # generate until EOS or max_new_tokens reached
    while (dec.size(1) - 1) < max_new_tokens:
        tgt_mask = dec_mask_from(dec)                                           # (1,1,T,T)
        dec_out = model.decode(enc_out, src_mask, dec, tgt_mask)                # (1,T,d)
        logits  = model.project(dec_out)                                        # (1,T,V)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)              # greedy → (1,1)
        dec = torch.cat([dec, next_token], dim=1)                               # append
        if next_token.item() == EOS:
            break

    return decode_ids(dec[0].tolist())

# example
print(generate_reply("[JIM] Is Dwight the assistant regional manager?"))

[MICHAEL] Wel, that that t that t that t t thathat t t thathe t t the thathe ath
