
# P4_transformers_correction

**Practical session: Transformers vs CNNs (Dyck validation, Addition, Parity)**  
Colab-ready, modular forward pass, timing & FLOPs, and controlled CNN baselines.



### **I.A.** Dyck-1 (balanced parentheses) dataset & validation target

Goal: binary classification (valid vs invalid) of sequences over the alphabet `{ '(', ')' }`, with optional distractor tokens.
- We'll generate balanced (valid) and corrupted (invalid) sequences of variable length.
- Baselines: naive 1D CNN vs dilated CNN vs small Transformer.
- Expectation: naive CNN struggles on long-range nesting; dilation helps; Transformer handles it cleanly.


In [None]:

#@title Imports & utilities
import math, random, time, os, sys
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# Simple seed control
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); 
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(0)

# Vocab and helpers
VOCAB = ['(', ')', 'x']  # 'x' optional distractor
stoi = {c:i for i,c in enumerate(VOCAB)}
itos = {i:c for c,i in stoi.items()}

def gen_dyck1_seq(n_pairs, p_distractor=0.0, max_distractors=0):
    """Generate a (possibly) valid Dyck-1 sequence (string) and a label (1 valid / 0 invalid)."""
    # build a valid sequence by sampling a random Dyck path via push/pop
    seq = []
    depth = 0
    for _ in range(2*n_pairs):
        # ensure validity while allowing some randomness
        if depth == 0:
            seq.append('('); depth += 1
        elif depth == (2*n_pairs - len(seq)):
            seq.append(')'); depth -= 1
        else:
            if random.random() < 0.5:
                seq.append('('); depth += 1
            else:
                seq.append(')'); depth -= 1
    assert depth == 0

    # with 50% probability, corrupt to make an invalid example
    is_valid = (random.random() < 0.5)
    if not is_valid:
        # minimal corruption: flip a random token if possible
        pos = random.randrange(len(seq))
        seq[pos] = '(' if seq[pos] == ')' else ')'
        # ensure it is actually invalid (fall back to another corruption if needed)
        # quick validator
        if validate_dyck1(''.join(seq)):
            # do another flip
            pos2 = (pos + 1) % len(seq)
            seq[pos2] = '(' if seq[pos2] == ')' else ')'

    # inject distractors
    if p_distractor > 0 and max_distractors > 0:
        k = random.randint(0, max_distractors)
        for _ in range(k):
            if random.random() < p_distractor:
                j = random.randrange(len(seq)+1)
                seq.insert(j, 'x')

    return ''.join(seq), int(is_valid)

def validate_dyck1(s):
    depth = 0
    for ch in s:
        if ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
            if depth < 0: return False
        else:
            # ignore distractors
            pass
    return depth == 0

class DyckDataset(Dataset):
    def __init__(self, n_samples=20000, n_pairs_range=(2, 40), p_distractor=0.0, max_distractors=0):
        self.samples = []
        for _ in range(n_samples):
            n_pairs = random.randint(*n_pairs_range)
            s, y = gen_dyck1_seq(n_pairs, p_distractor, max_distractors)
            x = torch.tensor([stoi[c] for c in s], dtype=torch.long)
            self.samples.append((x, torch.tensor(y, dtype=torch.long)))
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

def pad_collate(batch, pad_idx=0):
    xs, ys = zip(*batch)
    lengths = torch.tensor([len(x) for x in xs], dtype=torch.long)
    maxlen = int(lengths.max())
    X = torch.full((len(xs), maxlen), pad_idx, dtype=torch.long)
    for i, x in enumerate(xs):
        X[i, :len(x)] = x
    Y = torch.stack(ys)
    return X, lengths, Y



### **I.B.** Naive 1D CNN baseline (no dilation)

A small temporal CNN with limited receptive field; we pool over time for the final binary decision.


In [None]:

class NaiveCNN(nn.Module):
    def __init__(self, vocab_size=len(VOCAB), emb_dim=32, hidden=64, n_layers=2, kernel=3, num_classes=2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        layers = []
        c_in = emb_dim
        for _ in range(n_layers):
            layers += [nn.Conv1d(c_in, hidden, kernel_size=kernel, padding=kernel//2),
                       nn.ReLU()]
            c_in = hidden
        self.net = nn.Sequential(*layers)
        self.head = nn.Linear(hidden, num_classes)
    def forward(self, x, lengths=None):
        # x: (B, L)
        e = self.emb(x).transpose(1, 2)   # (B, C, L)
        h = self.net(e)                   # (B, H, L)
        h = h.mean(dim=-1)                # global average pool over time
        return self.head(h)               # (B, 2)



### **I.C.** Training loop (shared)

We'll reuse the same training/eval utilities for all models.


In [None]:

@dataclass
class TrainConfig:
    epochs: int = 5
    batch_size: int = 128
    lr: float = 2e-3
    max_batches_per_epoch: int = 200  # cap for speed in class
    clip_grad: float = 1.0

def accuracy(logits, y):
    return (logits.argmax(dim=-1) == y).float().mean().item()

def run_epoch(model, loader, opt=None):
    is_train = opt is not None
    model.train(is_train)
    total_loss, total_acc, n = 0.0, 0.0, 0
    for i, (X, lengths, Y) in enumerate(loader):
        X, Y = X.to(device), Y.to(device)
        logits = model(X, lengths)
        loss = F.cross_entropy(logits, Y)
        if is_train:
            opt.zero_grad(set_to_none=True)
            loss.backward()
            if cfg.clip_grad is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad)
            opt.step()
        total_loss += loss.item() * X.size(0)
        total_acc  += accuracy(logits, Y) * X.size(0)
        n += X.size(0)
        if is_train and cfg.max_batches_per_epoch and i+1 >= cfg.max_batches_per_epoch:
            break
    return total_loss/n, total_acc/n

def train_eval(model, train_loader, val_loader, cfg: TrainConfig):
    model = model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    best = {"val_acc": 0.0, "state": None}
    for ep in range(cfg.epochs):
        tr_loss, tr_acc = run_epoch(model, train_loader, opt)
        va_loss, va_acc = run_epoch(model, val_loader, None)
        if va_acc > best["val_acc"]:
            best = {"val_acc": va_acc, "state": {k:v.detach().cpu() for k,v in model.state_dict().items()}}
        print(f"Epoch {ep+1}/{cfg.epochs} | train acc {tr_acc:.3f} | val acc {va_acc:.3f}")
    # load best
    model.load_state_dict(best["state"])
    return model, best["val_acc"]



### **II.A.** Minimal Transformer (modular forward pass)

We decompose the forward pass to mirror the step-by-step style used in *BE_session2_exercice*:
- `token_embed` + `pos_embed`
- Linear projections to Q, K, V
- Scaled dot-product attention
- Multi-head split/merge
- MLP (feed-forward) per token
- Residual + LayerNorm


In [None]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, L, D)
    def forward(self, x):
        L = x.size(1)
        return x + self.pe[:, :L, :]

def attention_scores(Q, K, mask=None):
    # Q,K: (B, H, L, Dh)
    Dh = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Dh)  # (B,H,L,L)
    if mask is not None:
        scores = scores.masked_fill(mask==0, float('-inf'))
    P = torch.softmax(scores, dim=-1)
    return P

def attention_apply(P, V):
    # P: (B,H,L,L), V:(B,H,L,Dh)
    return torch.matmul(P, V)  # (B,H,L,Dh)

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model=128, n_heads=4, dropout=0.0, causal=False):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model, self.n_heads, self.d_head = d_model, n_heads, d_model // n_heads
        self.Wq = nn.Linear(d_model, d_model, bias=False)
        self.Wk = nn.Linear(d_model, d_model, bias=False)
        self.Wv = nn.Linear(d_model, d_model, bias=False)
        self.Wo = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.causal = causal

    def forward(self, x):
        B, L, D = x.shape
        q = self.Wq(x).view(B, L, self.n_heads, self.d_head).transpose(1,2)  # (B,H,L,Dh)
        k = self.Wk(x).view(B, L, self.n_heads, self.d_head).transpose(1,2)
        v = self.Wv(x).view(B, L, self.n_heads, self.d_head).transpose(1,2)
        # causal mask if needed
        mask = None
        if self.causal:
            mask = torch.tril(torch.ones(L, L, device=x.device)).unsqueeze(0).unsqueeze(0)  # (1,1,L,L)
        P = attention_scores(q, k, mask)        # (B,H,L,L)
        P = self.dropout(P)
        z = attention_apply(P, v)               # (B,H,L,Dh)
        z = z.transpose(1,2).contiguous().view(B, L, D)  # merge heads
        return self.Wo(z)

class TransformerBlock(nn.Module):
    def __init__(self, d_model=128, n_heads=4, mlp_ratio=4, dropout=0.1, causal=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout=dropout, causal=causal)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model*mlp_ratio),
            nn.GELU(),
            nn.Linear(d_model*mlp_ratio, d_model),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size=len(VOCAB), d_model=128, n_layers=2, n_heads=4, num_classes=2, max_len=4096):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len=max_len)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, causal=False) for _ in range(n_layers)])
        self.head = nn.Linear(d_model, num_classes)
    def forward(self, x, lengths=None):
        h = self.tok(x)
        h = self.pos(h)
        for blk in self.blocks:
            h = blk(h)
        # classification: mean-pool over time
        h = h.mean(dim=1)
        return self.head(h)



### **II.B.** Addition and Parity tasks

Two synthetic sequence-to-label tasks:
- **Addition**: strings like `"123+45="` → label is the sum mod 1000 (or predict each digit; here we classify the last digit for simplicity).
- **Parity**: binary strings → label 0/1 is the parity (XOR) of ones.
We keep them as **sequence classification** to compare architectures fairly.


In [None]:

# Addition (predict last digit of the sum) over base-10 characters
DIGITS = [str(i) for i in range(10)]
ADD_VOCAB = DIGITS + ['+','=']
add_stoi = {c:i for i,c in enumerate(ADD_VOCAB)}
add_itos = {i:c for c,i in add_stoi.items()}

def gen_add_sample(n1_digits=3, n2_digits=3):
    a = random.randint(0, 10**n1_digits - 1)
    b = random.randint(0, 10**n2_digits - 1)
    s = f"{a}+{b}="
    y = (a + b) % 10  # last digit
    x = torch.tensor([add_stoi[c] for c in s], dtype=torch.long)
    return x, torch.tensor(y, dtype=torch.long)

class AddDataset(Dataset):
    def __init__(self, n, n1_digits=3, n2_digits=3):
        self.samples = [gen_add_sample(n1_digits, n2_digits) for _ in range(n)]
        self.num_classes = 10
    def __len__(self): return len(self.samples)
    def __getitem__(self, i): return self.samples[i]

def add_pad_collate(batch, pad_idx=0):
    xs, ys = zip(*batch)
    lengths = torch.tensor([len(x) for x in xs], dtype=torch.long)
    maxlen = int(lengths.max())
    X = torch.full((len(xs), maxlen), pad_idx, dtype=torch.long)
    for i, x in enumerate(xs):
        X[i, :len(x)] = x
    Y = torch.stack(ys)
    return X, lengths, Y

# Parity over binary strings
BIN_VOCAB = ['0','1']
bin_stoi = {c:i for i,c in enumerate(BIN_VOCAB)}

def gen_parity_sample(L=64):
    s = ''.join(random.choice(BIN_VOCAB) for _ in range(L))
    y = s.count('1') % 2
    x = torch.tensor([bin_stoi[c] for c in s], dtype=torch.long)
    return x, torch.tensor(y, dtype=torch.long)

class ParityDataset(Dataset):
    def __init__(self, n, L=64):
        self.samples = [gen_parity_sample(L) for _ in range(n)]
    def __len__(self): return len(self.samples)
    def __getitem__(self, i): return self.samples[i]

def bin_pad_collate(batch, pad_idx=0):
    xs, ys = zip(*batch)
    lengths = torch.tensor([len(x) for x in xs], dtype=torch.long)
    maxlen = int(lengths.max())
    X = torch.full((len(xs), maxlen), pad_idx, dtype=torch.long)
    for i, x in enumerate(xs):
        X[i, :len(x)] = x
    Y = torch.stack(ys)
    return X, lengths, Y



### **II.C.** Dilated CNN + experiment harness

We compare:
1) NaiveCNN (no dilation)  
2) DilatedCNN (growing dilation)  
3) TinyTransformer

Same training harness; you can switch datasets.


In [None]:

class DilatedCNN(nn.Module):
    def __init__(self, vocab_size, emb_dim=64, hidden=96, layers_cfg=((3,1),(3,2),(3,4),(3,8)), num_classes=2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        blocks = []
        c_in = emb_dim
        for k, d in layers_cfg:
            pad = (k-1)//2 * d
            blocks += [nn.Conv1d(c_in, hidden, kernel_size=k, dilation=d, padding=pad),
                       nn.ReLU()]
            c_in = hidden
        self.net = nn.Sequential(*blocks)
        self.head = nn.Linear(hidden, num_classes)
    def forward(self, x, lengths=None):
        e = self.emb(x).transpose(1,2)
        h = self.net(e)
        h = h.mean(dim=-1)
        return self.head(h)

# Example: run a quick sanity experiment on Dyck
cfg = TrainConfig(epochs=3, batch_size=128, lr=2e-3, max_batches_per_epoch=150)

train_ds = DyckDataset(n_samples=8000, n_pairs_range=(2,30), p_distractor=0.1, max_distractors=4)
val_ds   = DyckDataset(n_samples=2000, n_pairs_range=(2,40), p_distractor=0.1, max_distractors=6)
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=pad_collate, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=pad_collate, num_workers=0)

print("Naive CNN on Dyck")
model = NaiveCNN(vocab_size=len(VOCAB), emb_dim=32, hidden=64, n_layers=2, kernel=3, num_classes=2)
model, acc_naive = train_eval(model, train_loader, val_loader, cfg)

print("\nDilated CNN on Dyck")
model = DilatedCNN(vocab_size=len(VOCAB), emb_dim=64, hidden=96, layers_cfg=((3,1),(3,2),(3,4),(3,8)), num_classes=2)
model, acc_dil = train_eval(model, train_loader, val_loader, cfg)

print("\nTransformer on Dyck")
model = TinyTransformer(vocab_size=len(VOCAB), d_model=128, n_layers=2, n_heads=4, num_classes=2)
model, acc_tx = train_eval(model, train_loader, val_loader, cfg)

print(f"Val acc: naive {acc_naive:.3f} | dilated {acc_dil:.3f} | transformer {acc_tx:.3f}")



### **II.D.** FLOPs and duration vs number of attention heads

We provide **(i)** a closed-form FLOPs estimate for a single Transformer block (self-attention + MLP) and **(ii)** a micro-benchmark to measure wall-clock as a function of heads `H` for fixed `d_model` and sequence length `L`.


In [None]:

def estimate_block_flops(L, d_model, n_heads, mlp_ratio=4):
    # QKV projections: 3 * (L * d_model * d_model)
    flops_qkv = 3 * L * d_model * d_model
    d_head = d_model // n_heads
    # attention scores & weighted sum: per head: L*L*d_head *2 (matmul multiply-add)
    flops_scores = n_heads * (2 * (L*L*d_head))
    flops_weighted = n_heads * (2 * (L*L*d_head))  # P@V
    # output projection Wo: L * d_model * d_model
    flops_wo = L * d_model * d_model
    # MLP: two linears: L * d_model * (d_model*mlp_ratio) *2
    flops_mlp = 2 * L * d_model * (d_model*mlp_ratio)
    total = flops_qkv + flops_scores + flops_weighted + flops_wo + flops_mlp
    return total

@torch.no_grad()
def time_block(B=16, L=256, d_model=256, n_heads_list=(1,2,4,8), iters=20, warmup=10):
    times = {}
    x = torch.randn(B, L, d_model, device=device)
    for h in n_heads_list:
        blk = TransformerBlock(d_model=d_model, n_heads=h, dropout=0.0, causal=False).to(device)
        # warmup
        for _ in range(warmup):
            _ = blk(x)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(iters):
            _ = blk(x)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        t1 = time.time()
        times[h] = (t1 - t0) / iters
    return times

# Example numbers
L, D = 256, 256
for h in [1,2,4,8]:
    print(f"Heads={h}: est FLOPs/block ~ {estimate_block_flops(L, D, h)/1e9:.2f} GFLOPs")
print("Timing (s/iter):", time_block(B=8, L=L, d_model=D, n_heads_list=(1,2,4,8)))



### **III.** Optional tasks: train on **Addition** and **Parity**

Swap the dataset and the model; observe that:
- Naive CNN is weaker (especially on long dependencies),  
- Dilated CNN closes much of the gap,  
- Transformer remains strong with the right depth/heads.


In [None]:

# Parity (binary, 0/1)
cfg = TrainConfig(epochs=3, batch_size=128, lr=2e-3, max_batches_per_epoch=150)
train_ds = ParityDataset(n=8000, L=128)
val_ds   = ParityDataset(n=2000, L=256)  # harder generalization
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=bin_pad_collate)
val_loader   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=bin_pad_collate)

print("Naive CNN on Parity")
model = NaiveCNN(vocab_size=2, emb_dim=16, hidden=64, n_layers=2, kernel=3, num_classes=2)
model, acc_naive = train_eval(model, train_loader, val_loader, cfg)

print("\nDilated CNN on Parity")
model = DilatedCNN(vocab_size=2, emb_dim=32, hidden=96, layers_cfg=((3,1),(3,2),(3,4),(3,8)), num_classes=2)
model, acc_dil = train_eval(model, train_loader, val_loader, cfg)

print("\nTransformer on Parity")
model = TinyTransformer(vocab_size=2, d_model=64, n_layers=2, n_heads=4, num_classes=2)
model, acc_tx = train_eval(model, train_loader, val_loader, cfg)

print(f"Val acc: naive {acc_naive:.3f} | dilated {acc_dil:.3f} | transformer {acc_tx:.3f}")



### **Appendix A.** Simple training tips
- Use `AdamW`, small number of epochs for classroom speed.
- Clip gradients to stabilize early training.
- Monitor validation accuracy; keep best checkpoint in RAM.



### **Appendix B.** GPU in Colab
Ensure **GPU** is enabled: *Runtime → Change runtime type → Hardware accelerator: GPU*.
