In [None]:
# ABOUTME: Route B training notebook for TRM with KDA layers.
# ABOUTME: Tracks recursion settings and deep supervision steps.


# route b: trm + kda

use this notebook for small-scale runs and recursion tuning.

all model implementations are in pytorch.


## run checklist
- select dataset snapshot tag and tokenizer version
- set recursion steps and deep supervision count
- keep kda/full attention ratio at 3:1 in the core
- run a short smoke train and confirm loss decreases

## config sketch
- start with a small recursion depth for stability
- log arc-agi evals each checkpoint
- compare with route a at matched token budgets


In [None]:
# TODO: build TRM core, set recursion steps, run a short training loop
pass


## model + training config (route b)
this config targets ~100m params and uses full recursion + deep supervision.


In [None]:
model_config = {
    "vocab_size": 50257,
    "block_size": 1024,
    "n_layer": 8,
    "n_head": 12,
    "n_embd": 768,
    "kda_chunk_size": 64,
    "kda_ratio": 3,
    "h_cycles": 2,
    "l_cycles": 2,
    "deep_steps": 2,
}

train_config = {
    "batch_size": 8,
    "max_steps": 200,
    "max_tokens": 200_000,
    "lr": 3e-4,
    "log_interval": 10,
}
model_config, train_config


## kda attention + trm model (route b)


In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

def chunk_kda(q, k, v, g, beta, chunk_size):
    dtype = v.dtype
    b, t, h, kdim = q.shape
    vdim = v.shape[-1]
    c = chunk_size
    n = t // c
    t_trunc = n * c
    q = q[:, :t_trunc]
    k = k[:, :t_trunc]
    v = v[:, :t_trunc]
    g = g[:, :t_trunc]
    beta = beta[:, :t_trunc]
    q = q.view(b, n, c, h, kdim).permute(0, 3, 1, 2, 4).to(torch.float32)
    k = k.view(b, n, c, h, kdim).permute(0, 3, 1, 2, 4).to(torch.float32)
    v = v.view(b, n, c, h, vdim).permute(0, 3, 1, 2, 4).to(torch.float32)
    g = g.view(b, n, c, h, kdim).permute(0, 3, 1, 2, 4).to(torch.float32)
    beta = beta.view(b, n, c, h).permute(0, 3, 1, 2).to(torch.float32)
    q = q * (kdim ** -0.5)
    g = g.cumsum(-2)
    mask = torch.triu(torch.ones(c, c, dtype=torch.bool, device=q.device), diagonal=0)
    a = torch.zeros(b, h, n, c, c, dtype=torch.float32, device=q.device)
    for i in range(c):
        k_i = k[..., i, :]
        g_i = g[..., i:i+1, :]
        a[..., i] = torch.einsum("... c d, ... d -> ... c", k * (g - g_i).exp(), k_i)
    a = a * beta[..., None]
    a = -a.masked_fill(mask, 0)
    for i in range(1, c):
        a[..., i, :i] = a[..., i, :i].clone() + (a[..., i, :, None].clone() * a[..., :i]).sum(-2)
    a = (a + torch.eye(c, dtype=torch.float32, device=q.device)) * beta[..., None, :]
    w = torch.einsum("... i j, ... j d -> ... i d", a, g.exp() * k)
    u = torch.einsum("... i j, ... j d -> ... i d", a, v)
    s = k.new_zeros(b, h, kdim, vdim)
    o = torch.zeros_like(v)
    mask = torch.triu(torch.ones(c, c, dtype=torch.bool, device=q.device), diagonal=1)
    for i in range(0, n):
        q_i, k_i, u_i, g_i, w_i = q[:, :, i], k[:, :, i], u[:, :, i], g[:, :, i], w[:, :, i]
        a = torch.zeros(b, h, c, c, dtype=torch.float32, device=q.device)
        for j in range(c):
            k_j = k[:, :, i, j]
            g_j = g[:, :, i, j:j+1, :]
            a[..., j] = torch.einsum("... c d, ... d -> ... c", q_i * (g_i - g_j).exp(), k_j)
        a = a.masked_fill(mask, 0)
        v_i = u_i - torch.einsum("... i j, ... j d -> ... i d", w_i, s)
        o[:, :, i] = torch.einsum("... c d, ... d v -> ... c v", q_i * g_i.exp(), s) + torch.einsum("... i j, ... j d -> ... i d", a, v_i)
        s = s * g_i[:, :, -1].exp().unsqueeze(-1)
        s = s + torch.einsum("... c d, ... c v -> ... d v", (g_i[:, :, -1:] - g_i).exp() * k_i, v_i)
    o = o.permute(0, 2, 3, 1, 4).contiguous().view(b, t_trunc, h, vdim)
    return o.to(dtype)

class KDAAttention(nn.Module):
    def __init__(self, n_embd, n_head, chunk_size):
        super().__init__()
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        self.chunk_size = chunk_size
        self.qkv = nn.Linear(n_embd, 3 * n_embd)
        self.g_proj = nn.Linear(n_embd, n_embd)
        self.beta_proj = nn.Linear(n_embd, n_head)
        self.out_proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        b, t, c = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(c, dim=-1)
        q = q.view(b, t, self.n_head, self.head_dim)
        k = k.view(b, t, self.n_head, self.head_dim)
        v = v.view(b, t, self.n_head, self.head_dim)
        g = -F.softplus(self.g_proj(x)).view(b, t, self.n_head, self.head_dim)
        beta = torch.sigmoid(self.beta_proj(x)).view(b, t, self.n_head)
        o = chunk_kda(q, k, v, g, beta, self.chunk_size)
        o = o.view(b, t, c)
        return self.out_proj(o)

class MLP(nn.Module):
    def __init__(self, n_embd, expansion=4):
        super().__init__()
        self.fc1 = nn.Linear(n_embd, expansion * n_embd)
        self.fc2 = nn.Linear(expansion * n_embd, n_embd)
    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

class Block(nn.Module):
    def __init__(self, n_embd, n_head, use_kda, chunk_size):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.attn = KDAAttention(n_embd, n_head, chunk_size) if use_kda else CausalSelfAttention(n_embd, n_head)
        self.mlp = MLP(n_embd)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        self.qkv = nn.Linear(n_embd, 3 * n_embd)
        self.out_proj = nn.Linear(n_embd, n_embd)
    def forward(self, x):
        b, t, c = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(c, dim=-1)
        q = q.view(b, t, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(b, t, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(b, t, self.n_head, self.head_dim).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        mask = torch.triu(torch.ones(t, t, device=x.device), diagonal=1).bool()
        att = att.masked_fill(mask, float("-inf"))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(b, t, c)
        return self.out_proj(y)

class TRMLanguageModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["n_embd"])
        self.pos_emb = nn.Embedding(cfg["block_size"], cfg["n_embd"])
        self.blocks = nn.ModuleList()
        for i in range(cfg["n_layer"]):
            use_kda = (i % (cfg["kda_ratio"] + 1)) != cfg["kda_ratio"]
            self.blocks.append(Block(cfg["n_embd"], cfg["n_head"], use_kda, cfg["kda_chunk_size"]))
        self.ln_f = nn.LayerNorm(cfg["n_embd"])
        self.lm_head = nn.Linear(cfg["n_embd"], cfg["vocab_size"], bias=False)

    def _forward_once(self, idx, z_h, z_l):
        b, t = idx.shape
        pos = torch.arange(0, t, device=idx.device)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        z_l = z_l + z_h + x
        for block in self.blocks:
            z_l = block(z_l)
        z_h = z_h + z_l
        for block in self.blocks:
            z_h = block(z_h)
        return z_h, z_l

    def forward(self, idx, targets=None):
        b, t = idx.shape
        z_h = torch.zeros(b, t, self.cfg["n_embd"], device=idx.device)
        z_l = torch.zeros(b, t, self.cfg["n_embd"], device=idx.device)
        losses = []
        for step in range(self.cfg["deep_steps"]):
            with torch.no_grad():
                for _ in range(self.cfg["h_cycles"] - 1):
                    for _ in range(self.cfg["l_cycles"]):
                        z_l, z_h = self._forward_once(idx, z_h, z_l)
            for _ in range(self.cfg["l_cycles"]):
                z_l, z_h = self._forward_once(idx, z_h, z_l)
            x = self.ln_f(z_h)
            logits = self.lm_head(x)
            if targets is not None:
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
                losses.append(loss)
            z_h = z_h.detach()
            z_l = z_l.detach()
        loss = None
        if losses:
            loss = sum(losses) / len(losses)
        return logits, loss

def count_params(model):
    return sum(p.numel() for p in model.parameters())

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TRMLanguageModel(model_config).to(device)
print("param_count", count_params(model))


## training loop (route b)
prints loss, step time, and tokens/sec for quick feedback.


In [None]:
import time
import json
from pathlib import Path
import tiktoken

enc = tiktoken.get_encoding("gpt2")
processed_dir = Path("../data/processed")
shards = sorted(processed_dir.glob("stage1_shard_*.jsonl"))
if not shards:
    raise FileNotFoundError("no processed shards found")

def load_tokens(max_tokens):
    tokens = []
    for shard in shards:
        with open(shard, "r") as f:
            for line in f:
                text = json.loads(line).get("text")
                if text is None:
                    continue
                if not isinstance(text, str):
                    text = str(text)
                tokens.extend(enc.encode(text))
                if max_tokens and len(tokens) >= max_tokens:
                    return torch.tensor(tokens[:max_tokens], dtype=torch.long)
    return torch.tensor(tokens, dtype=torch.long)

token_data = load_tokens(train_config["max_tokens"])

def get_batch():
    block_size = model_config["block_size"]
    batch_size = train_config["batch_size"]
    idx = torch.randint(0, token_data.numel() - block_size - 1, (batch_size,))
    x = torch.stack([token_data[i : i + block_size] for i in idx])
    y = torch.stack([token_data[i + 1 : i + block_size + 1] for i in idx])
    return x.to(device), y.to(device)

opt = torch.optim.AdamW(model.parameters(), lr=train_config["lr"])
model.train()
for step in range(1, train_config["max_steps"] + 1):
    t0 = time.time()
    x, y = get_batch()
    logits, loss = model(x, y)
    opt.zero_grad()
    loss.backward()
    opt.step()
    dt = time.time() - t0
    tokens_per_sec = (x.numel()) / dt
    if step % train_config["log_interval"] == 0 or step == 1:
        print(f"step {step} loss {loss.item():.4f} step_time {dt:.3f}s tok/s {tokens_per_sec:.1f}")
