# ZeptoGPT Colab Notebook

One of the smallest GPTs in the universe. 

In [None]:
# micro_abc_sorter_e6.py
# Minimal-but-reliable ABC sorter (duplicates allowed), still "GPT-style"
# 1 block, 2 heads, embedding E=6 (head_dim=3), MLP×1, optional pos-emb.
# Includes AR trace logs + a very simple text box (press Enter).

import math, random, itertools
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----- Repro & device -----
torch.manual_seed(42); random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ----- Config -----
USE_POS_EMBED = True   # set False for 276 params (saves 42 params), True is more robust
STEPS         = 15000  # training steps; small CPU still fine
BATCH         = 128
BASE_LR       = 1e-3
WARMUP_STEPS  = 500    # warmup helps stability at tiny scale

# ----- Vocab -----
vocab = ['<sep>', 'a', 'b', 'c']
SEP, A, B, C = range(len(vocab))
stoi = {ch: i for i, ch in enumerate(vocab)}
itos = {i: ch for ch, i in stoi.items()}
V = len(vocab)

# ----- Shapes -----
T = 7   # 3 input + <sep> + 3 output
E = 6   # embedding dim
H = 2   # heads
assert E % H == 0
D = E // H  # per-head dim (=3)

# ----- Modules (no biases except LayerNorm) -----
token_embed = nn.Embedding(V, E).to(device)
pos_embed   = nn.Embedding(T, E).to(device) if USE_POS_EMBED else None

class MHA(nn.Module):
    def __init__(self, E, H, T):
        super().__init__()
        self.H, self.D = H, E // H
        self.q = nn.Linear(E, E, bias=False)
        self.k = nn.Linear(E, E, bias=False)
        self.v = nn.Linear(E, E, bias=False)
        self.o = nn.Linear(E, E, bias=False)
        self.register_buffer("mask", torch.tril(torch.ones(T, T)).unsqueeze(0).unsqueeze(0))
    def forward(self, x, return_attn=False):
        B, Lt, E = x.shape
        q = self.q(x).view(B, Lt, self.H, self.D).transpose(1, 2)   # (B,H,T,D)
        k = self.k(x).view(B, Lt, self.H, self.D).transpose(1, 2)
        v = self.v(x).view(B, Lt, self.H, self.D).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.D)          # (B,H,T,T)
        att = att.masked_fill(self.mask[:, :, :Lt, :Lt] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, Lt, E)            # (B,T,E)
        y = self.o(y)
        if return_attn: return y, att
        return y

class FF(nn.Module):
    def __init__(self, E, mult=1):
        super().__init__()
        self.l1 = nn.Linear(E, mult*E, bias=False)
        self.l2 = nn.Linear(mult*E, E, bias=False)
    def forward(self, x): return self.l2(F.relu(self.l1(x)))

class Block(nn.Module):
    def __init__(self, E, H, T):
        super().__init__()
        self.ln1 = nn.LayerNorm(E)
        self.att = MHA(E, H, T)
        self.ln2 = nn.LayerNorm(E)
        self.ff  = FF(E, mult=1)
    def forward(self, x, return_attn=False):
        if return_attn:
            a, att = self.att(self.ln1(x), return_attn=True)
        else:
            a = self.att(self.ln1(x)); att=None
        x = x + a
        x = x + self.ff(self.ln2(x))
        if return_attn: return x, att
        return x

block      = Block(E, H, T).to(device)
final_norm = nn.LayerNorm(E).to(device)

# Weight-tied LM head (no extra params)
lm_head = nn.Linear(E, V, bias=False).to(device)
lm_head.weight = token_embed.weight

# ----- Parameter counting -----
def count_params():
    def n(p): return sum(x.numel() for x in p)
    tok = n(token_embed.parameters())                       # V*E
    pos = n(pos_embed.parameters()) if pos_embed else 0     # T*E (optional)
    att = n(block.att.q.parameters()) + n(block.att.k.parameters()) + \
          n(block.att.v.parameters()) + n(block.att.o.parameters())  # 4*E*E
    mlp = n(block.ff.parameters())                          # 2*E*E (since mult=1)
    lns = n(block.ln1.parameters()) + n(block.ln2.parameters()) + n(final_norm.parameters()) # 3*(2E)
    return tok, pos, att, mlp, lns, tok + pos + att + mlp + lns

# ----- Data -----
abc_ids = [stoi['a'], stoi['b'], stoi['c']]
def make_batch(B=128):
    x = torch.empty((B, T), dtype=torch.long)
    for i in range(B):
        seq = random.choices(abc_ids, k=3)   # duplicates allowed
        toks = seq + [SEP] + sorted(seq)
        x[i] = torch.tensor(toks)
    return x.to(device)

# ----- Forward -----
def forward(x, return_attn=False):
    B, Lt = x.shape
    h = token_embed(x)
    if pos_embed is not None:
        pos = torch.arange(Lt, device=x.device).unsqueeze(0).expand(B, Lt)
        h = h + pos_embed(pos)
    if return_attn:
        h, att = block(h, return_attn=True)
    else:
        h = block(h); att=None
    h = final_norm(h)
    logits = lm_head(h)
    if return_attn: return logits, att
    return logits

# ----- Optimizer + simple LR schedule -----
opt = torch.optim.AdamW(
    list(token_embed.parameters()) +
    (list(pos_embed.parameters()) if pos_embed else []) +
    list(block.parameters()) +
    list(final_norm.parameters()),
    lr=BASE_LR
)

def get_lr(step):
    # Linear warmup, then cosine decay to 10% of BASE_LR
    if step < WARMUP_STEPS:
        return BASE_LR * (step + 1) / WARMUP_STEPS
    # cosine over remaining steps
    t = (step - WARMUP_STEPS) / max(1, STEPS - WARMUP_STEPS)
    return BASE_LR * (0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * t)))

# ----- Training -----
def train():
    tok, pos, att, mlp, lns, total = count_params()
    print(f"Architecture: blocks=1, heads={H}, emb={E}, head_dim={D}, mlp_mult=1, context={T}")
    print(f"  token_embed: {tok}")
    if pos_embed is not None: print(f"  pos_embed: {pos}")
    print(f"  attn(q,k,v,proj): {att}")
    print(f"  mlp: {mlp}")
    print(f"  layer_norms (2 block + final): {lns}")
    print("  lm_head: weight-tied (0 extra)")
    print(f"Trainable parameters: {total}")

    for step in range(STEPS + 1):
        for g in opt.param_groups:
            g['lr'] = get_lr(step)

        x = make_batch(BATCH)
        logits = forward(x)
        loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, V), x[:, 1:].reshape(-1))
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(list(opt.param_groups[0]['params']), 1.0)
        opt.step()

        if step % 1000 == 0:
            print(f"Step {step:05d} | LR: {opt.param_groups[0]['lr']:.5f} | Loss: {loss.item():.4f}")
    print("✅ Training complete!")

# ----- Generation (greedy) + TRACE -----
@torch.no_grad()
def generate_sorted(chars):
    ids = [stoi[c] for c in chars]
    x = torch.tensor([[*ids, SEP]], dtype=torch.long, device=device)
    while x.size(1) < 7:
        logits = forward(x)[:, -1, :]
        logits[:, SEP] = -float('inf')         # forbid <sep> after separator
        next_id = int(torch.argmax(logits, dim=-1))
        x = torch.cat([x, torch.tensor([[next_id]], device=device)], 1)
    return [itos[i] for i in x[0].tolist()[-3:]]

def _fmt(ids): return "[" + ", ".join(itos[i] for i in ids) + "]"

@torch.no_grad()
def trace(chars):
    ids = [stoi[c] for c in chars]
    x = torch.tensor([[*ids, SEP]], dtype=torch.long, device=device)
    print("\n— AR trace —")
    print(f"start {_fmt(x[0].tolist())}  (heads={H}, head_dim={D})")
    steps=[]
    while x.size(1) < 7:
        logits, att = forward(x, return_attn=True)
        probs = torch.softmax(logits[:, -1, :], -1)[0]
        choice = int(torch.argmax(probs))
        steps.append(itos[choice])
        print(f"  step{len(steps)}: p(a)={float(probs[stoi['a']]):.3f}, "
              f"p(b)={float(probs[stoi['b']]):.3f}, p(c)={float(probs[stoi['c']]):.3f} -> '{itos[choice]}'")
        # per-head attention from last position
        Hh, Tcur = att.shape[1], x.size(1)
        w = att[0, :, -1, :Tcur]  # (H, Tcur)
        toks = [itos[i] for i in x[0].tolist()]
        for h in range(Hh):
            weights = ", ".join(f"{toks[t]}:{w[h,t].item():.2f}" for t in range(Tcur))
            print(f"    head{h}: [{weights}]")
        x = torch.cat([x, torch.tensor([[choice]], device=device)], 1)
    full = x[0].tolist()
    print(f"  full {_fmt(full)} -> outputs {_fmt(full[-3:])}")

# ----- Evaluation -----
@torch.no_grad()
def evaluate():
    triples = list(itertools.product(['a','b','c'], repeat=3))
    correct = 0
    for t in triples:
        pred = generate_sorted(list(t))
        tgt  = sorted(list(t))
        ok = (pred == tgt); correct += int(ok)
        print(f"{list(t)} -> pred {pred} | tgt {tgt} {'✓' if ok else '✗'}")
    print(f"Model accuracy on 27 triples: {correct}/27")
    for s in ["caa","bac","ccb"]:
        trace(list(s))

# ----- Minimal text box (Jupyter) / CLI fallback -----
def launch_textbox():
    try:
        import ipywidgets as widgets
        from IPython.display import display
        tb = widgets.Text(
            value='caa',
            placeholder='Type 3 letters (e.g., cba) then press Enter',
            description=''
        )
        def _on_submit(change):
            s = change.value.strip().lower()
            if len(s) == 3 and set(s).issubset({'a','b','c'}):
                trace(list(s))
            else:
                print("Please enter exactly 3 chars from {a,b,c}.")
        tb.on_submit(_on_submit)
        display(tb)
        print("Type in the box and press Enter.")
    except Exception:
        try:
            s = input("\nType three letters a-c (e.g., aca), then Enter: ").strip().lower()
            if len(s) == 3 and set(s).issubset({'a','b','c'}):
                trace(list(s))
            else:
                print("Please enter exactly 3 chars from {a,b,c}.")
        except EOFError:
            pass

# ----- Run -----
if __name__ == "__main__":
    train()
    evaluate()
    launch_textbox()
