# Deep Attention 1


In [None]:
# micro_abc_sorter_e6_deepqkv.py
# ABC sorter with "deeper" Q/K/V: each head’s Q/K/V weights are produced by
# a factorized pair (A,B) and (optionally) gate a full-capacity base weight.
# Set STRICT_FACTORIZE=True to use W = softmax(A@B) with no base.

import math, random, itertools
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----- Repro & device -----
torch.manual_seed(42); random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ----- Config -----
USE_POS_EMBED   = True
STEPS           = 15000
BATCH           = 128
BASE_LR         = 1e-3
WARMUP_STEPS    = 500

# New knobs for the deeper-QKV module
RANK            = 3         # factorization rank r
SOFTMAX_TEMP    = 1.0       # temperature inside softmax gating
STRICT_FACTORIZE= False     # False: gate a base weight (recommended). True: pure softmax(A@B).

# ----- Vocab -----
vocab = ['', 'a', 'b', 'c']
SEP, A, B, C = range(len(vocab))
stoi = {ch: i for i, ch in enumerate(vocab)}
itos = {i: ch for ch, i in stoi.items()}
V = len(vocab)

# ----- Shapes -----
T = 7   # 3 input + SEP + 3 output
E = 6   # embedding dim
H = 2   # heads
assert E % H == 0
D = E // H  # per-head dim (=3)

# ----- Modules (no biases except LayerNorm) -----
token_embed = nn.Embedding(V, E).to(device)
pos_embed   = nn.Embedding(T, E).to(device) if USE_POS_EMBED else None

class MHADeepQKV(nn.Module):
    """
    Multi-head attention where each head's Q/K/V weight matrices W∈R^{E×D}
    are produced from two deeper matrices A∈R^{E×r}, B∈R^{r×D} via matmul,
    optionally passed through a row-wise softmax (gating), and (by default)
    elementwise-gated onto a free base weight to preserve expressivity.

    Shapes:
      x: (B, T, E)
      W_q, W_k, W_v: (H, E, D)
      q,k,v: (B, H, T, D)
    """
    def __init__(self, E, H, T, rank=3, temp=1.0, strict=False):
        super().__init__()
        self.H, self.D, self.E = H, E // H, E
        self.rank, self.temp, self.strict = rank, temp, strict

        # Factorized parameters per head for Q/K/V
        def pair():
            A = nn.Parameter(torch.empty(H, E, rank))
            B = nn.Parameter(torch.empty(H, rank, self.D))
            return A, B
        self.qA, self.qB = pair()
        self.kA, self.kB = pair()
        self.vA, self.vB = pair()

        # Optional base weights (full capacity E×D per head)
        if not self.strict:
            self.q_base = nn.Parameter(torch.empty(H, E, self.D))
            self.k_base = nn.Parameter(torch.empty(H, E, self.D))
            self.v_base = nn.Parameter(torch.empty(H, E, self.D))
        else:
            self.register_parameter("q_base", None)
            self.register_parameter("k_base", None)
            self.register_parameter("v_base", None)

        # Standard output projection (shared across heads as usual)
        self.o = nn.Linear(E, E, bias=False)

        # Causal mask
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(T, T)).unsqueeze(0).unsqueeze(0)
        )

        self.reset_parameters()

    def reset_parameters(self):
        def init_pair(A, B):
            nn.init.kaiming_uniform_(A, a=math.sqrt(5))
            nn.init.kaiming_uniform_(B, a=math.sqrt(5))
        for A, B in [(self.qA, self.qB), (self.kA, self.kB), (self.vA, self.vB)]:
            init_pair(A, B)
        if self.q_base is not None:
            nn.init.xavier_uniform_(self.q_base)
            nn.init.xavier_uniform_(self.k_base)
            nn.init.xavier_uniform_(self.v_base)
        nn.init.xavier_uniform_(self.o.weight)

    def _make_weight(self, A, B, base):
        # (H,E,r) @ (H,r,D) -> (H,E,D)
        W = torch.matmul(A, B) / max(self.temp, 1e-6)
        # row-wise softmax over D (keeps per-input-dim rows normalized)
        W = torch.softmax(W, dim=-1)
        if base is None:
            return W                    # strict variant
        else:
            return base * W             # gated variant (safer; retains capacity)

    def forward(self, x, return_attn=False):
        B, Lt, E = x.shape
        Wq = self._make_weight(self.qA, self.qB, self.q_base)
        Wk = self._make_weight(self.kA, self.kB, self.k_base)
        Wv = self._make_weight(self.vA, self.vB, self.v_base)

        # Project with per-head weights: (B,T,E) × (H,E,D) -> (B,H,T,D)
        q = torch.einsum('bte,hed->bthd', x, Wq).transpose(1, 2)
        k = torch.einsum('bte,hed->bthd', x, Wk).transpose(1, 2)
        v = torch.einsum('bte,hed->bthd', x, Wv).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.D)
        att = att.masked_fill(self.mask[:, :, :Lt, :Lt] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        y = att @ v                           # (B,H,T,D)
        y = y.transpose(1, 2).contiguous().view(B, Lt, E)  # (B,T,E)
        y = self.o(y)
        if return_attn:
            return y, att
        return y

class FF(nn.Module):
    def __init__(self, E, mult=1):
        super().__init__()
        self.l1 = nn.Linear(E, mult*E, bias=False)
        self.l2 = nn.Linear(mult*E, E, bias=False)
    def forward(self, x): return self.l2(F.relu(self.l1(x)))

class Block(nn.Module):
    def __init__(self, E, H, T):
        super().__init__()
        self.ln1 = nn.LayerNorm(E)
        self.att = MHADeepQKV(E, H, T, rank=RANK, temp=SOFTMAX_TEMP, strict=STRICT_FACTORIZE)
        self.ln2 = nn.LayerNorm(E)
        self.ff  = FF(E, mult=1)
    def forward(self, x, return_attn=False):
        if return_attn:
            a, att = self.att(self.ln1(x), return_attn=True)
        else:
            a = self.att(self.ln1(x)); att=None
        x = x + a
        x = x + self.ff(self.ln2(x))
        if return_attn: return x, att
        return x

block      = Block(E, H, T).to(device)
final_norm = nn.LayerNorm(E).to(device)

# Weight-tied LM head (no extra params)
lm_head = nn.Linear(E, V, bias=False).to(device)
lm_head.weight = token_embed.weight

# ----- Parameter counting -----
def count_params():
    def n(p): return sum(x.numel() for x in p)
    tok = n(token_embed.parameters())                       # V*E
    pos = n(pos_embed.parameters()) if pos_embed else 0     # T*E (optional)

    # Attention params:
    # - factorized Q/K/V: 3 * H * (E*R + R*D)
    # - optional base Q/K/V: 3 * H * (E*D) if not strict
    # - output projection: E*E
    r = RANK; base = 0 if STRICT_FACTORIZE else (3*H*E*D)
    att = 3*H*(E*r + r*D) + base + (E*E)

    mlp = n(block.ff.parameters())                          # 2*E*E (since mult=1)
    lns = n(block.ln1.parameters()) + n(block.ln2.parameters()) + n(final_norm.parameters()) # 3*(2E)
    total = tok + pos + att + mlp + lns
    return tok, pos, att, mlp, lns, total

# ----- Data -----
abc_ids = [stoi['a'], stoi['b'], stoi['c']]
def make_batch(B=128):
    x = torch.empty((B, T), dtype=torch.long)
    for i in range(B):
        seq = random.choices(abc_ids, k=3)   # duplicates allowed
        toks = seq + [SEP] + sorted(seq)
        x[i] = torch.tensor(toks)
    return x.to(device)

# ----- Forward -----
def forward(x, return_attn=False):
    B, Lt = x.shape
    h = token_embed(x)
    if pos_embed is not None:
        pos = torch.arange(Lt, device=x.device).unsqueeze(0).expand(B, Lt)
        h = h + pos_embed(pos)
    if return_attn:
        h, att = block(h, return_attn=True)
    else:
        h = block(h); att=None
    h = final_norm(h)
    logits = lm_head(h)
    if return_attn: return logits, att
    return logits

# ----- Optimizer + simple LR schedule -----
opt = torch.optim.AdamW(
    list(token_embed.parameters()) +
    (list(pos_embed.parameters()) if pos_embed else []) +
    list(block.parameters()) +
    list(final_norm.parameters()),
    lr=BASE_LR
)

def get_lr(step):
    if step < WARMUP_STEPS:
        return BASE_LR * (step + 1) / WARMUP_STEPS
    t = (step - WARMUP_STEPS) / max(1, STEPS - WARMUP_STEPS)
    return BASE_LR * (0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * t)))

# ----- Training -----
def train():
    tok, pos, att, mlp, lns, total = count_params()
    print(f"Architecture: blocks=1, heads={H}, emb={E}, head_dim={D}, mlp_mult=1, context={T}")
    print(f"  token_embed: {tok}")
    if pos_embed is not None: print(f"  pos_embed: {pos}")
    print(f"  attn(deep QKV + proj): {att}  (rank={RANK}, strict={STRICT_FACTORIZE})")
    print(f"  mlp: {mlp}")
    print(f"  layer_norms (2 block + final): {lns}")
    print("  lm_head: weight-tied (0 extra)")
    print(f"Trainable parameters: {total}")

    for step in range(STEPS + 1):
        for g in opt.param_groups:
            g['lr'] = get_lr(step)

        x = make_batch(BATCH)
        logits = forward(x)
        loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, V), x[:, 1:].reshape(-1))
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(list(opt.param_groups[0]['params']), 1.0)
        opt.step()

        if step % 1000 == 0:
            print(f"Step {step:05d} | LR: {opt.param_groups[0]['lr']:.5f} | Loss: {loss.item():.4f}")
    print("✅ Training complete!")

# ----- Generation (greedy) + TRACE -----
@torch.no_grad()
def generate_sorted(chars):
    ids = [stoi[c] for c in chars]
    x = torch.tensor([[*ids, SEP]], dtype=torch.long, device=device)
    while x.size(1) < 7:
        logits = forward(x)[:, -1, :]
        logits[:, SEP] = -float('inf')  # forbid SEP after separator
        next_id = int(torch.argmax(logits, dim=-1))
        x = torch.cat([x, torch.tensor([[next_id]], device=device)], 1)
    return [itos[i] for i in x[0].tolist()[-3:]]

def _fmt(ids): return "[" + ", ".join(itos[i] for i in ids) + "]"

@torch.no_grad()
def trace(chars):
    ids = [stoi[c] for c in chars]
    x = torch.tensor([[*ids, SEP]], dtype=torch.long, device=device)
    print("\n— AR trace —")
    print(f"start {_fmt(x[0].tolist())}  (heads={H}, head_dim={D})")
    steps=[]
    while x.size(1) < 7:
        logits, att = forward(x, return_attn=True)
        probs = torch.softmax(logits[:, -1, :], -1)[0]
        choice = int(torch.argmax(probs))
        steps.append(itos[choice])
        print(f"  step{len(steps)}: p(a)={float(probs[stoi['a']]):.3f}, "
              f"p(b)={float(probs[stoi['b']]):.3f}, p(c)={float(probs[stoi['c']]):.3f} -> '{itos[choice]}'")
        Hh, Tcur = att.shape[1], x.size(1)
        w = att[0, :, -1, :Tcur]  # (H, Tcur)
        toks = [itos[i] for i in x[0].tolist()]
        for h in range(Hh):
            weights = ", ".join(f"{toks[t]}:{w[h,t].item():.2f}" for t in range(Tcur))
            print(f"    head{h}: [{weights}]")
        x = torch.cat([x, torch.tensor([[choice]], device=device)], 1)
    full = x[0].tolist()
    print(f"  full {_fmt(full)} -> outputs {_fmt(full[-3:])}")

# ----- Evaluation -----
@torch.no_grad()
def evaluate():
    triples = list(itertools.product(['a','b','c'], repeat=3))
    correct = 0
    for t in triples:
        pred = generate_sorted(list(t))
        tgt  = sorted(list(t))
        ok = (pred == tgt); correct += int(ok)
        print(f"{list(t)} -> pred {pred} | tgt {tgt} {'✓' if ok else '✗'}")
    print(f"Model accuracy on 27 triples: {correct}/27")
    for s in ["caa","bac","ccb"]:
        trace(list(s))

# ----- Minimal text box (Jupyter) / CLI fallback -----
def launch_textbox():
    try:
        import ipywidgets as widgets
        from IPython.display import display
        tb = widgets.Text(
            value='caa',
            placeholder='Type 3 letters (e.g., cba) then press Enter',
            description=''
        )
        def _on_submit(change):
            s = change.value.strip().lower()
            if len(s) == 3 and set(s).issubset({'a','b','c'}):
                trace(list(s))
            else:
                print("Please enter exactly 3 chars from {a,b,c}.")
        tb.on_submit(_on_submit)
        display(tb)
        print("Type in the box and press Enter.")
    except Exception:
        try:
            s = input("\nType three letters a-c (e.g., aca), then Enter: ").strip().lower()
            if len(s) == 3 and set(s).issubset({'a','b','c'}):
                trace(list(s))
            else:
                print("Please enter exactly 3 chars from {a,b,c}.")
        except EOFError:
            pass

# ----- Run -----
if __name__ == "__main__":
    train()
    evaluate()
    launch_textbox()
