In [10]:
# trainer_rmt_partial_softk.py
# Recurrent Memory Transformer with differentiable soft top-k bank routing
# and partial memory updates (only a subset of memory slots, and optionally dims, update per chunk).

import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

# --------------- Utilities ---------------

def set_seed(seed: int = 1337):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

def fourier_pos_enc(pos, d, base=10_000.0):
    # pos: [T,1] or [B,T,1] (float), returns [T,d] (or broadcastable)
    if pos.dim() == 3:
        pos = pos[0]
    half = d // 2
    if half == 0:
        return torch.zeros_like(pos)
    freqs = torch.arange(half, device=pos.device).float() / max(1, half)
    ang = pos.float() / (base ** freqs)  # [T,half]
    return torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)  # [T,d]

# --------------- Model ---------------

class MLP(nn.Module):
    def __init__(self, d, mult=4, p=0.0):
        super().__init__()
        self.fc1 = nn.Linear(d, mult * d)
        self.fc2 = nn.Linear(mult * d, d)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        return self.fc2(self.drop(F.gelu(self.fc1(x))))

class RMTBlockPartial(nn.Module):
    """One layer: token pass (reads full prev memory), then memory update (writes only selected banks)."""
    def __init__(self, d, n_heads, p=0.0, d_proj=None):
        super().__init__()
        self.ln_tok_q  = nn.LayerNorm(d)
        self.ln_tok_kv = nn.LayerNorm(d)
        self.mha_tok   = nn.MultiheadAttention(d, n_heads, dropout=p, batch_first=True)
        self.ff_tok    = MLP(d, p=p)
        self.ln_tok_ff = nn.LayerNorm(d)

        self.ln_mem_q  = nn.LayerNorm(d)
        self.ln_mem_kv = nn.LayerNorm(d)
        self.mha_mem   = nn.MultiheadAttention(d, n_heads, dropout=p, batch_first=True)
        self.ff_mem    = MLP(d, p=p)
        self.ln_mem_ff = nn.LayerNorm(d)

        self.d_proj = d_proj  # if set, update only first d_proj dims of chosen bank(s)

    def token_pass(self, tok, mem, mask_tok):
        q  = self.ln_tok_q(tok)
        kv = self.ln_tok_kv(torch.cat([mem, tok], dim=1))  # [B,M+T,D]
        attn, _ = self.mha_tok(q, kv, kv, need_weights=False, attn_mask=mask_tok)
        tok = tok + attn
        tok = tok + self.ff_tok(self.ln_tok_ff(tok))
        return tok

    def bank_update(self, mem, tok, bank_slices, bank_weights):
        """
        mem: [B,M,D], tok: [B,T,D]
        bank_slices: list of (start, end) for each chosen bank per sample (length = topk per sample)
                     shape handled with a small loop over (B, topk) for clarity.
        bank_weights: [B, topk] non-negative weights (softmax probs for chosen banks).
        """
        B, M, D = mem.shape
        mem_next = mem
        dproj = self.d_proj

        # small, robust loops (B and topk are modest)
        for b in range(B):
            for i, (start, end) in enumerate(bank_slices[b]):
                w = bank_weights[b, i]
                if w.item() == 0.0:
                    continue
                sl = slice(start, end)  # size group
                qm  = self.ln_mem_q(mem_next[b:b+1, sl, :])  # [1,group,D]
                kvm = self.ln_mem_kv(torch.cat([mem_next[b:b+1], tok[b:b+1]], dim=1))
                upd, _ = self.mha_mem(qm, kvm, kvm, need_weights=False)   # [1,group,D]
                upd = upd + self.ff_mem(self.ln_mem_ff(upd))              # [1,group,D]
                if dproj is not None and dproj < D:
                    mem_next[b, sl, :dproj] = mem_next[b, sl, :dproj] + w * upd[0, :, :dproj]
                else:
                    mem_next[b, sl, :]     = mem_next[b, sl, :]     + w * upd[0, :, :]
        return mem_next

class BankGater(nn.Module):
    """Soft bank router producing logits over K banks; temperature is applied in forward."""
    def __init__(self, d, K, hidden=None, p=0.0):
        super().__init__()
        h = hidden or (2 * d)
        self.net = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, h),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(h, K)
        )
        self.tau = 1.0  # temperature (adjust in trainer)

    def forward(self, feat):
        logits = self.net(feat)              # [B,K]
        tau = max(1e-3, float(self.tau))
        probs = torch.softmax(logits / tau, dim=-1)
        return probs, logits

class RMTPartialSoftK(nn.Module):
    """
    Memory M split into K banks of size group = M//K.
    Each chunk: tokens read full mem_prev; router picks per-sample top-k banks with soft weights;
    only those banks are updated (optionally only first d_proj dims).
    """
    def __init__(self, vocab, d=256, n_heads=4, n_layers=4, M=128, K=8, d_proj=None, p=0.0):
        super().__init__()
        assert M % K == 0, "M must be divisible by K"
        self.d, self.M, self.K = d, M, K
        self.group = M // K
        self.embed = nn.Embedding(vocab, d)
        self.pos_scale = nn.Parameter(torch.ones(1))
        self.blocks = nn.ModuleList([RMTBlockPartial(d, n_heads, p, d_proj) for _ in range(n_layers)])
        self.ln_out = nn.LayerNorm(d)
        self.out = nn.Linear(d, vocab)

        self.gater = BankGater(d, K, hidden=2*d, p=p)

        # tiny aux weights (can be 0.0 during warmup)
        self.lb_lambda = 0.0  # load-balance
        self.ent_lambda = 0.0  # (optional) encourage entropy (diversity)

    @torch.no_grad()
    def init_mem(self, B, device=None, dtype=None):
        device = device or next(self.parameters()).device
        dtype  = dtype  or next(self.parameters()).dtype
        return torch.zeros(B, self.M, self.d, device=device, dtype=dtype)

    def token_mask(self, T, device, dtype):
        # allow token->mem; causal token->token; avoid -inf (use large negative)
        mask = torch.zeros(T, self.M + T, device=device, dtype=dtype)
        neg = torch.full((T, T), -1e4, device=device, dtype=dtype)
        mask[:, self.M:] = torch.triu(neg, 1)
        return mask

    def forward(self, x, mem_prev, pos_offset=0, topk=2, collect_aux=True):
        """
        x: [B,T] token ids
        mem_prev: [B,M,D]
        Returns: logits [B,T,V], mem_next [B,M,D], aux dict
        """
        B, T = x.shape
        tok = self.embed(x)                                    # [B,T,D]
        pos = torch.arange(pos_offset, pos_offset + T, device=x.device, dtype=tok.dtype).unsqueeze(-1)
        tok = tok + self.pos_scale * fourier_pos_enc(pos, self.d).to(tok.dtype)

        mask_tok = self.token_mask(T, x.device, tok.dtype)

        # router features (pre-update)
        pool = tok.mean(dim=1)                                 # [B,D]
        probs, logits = self.gater(pool)                       # [B,K]
        topk = min(topk, self.K)
        pvals, pidx = torch.topk(probs, topk, dim=-1)          # [B,topk], [B,topk]

        # Build per-sample bank slices
        bank_slices = []
        for b in range(B):
            slices_b = []
            for i in range(topk):
                g = int(pidx[b, i].item())
                start = g * self.group
                slices_b.append((start, start + self.group))
            bank_slices.append(slices_b)

        mem = mem_prev
        # Pass through blocks: token pass first, then weighted bank updates
        for blk in self.blocks:
            tok = blk.token_pass(tok, mem, mask_tok)
            mem = blk.bank_update(mem, tok, bank_slices, pvals)

        logits_tok = self.out(self.ln_out(tok))

        aux = {}
        if collect_aux:
            # Encourage average probs across batch to be ~uniform (rough load-balance)
            avg_p = probs.mean(dim=0)                                  # [K]
            uni = torch.full_like(avg_p, 1.0 / self.K)
            lb = torch.sum(avg_p * (avg_p.add(1e-8).log() - uni.add(1e-8).log()))  # KL(avg||uni)
            ent = (-(probs * (probs.add(1e-8).log())).sum(dim=-1).mean())
            aux = {
                "lb_loss": self.lb_lambda * lb,
                "ent_bonus": self.ent_lambda * ent,
                "gate_entropy": float(ent.detach().item())
            }

        return logits_tok, mem, aux

In [11]:

# --------------- Minimal Trainer ---------------

def make_stream_batch(B, T_total, V, D, device):
    x = torch.randint(0, V, (B, T_total), device=device)
    y = torch.full((B, T_total), -100, device=device)  # ignore first D tokens
    y[:, D:] = x[:, :-D]
    return x, y

def main():

    device   = "cuda" if torch.cuda.is_available() else "cpu"
    VOC      = 8
    DELAY    = 64
    CHUNK    = 32           # can be < DELAY
    BATCH    = 32
    LAYERS   = 3
    DIM      = 128
    HEADS    = 4
    M        = 128          # M >= DELAY + margin
    K        = 8            # number of banks; only top-k banks (per sample) updated
    TOPK     = 2
    D_PROJ   = 64           # update only first D_PROJ dims in chosen bank(s); None -> full D
    LR       = 1e-4
    STEPS    = 1200
    WARM_STEPS = 300        # during warmup: no aux losses, higher router temperature
    ACCUM    = 1            # gradient accumulation episodes per optimizer step
    LOG_EVERY = 20

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    model = RMTPartialSoftK(
        vocab=VOC, d=DIM, n_heads=HEADS, n_layers=LAYERS,
        M=M, K=K, d_proj=D_PROJ, p=0.0
    ).to(device)

    # Router temperature: start softer, anneal later if desired
    model.gater.tau = 1.5
    model.lb_lambda = 0.0
    model.ent_lambda = 0.0

    opt = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.95), weight_decay=0.01)
    crit = nn.CrossEntropyLoss(ignore_index=-100)

    global_pos = 0
    running = None

    def finite_or_die(*tensors):
        for t in tensors:
            if not torch.isfinite(t).all():
                raise RuntimeError("Non-finite detected. Lower LR, check mask, or print intermediates.")

    for step in range(STEPS):
        opt.zero_grad(set_to_none=True)
        total_loss = torch.tensor(0.0, device=device)

        for _ in range(ACCUM):
            T_total = 4 * CHUNK
            x_full, y_full = make_stream_batch(BATCH, T_total, VOC, DELAY, device)
            mem = model.init_mem(BATCH, device=device)

            for s in range(0, T_total, CHUNK):
                x = x_full[:, s:s+CHUNK]
                y = y_full[:, s:s+CHUNK]

                logits, mem, aux = model(x, mem, pos_offset=global_pos + s, topk=TOPK, collect_aux=True)
                finite_or_die(logits, mem)

                ce = crit(logits.reshape(-1, VOC), y.reshape(-1))

                if step >= WARM_STEPS:
                    # turn on tiny aux regularizers
                    model.lb_lambda  = 1e-3
                    model.ent_lambda = 5e-4
                    lb  = aux.get("lb_loss", ce.new_tensor(0.0))
                    eb  = aux.get("ent_bonus", ce.new_tensor(0.0))
                else:
                    lb = eb = ce.new_tensor(0.0)

                loss_step = ce + lb + eb
                finite_or_die(loss_step)
                total_loss = total_loss + loss_step

            global_pos += T_total

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        with torch.no_grad():
            pred = logits.argmax(-1)
            mask = (y != -100)
            acc = (pred[mask] == y[mask]).float().mean().item()
            running = total_loss.item() if running is None else 0.95 * running + 0.05 * float(total_loss.item())

        if (step + 1) % LOG_EVERY == 0:
            print(f"step {step+1:4d} | loss {running:.3f} | acc {acc:.3f} | gate_H {aux.get('gate_entropy', 0.0):.3f} | tau {model.gater.tau:.2f}")

        # (optional) very mild anneal of router temperature after warmup
        if (step + 1) % 200 == 0 and step >= WARM_STEPS:
            model.gater.tau = max(1.0, model.gater.tau * 0.95)

if __name__ == "__main__":
    main()

RuntimeError: Non-finite detected. Lower LR, check mask, or print intermediates.