In [None]:
# micro_abc_sorter_levels.py
# Transformer forks for tiny ABC sorter with curriculum (train len=3, test 3..8).
# Variants: vanilla, lora (W = W0 + A@B), softmax_gate (scaled), tanh_gate (zero-centered),
# and strict_softmax (no base). Accurate param counts, same training loop.

import math, random, itertools
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------ Repro & device ------------------------
torch.manual_seed(42); random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ------------------------ Config ------------------------
# Task
ALPHABET = ['a','b','c']  # keep 3 for now
TRAIN_LEN = 3             # train on 3 inputs
TEST_LENS = [3,4,5,6,7,8] # evaluate generalization
MAX_LEN   = max(TEST_LENS)
DUPLICATES_ALLOWED = True

# Model/opt
USE_POS_EMBED   = True
E = 6                  # embedding dim
H = 2                  # heads
assert E % H == 0
D = E // H
MLP_MULT = 1
BATCH   = 128
STEPS   = 15000
BASE_LR = 1e-3
WARMUP_STEPS = 500
GRAD_CLIP = 1.0

# Deeper-QKV knobs (shared across variants unless overridden)
RANK         = 3        # factorization rank r
SOFTMAX_TEMP = 1.0      # temperature in AB/τ
LORA_ALPHA   = 4.0      # LoRA scaling (used as alpha/r)
GATE_ALPHA_INIT = D     # init for per-head alpha in softmax_gate
GATE_BETA_INIT  = 1.0   # init for per-head beta in tanh_gate

# Context / embeddings: set to handle the longest test
T_MAX = 2*MAX_LEN + 1   # input L + SEP + output L

# Runs to perform (edit this list to try more/less)
RUNS = [
    {"name": "vanilla"},
    {"name": "lora", "rank": RANK, "lora_alpha": LORA_ALPHA},
    {"name": "softmax_gate", "rank": RANK, "temp": SOFTMAX_TEMP, "alpha_init": GATE_ALPHA_INIT},
    # You can uncomment to try these too:
    # {"name": "tanh_gate", "rank": RANK, "temp": SOFTMAX_TEMP, "beta_init": GATE_BETA_INIT},
    # {"name": "strict_softmax", "rank": RANK, "temp": SOFTMAX_TEMP, "alpha_init": GATE_ALPHA_INIT},
]

# ------------------------ Vocab ------------------------
vocab = [''] + ALPHABET     # '' is SEP
SEP = 0
stoi = {ch: i for i, ch in enumerate(vocab)}
itos = {i: ch for ch, i in stoi.items()}
V = len(vocab)

# ------------------------ Data ------------------------
def make_batch(B, L):
    """Return (B, T_MAX) long tensor where the last positions are padding that won't be used.
       Sequence format: [x1..xL, SEP, y1..yL] where y is sorted(x) ascending."""
    x = torch.full((B, 2*L+1), SEP, dtype=torch.long)
    ids = list(range(1, len(vocab)))  # a..c as 1..V-1
    for i in range(B):
        if DUPLICATES_ALLOWED:
            seq = random.choices(ids, k=L)
        else:
            seq = random.sample(ids, k=L)
        toks = seq + [SEP] + sorted(seq)
        x[i] = torch.tensor(toks, dtype=torch.long)
    # pad to T_MAX on the right if needed (model accepts variable Lt<=T_MAX)
    if x.size(1) < T_MAX:
        pad = torch.full((B, T_MAX - x.size(1)), SEP, dtype=torch.long)
        x = torch.cat([x, pad], dim=1)
    return x.to(device)

# ------------------------ Modules ------------------------
class MHAFork(nn.Module):
    """
    Multi-head attention where each head's Q/K/V matrix W∈R^{E×D} is
    produced per variant:
      - 'vanilla':         W = W_base
      - 'lora':            W = W_base + (alpha/rank) * (A @ B)
      - 'softmax_gate':    W = W_base ⊙ (alpha * softmax(A @ B / temp))      # alpha per head
      - 'tanh_gate':       W = W_base ⊙ (1 + beta * tanh(A @ B / temp))      # beta per head
      - 'strict_softmax':  W = alpha * softmax(A @ B / temp)                  # no base
    Shapes:
      x: (B,T,E), W: (H,E,D), q/k/v: (B,H,T,D)
    """
    def __init__(self, E, H, T_max, variant="vanilla",
                 rank=3, temp=1.0, lora_alpha=4.0,
                 alpha_init=D, beta_init=1.0):
        super().__init__()
        self.H, self.D, self.E = H, E // H, E
        self.variant = variant
        self.rank = rank
        self.temp = temp
        self.lora_alpha = lora_alpha
        self.alpha_init = alpha_init
        self.beta_init = beta_init

        # Per-head base weights (present in all but strict_softmax)
        if variant != "strict_softmax":
            self.q_base = nn.Parameter(torch.empty(H, E, self.D))
            self.k_base = nn.Parameter(torch.empty(H, E, self.D))
            self.v_base = nn.Parameter(torch.empty(H, E, self.D))
        else:
            self.register_parameter("q_base", None)
            self.register_parameter("k_base", None)
            self.register_parameter("v_base", None)

        # Factorized params needed by all but vanilla
        if variant in ["lora", "softmax_gate", "tanh_gate", "strict_softmax"]:
            def pair():
                A = nn.Parameter(torch.empty(H, E, rank))
                B = nn.Parameter(torch.empty(H, rank, self.D))
                return A, B
            self.qA, self.qB = pair()
            self.kA, self.kB = pair()
            self.vA, self.vB = pair()
        else:
            for name in ["qA","qB","kA","kB","vA","vB"]:
                self.register_parameter(name, None)

        # Head-level scalars for gates
        if variant == "softmax_gate" or variant == "strict_softmax":
            # learnable alpha per head, broadcast to (E,D)
            self.alpha_q = nn.Parameter(torch.full((H,1,1), float(alpha_init)))
            self.alpha_k = nn.Parameter(torch.full((H,1,1), float(alpha_init)))
            self.alpha_v = nn.Parameter(torch.full((H,1,1), float(alpha_init)))
        else:
            for name in ["alpha_q","alpha_k","alpha_v"]:
                self.register_parameter(name, None)

        if variant == "tanh_gate":
            self.beta_q = nn.Parameter(torch.full((H,1,1), float(beta_init)))
            self.beta_k = nn.Parameter(torch.full((H,1,1), float(beta_init)))
            self.beta_v = nn.Parameter(torch.full((H,1,1), float(beta_init)))
        else:
            for name in ["beta_q","beta_k","beta_v"]:
                self.register_parameter(name, None)

        # Output projection
        self.o = nn.Linear(E, E, bias=False)

        # Causal mask (for max context)
        self.register_buffer("mask", torch.tril(torch.ones(T_max, T_max))
                             .unsqueeze(0).unsqueeze(0))

        self.reset_parameters()

    def reset_parameters(self):
        def init_base(W):
            if W is not None:
                nn.init.xavier_uniform_(W)
        init_base(self.q_base); init_base(self.k_base); init_base(self.v_base)
        if self.qA is not None:
            for A,B in [(self.qA,self.qB),(self.kA,self.kB),(self.vA,self.vB)]:
                nn.init.kaiming_uniform_(A, a=math.sqrt(5))
                nn.init.kaiming_uniform_(B, a=math.sqrt(5))
        nn.init.xavier_uniform_(self.o.weight)

    def _AB(self, A, B):
        # (H,E,r) @ (H,r,D) -> (H,E,D)
        if A is None: return None
        return torch.matmul(A, B)

    def _make_weight_set(self, A, B, base, alpha=None, beta=None):
        var = self.variant
        if var == "vanilla":
            return base
        if var == "lora":
            scale = (self.lora_alpha / max(1, self.rank))
            return base + scale * self._AB(A,B)
        if var == "softmax_gate":
            G = torch.softmax(self._AB(A,B) / max(self.temp, 1e-6), dim=-1)
            return base * (alpha * G)
        if var == "tanh_gate":
            G = 1.0 + beta * torch.tanh(self._AB(A,B) / max(self.temp, 1e-6))
            return base * G
        if var == "strict_softmax":
            G = torch.softmax(self._AB(A,B) / max(self.temp, 1e-6), dim=-1)
            return alpha * G
        raise ValueError(f"Unknown variant {var}")

    def forward(self, x, return_attn=False):
        B, Lt, E = x.shape
        # Per-head weights
        if self.variant == "vanilla":
            Wq, Wk, Wv = self.q_base, self.k_base, self.v_base
        elif self.variant == "lora":
            Wq = self._make_weight_set(self.qA, self.qB, self.q_base)
            Wk = self._make_weight_set(self.kA, self.kB, self.k_base)
            Wv = self._make_weight_set(self.vA, self.vB, self.v_base)
        elif self.variant == "softmax_gate":
            Wq = self._make_weight_set(self.qA, self.qB, self.q_base, self.alpha_q)
            Wk = self._make_weight_set(self.kA, self.kB, self.k_base, self.alpha_k)
            Wv = self._make_weight_set(self.vA, self.vB, self.v_base, self.alpha_v)
        elif self.variant == "tanh_gate":
            Wq = self._make_weight_set(self.qA, self.qB, self.q_base, beta=self.beta_q)
            Wk = self._make_weight_set(self.kA, self.kB, self.k_base, beta=self.beta_k)
            Wv = self._make_weight_set(self.vA, self.vB, self.v_base, beta=self.beta_v)
        elif self.variant == "strict_softmax":
            Wq = self._make_weight_set(self.qA, self.qB, None, self.alpha_q)
            Wk = self._make_weight_set(self.kA, self.kB, None, self.alpha_k)
            Wv = self._make_weight_set(self.vA, self.vB, None, self.alpha_v)
        else:
            raise ValueError

        # Projections: (B,T,E) x (H,E,D) -> (B,H,T,D)
        q = torch.einsum('bte,hed->bthd', x, Wq).transpose(1, 2)
        k = torch.einsum('bte,hed->bthd', x, Wk).transpose(1, 2)
        v = torch.einsum('bte,hed->bthd', x, Wv).transpose(1, 2)

        # Attention
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.D)
        att = att.masked_fill(self.mask[:, :, :Lt, :Lt] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, Lt, E)
        y = self.o(y)
        if return_attn:
            return y, att
        return y

class FF(nn.Module):
    def __init__(self, E, mult=1):
        super().__init__()
        self.l1 = nn.Linear(E, mult*E, bias=False)
        self.l2 = nn.Linear(mult*E, E, bias=False)
    def forward(self, x): return self.l2(F.relu(self.l1(x)))

class Block(nn.Module):
    def __init__(self, E, H, T_max, variant="vanilla", **kw):
        super().__init__()
        self.ln1 = nn.LayerNorm(E)
        self.att = MHAFork(E, H, T_max, variant=variant, **kw)
        self.ln2 = nn.LayerNorm(E)
        self.ff  = FF(E, mult=MLP_MULT)
    def forward(self, x, return_attn=False):
        if return_attn:
            a, att = self.att(self.ln1(x), return_attn=True)
        else:
            a = self.att(self.ln1(x)); att=None
        x = x + a
        x = x + self.ff(self.ln2(x))
        if return_attn: return x, att
        return x

# ------------------------ Model wrapper ------------------------
class SorterModel(nn.Module):
    def __init__(self, variant="vanilla", **kw):
        super().__init__()
        self.variant = variant
        self.token_embed = nn.Embedding(V, E).to(device)
        if USE_POS_EMBED:
            self.pos_embed = nn.Embedding(T_MAX, E).to(device)
        else:
            self.pos_embed = None
        self.block = Block(E, H, T_MAX, variant=variant, **kw).to(device)
        self.final_norm = nn.LayerNorm(E).to(device)
        # weight-tied LM head
        self.lm_head = nn.Linear(E, V, bias=False).to(device)
        self.lm_head.weight = self.token_embed.weight

    def forward(self, x, return_attn=False):
        B, Lt = x.shape
        h = self.token_embed(x)
        if self.pos_embed is not None:
            pos = torch.arange(Lt, device=x.device).unsqueeze(0).expand(B, Lt)
            h = h + self.pos_embed(pos)
        if return_attn:
            h, att = self.block(h, return_attn=True)
        else:
            h = self.block(h); att=None
        h = self.final_norm(h)
        logits = self.lm_head(h)
        if return_attn: return logits, att
        return logits

# ------------------------ Utilities ------------------------
def count_params(model):
    return sum(p.numel() for p in model.parameters())

def get_lr(step):
    if step < WARMUP_STEPS:
        return BASE_LR * (step + 1) / WARMUP_STEPS
    t = (step - WARMUP_STEPS) / max(1, STEPS - WARMUP_STEPS)
    return BASE_LR * (0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * t)))

def train_one(model):
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR)
    for step in range(STEPS + 1):
        for g in opt.param_groups:
            g['lr'] = get_lr(step)
        x = make_batch(BATCH, TRAIN_LEN)
        logits = model(x)
        L = 2*TRAIN_LEN + 1
        loss = F.cross_entropy(logits[:, :L-1, :].reshape(-1, V),
                               x[:, 1:L].reshape(-1))
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        opt.step()
        if step % 1000 == 0:
            print(f"Step {step:05d} | LR: {opt.param_groups[0]['lr']:.5f} | Loss: {loss.item():.4f}")
    return model

@torch.no_grad()
def generate_sorted(model, chars):
    L = len(chars)
    ids = [stoi[c] for c in chars]
    # pad to T_MAX internally
    x = torch.tensor([[*ids, SEP]], dtype=torch.long, device=device)
    while x.size(1) < (2*L + 1):
        logits = model(x)[:, -1, :]
        logits[:, SEP] = -float('inf')  # forbid SEP after the first separator
        next_id = int(torch.argmax(logits, dim=-1))
        x = torch.cat([x, torch.tensor([[next_id]], device=device)], 1)
    return [itos[i] for i in x[0].tolist()[-L:]]

@torch.no_grad()
def evaluate_lengths(model, lens=TEST_LENS, max_exact=7000, samples=3000):
    results = {}
    for L in lens:
        alphabet = ALPHABET
        all_tuples = list(itertools.product(alphabet, repeat=L))
        exact = len(all_tuples) <= max_exact
        tuples = all_tuples if exact else random.sample(all_tuples, samples)
        correct = 0
        for t in tuples:
            pred = generate_sorted(model, list(t))
            tgt  = sorted(list(t))
            correct += int(pred == tgt)
        acc = correct/len(tuples)
        results[L] = (acc, exact, len(tuples))
    return results

@torch.no_grad()
def trace(model, chars):
    L = len(chars)
    ids = [stoi[c] for c in chars]
    x = torch.tensor([[*ids, SEP]], dtype=torch.long, device=device)
    print(f"\n— AR trace ({model.variant}) —")
    print(f"start {[itos[i] for i in x[0].tolist()]}  (heads={H}, head_dim={D})")
    steps=[]
    while x.size(1) < (2*L + 1):
        logits, att = model(x, return_attn=True)
        probs = torch.softmax(logits[:, -1, :], -1)[0]
        choice = int(torch.argmax(probs))
        steps.append(itos[choice])
        print(f"  step{len(steps)}: " + ", ".join([f"p({ch})={float(probs[stoi[ch]]):.3f}" for ch in ALPHABET])
              + f" -> '{itos[choice]}'")
        Hh, Tcur = att.shape[1], x.size(1)
        w = att[0, :, -1, :Tcur]  # (H, Tcur)
        toks = [itos[i] for i in x[0].tolist()]
        for h in range(Hh):
            weights = ", ".join(f"{toks[t]}:{w[h,t].item():.2f}" for t in range(Tcur))
            print(f"    head{h}: [{weights}]")
        x = torch.cat([x, torch.tensor([[choice]], device=device)], 1)
    full = x[0].tolist()
    print(f"  full {[itos[i] for i in full]} -> outputs {[itos[i] for i in full[-L:]]}")

# ------------------------ Experiment harness ------------------------
def run_variant(variant_name, **kw):
    print("\n" + "="*70)
    print(f"Variant: {variant_name}")
    m = SorterModel(variant=variant_name, **kw).to(device)
    print(f"Architecture: blocks=1, heads={H}, emb={E}, head_dim={D}, mlp_mult={MLP_MULT}, context(T_MAX)={T_MAX}")
    print(f"Trainable parameters: {count_params(m)}")
    print("Training...")
    train_one(m)
    print("Evaluating...")
    results = evaluate_lengths(m)
    for L in TEST_LENS:
        acc, exact, n = results[L]
        tag = "exact" if exact else f"sample({n})"
        print(f"  len={L}: acc={acc:.3f} [{tag}]")
    # Show a couple traces
    for s in ["caa","bac","ccb"]:
        trace(m, list(s))
    return variant_name, results

def leaderboard(all_results):
    print("\n" + "#"*70)
    print("Leaderboard (accuracy)")
    header = "variant".ljust(18) + " | " + " ".join([f"L{L}".rjust(6) for L in TEST_LENS])
    print(header)
    print("-"*len(header))
    for name, res in all_results:
        row = name.ljust(18) + " | "
        for L in TEST_LENS:
            acc = res[L][0]
            row += f"{acc:.3f}".rjust(6)
        print(row)

# ------------------------ Main ------------------------
if __name__ == "__main__":
    all_results = []
    for cfg in RUNS:
        name = cfg["name"]
        kw = {k:v for k,v in cfg.items() if k != "name"}
        all_results.append(run_variant(name, **kw))
    leaderboard(all_results)
