In [None]:
from google.colab import userdata
import os, stat, textwrap, subprocess
token = userdata.get('GITHUB_PAT')  # stored securely in Colab Secrets
assert token, "Add GITHUB_PAT in the Secrets panel first."

# Write ~/.netrc (git/curl read this for auth)
netrc_path = os.path.expanduser("~/.netrc")
with open(netrc_path, "w") as f:
    f.write(textwrap.dedent(f"""
    machine github.com
      login {os.environ.get('USER','colab')}
      password {token}
    """).lstrip())

# Lock down permissions so git accepts it
os.chmod(netrc_path, stat.S_IRUSR | stat.S_IWUSR)

# Now you can clone without putting the token in the URL
subprocess.run(["git", "clone", "https://github.com/jmCabrillana/rtrl_moe.git"], check=True)

In [None]:
%cd rtrl_moe/

In [None]:
!git config --global user.email "jean_manuel.cabrillana@yahoo.fr"
!git config --global user.name "jmCabrillana"
!git status

In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import jacrev, vmap, functional_call
from torch.utils.tensorboard import SummaryWriter
from einops import rearrange
import random
import time
import importlib
import re

import circular_seg_tree
from circular_seg_tree import CircularTree, sparse_left_mul
from rtrl_minimal import RTRL

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
a = torch.tensor([[1., 0, 2], [0, 0, 0]]).to_sparse().requires_grad_()
b = torch.tensor([[0., 0, 1], [0, 0, 0], [0, 0, 0]]).to_sparse().requires_grad_()
torch.sparse.mm(a, b)

def sparse_left_mm(a, b):
    if a is None: return b
    if b is None: return a
    return torch.sparse.mm(b, a)

  torch.sparse.mm(a, b)


In [80]:
t = CircularTree(10, None, sparse_left_mm)
t

<structure.CircularTree at 0x7ecd7efedea0>

In [None]:
# test circular
def fmt(t):
    if t is None: return "None"
    else: return f"{t}"

def print_tree_state(step):
    print(f"\n=== After update step {step} (head={cmt.head}) ===")
    for i in range(2*n):  # 0 (unused)
        print(f"{i:2d}: {fmt(cmt.tree[i])}")

# print("TEST CIRCULAR MULT")
# def op(A, B):
#     return B @ A
# I = torch.eye(2).unsqueeze(0)  # shape: (1, 2, 2)
# n = 5
# cmt = CircularTree(n=n, I=I, op=op)

# # Perform more than n inserts to exercise wrap-around
# for k in range(1, n + 3):
#     new = I.clone()
#     new[..., 0, 1] = float(k)  # set top-right entry to k
#     cmt.update(new)
#     print_tree_state(k)

# # Sum
# res = cmt.query(2, 5)  # adjust as you like
# print("\nQuery result (logical [2,5)):", fmt(res))

print("TEST CIRCULAR SPARSE")
I = None 
n = 5
cmt = CircularTree(n=n, I=None, op=sparse_left_mm)

# Perform more than n inserts to exercise wrap-around
for k in range(1, n + 3):
    cmt.update( torch.tensor([[1., k],[0,1]]).to_sparse())
    print_tree_state(k)

# Sum
res = cmt.query(2, 5)  # adjust as you like
print("\nQuery result (logical [2,5)):", fmt(res))



# RTRL

In [3]:
# Block RTRL
# per-point function to batch jacobian later
def make_f_single(model, write=slice(None)):
    def f(params, h, x, kw):
        x1, h1 = rearrange(x, '... -> 1 ...'), rearrange(h, '... -> 1 ...')
        *_, h_next_b = functional_call(model, params, (x1, h1), kw)
        return rearrange(h_next_b, '1 h -> h')[write] # remove fake batch, slice
    return f

def add_grad_(p, g):
    if g is None: return
    g = g.detach()
    if p.grad is None:
        p.grad = g.clone()
    else:
        p.grad.add_(g)

class BlockRTRL:
    @torch.no_grad()
    def __init__(self, state_params, B, H, len_buffer=64, len_buffer_hk=8):
        self.state_params = state_params
        self.P_t = {k:torch.zeros([B, H, p.numel()]).to(p) for k,p in self.state_params.items()}
        self.last_update = {k:0 for k in self.state_params.keys()}
        self.len_buffer = len_buffer
        self.len_buffer_hk = len_buffer_hk
        # self.buffer = CircularMatTree(len_buffer, torch.eye(H).expand(B, -1, -1).to(list(self.state_params.values())[0]))
        self.buffer = CircularTree(len_buffer, None, sparse_left_mul)
        self.last_hk = [0]
        self.t = 0

    def reset(self):
        [P_t.zero_() for P_t in self.P_t.values()]

    @torch.no_grad()
    def get_left_product(self, k, l):
        return self.buffer.query(k, l)
      #   if l is None: l = len(self.buffer)
      #   L = torch.eye(H).expand(B, -1, -1)
      #   for i in range(k, l):
      #     L = torch.bmm(self.buffer[i], L)
      #   return L

    @torch.no_grad()
    def step(self, model, x_t, h_t, loss, active_params, read=None, write=None, **kw):
        """
        x_t: [B,...], h_t: [B,H], P_t: [B,H,Tp], dL_dH_t: [B,H]
        """
        t0 = time.time()
        params = dict(model.named_parameters())
        B, H = h_t.shape[:2]
        read = list(range(H)) if read is None else read
        write = list(range(H)) if write is None else write
        f1 = make_f_single(model, write)

        print_time = False

        # batched jacobian of per-sample f
        with torch.enable_grad():
            if print_time: t1 = time.time(); print("0: ", t1-t0); t0=t1
            Jh_proj = vmap(jacrev(f1, argnums=1), in_dims=(None, 0, 0, None))(active_params, h_t, x_t, kw)  # [B,W,H]
            # Jh = torch.eye(H).repeat(B, 1, 1).to(h_t), Jh[:, proj] = Jh_proj
            if print_time: t1 = time.time(); print("1: ", t1-t0)
            Jtheta_proj = vmap(jacrev(f1, argnums=0), in_dims=(None, 0, 0, None))(active_params, h_t, x_t, kw) # [B,W,[...]]
            Jtheta_proj = {k:rearrange(v, 'b h ... -> b h (...)') for k, v in Jtheta_proj.items()}  # [B,W,[Tp]]
            # >>> Detach before storing <<<
            Jh_proj = Jh_proj.detach()
            Jtheta_proj = {k: v.detach() for k, v in Jtheta_proj.items()}

        # Update circular buffer
        if print_time: t1 = time.time(); print("2: ", t1-t0); t0=t1
        # self.buffer.update((proj, Jh_proj))
        if print_time: t1 = time.time(); print("3: ", t1-t0); t0=t1
        # self.buffer.append(Jh); if len(self.buffer) > self.len_buffer: self.buffer.pop(0)

        # RTRL recursion on active or expiring sensitivities
        for k in self.state_params.keys():
            # Active parameters update
            if k in active_params.keys():
                # Jtheta = torch.zeros([B, H, self.state_params[k].numel()]); Jtheta[:, proj] = Jtheta_tree[k]
                # t <-> index [q-1]; product for time s < t starts at s+1, ie index [s - t + (q-1) + 1] = [s - t + q]
                # idx, L_Jh = self.get_left_product(self.last_update[k] - self.t + self.len_buffer, self.len_buffer)
                # self.P_t[k][:, idx] = L_Jh @ self.P_t[k]
                self.P_t[k][:, write] = Jh_proj[:,:,read] @ self.P_t[k][:, read]
                self.P_t[k][:, write] += Jtheta_proj[k]
                self.last_update[k] = self.t
            # Expiring parameters update
            else:
                self.P_t[k][:, write] = Jh_proj[:,:,read] @ self.P_t[k][:, read]
        # if print_time: t1 = time.time(); print("4.1: ", t1-t0); t0=t1
        # for k in self.state_params.keys():
        #     idx, L_Jh = self.get_left_product(0, self.len_buffer)
        #     if k not in active_params.keys() and self.last_update[k] <= self.t - self.len_buffer: # == in practice
        #         self.P_t[k][:, idx] = L_Jh @ self.P_t[k]
        #         self.last_update[k] = self.t
        # if print_time: t1 = time.time(); print("4.2: ", t1-t0); t0=t1


        # Backpropagate, Maintain last activated
        if loss is not None:
            dL_dTheta = dict(zip(params.keys(), torch.autograd.grad(loss, params.values(), retain_graph=True, allow_unused=True)))
            (dL_dH_t,) = torch.autograd.grad(loss, h_t, retain_graph=False) #[:,self.last_hk]
            for k, g in dL_dTheta.items():
                add_grad_(params[k], g)
            for k, p in self.state_params.items():
                g_mean = torch.einsum('b h, b h t -> b t', dL_dH_t, self.P_t[k]).mean(0) #[self.last_hk]
                add_grad_(params[k], g_mean.view(p.shape))
        if len(write) != H:
            I = set(self.last_hk) & set(write)
            self.last_hk = ([k for k in self.last_hk + write if k not in I] + list(I))
            self.last_hk = self.last_hk[-self.len_buffer_hk:]
        if print_time: t1 = time.time(); print("5: ", t1-t0); t0=t1

        self.t += 1


# Smoke Test

In [4]:
class Dummy(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=3, output_dim=1):
        super().__init__()
        self.state_fc1 = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.state_fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.state_norm = nn.LayerNorm(hidden_dim)
        self.output_fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.output_fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, h):
        z = torch.cat([x, h], dim=-1)
        s = F.silu(self.state_fc1(z))
        s = s + F.silu(self.state_fc2(s))
        s = self.state_norm(s)
        h_next = torch.tanh(s)
        y = self.output_fc2(F.silu(self.output_fc1(h_next)))
        return y, h_next
    
class DummyChunk(nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f
    
    def forward(self, x_chunk, h0):
        h = h0.clone()
        for k in range(x_chunk.shape[1]):
            y, h = self.f(x_chunk[:, k], h)
        return y, h

model = Dummy(input_dim=5, hidden_dim=3, output_dim=5).to(device)
model_seq = DummyChunk(model).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
B, H, D = 1, 3, 5
state_params = {k:v.to(device) for k,v in model.named_parameters() if k.startswith("state_")}
state_params = dict(model.named_parameters())
rtrl = BlockRTRL(state_params, B, H)
# writer = SummaryWriter(log_dir="runs/block_rtrl_1")
t=0

In [None]:
# regression
n=10
for _ in range(n):
  x_t = torch.randn(B, D).to(device)   # input batch
  h_t = torch.randn(B, H).to(device).requires_grad_()   # hidden states
  target = 2*x_t + 0.5
  y, h_next = model(x_t, h_t)
  loss = ((y - target) ** 2).mean()
  optimizer.zero_grad()
  rtrl.step(model, x_t, h_t, loss, state_params, [t % H, (t-1) % H], [t % H])
  # loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  optimizer.step()
  h_t = h_next.detach()
  # writer.add_scalar("train/loss", loss.item(), t)
  print(loss.item())
  t += 1

In [6]:
# anbn
B, H, D, T = 1, 30, 2, 4
model = Dummy(input_dim=D, hidden_dim=H, output_dim=D).to(device)
model_seq = DummyChunk(model).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
seq_optimizer = torch.optim.Adam(model_seq.parameters(), lr=1e-3)
state_params = {k:v.to(device) for k,v in model.named_parameters() if k.startswith("state_")}
seq_state_params = {k:v.to(device) for k,v in model_seq.named_parameters() if k.startswith("f.state_")}
rtrl = BlockRTRL(state_params, B, H)
seq_rtrl = BlockRTRL(seq_state_params, B, H)
writer = SummaryWriter(log_dir="runs/anbn-shuffle_dummy_bptt_1")
criterion = nn.CrossEntropyLoss() 
t = 0

def make_seq(len_max_a=8):
    # half the time make a valid a^n b^n, otherwise pick a mismatched n_b
    n_a = random.randint(1, len_max_a)
    if random.random() < 0.5:
        n_b = n_a
        tgt = torch.tensor([1]).to(device)
    else:
        n_b = random.choice([k for k in range(1,len_max_a+1) if k != n_a])
        tgt = torch.tensor([0]).to(device)
    seq = torch.cat([torch.tensor([[1.,0.]]).repeat(n_a,1), torch.tensor([[0.,1.]]).repeat(n_b,1)], dim=0).to(device)
    perm = torch.randperm(seq.size(0), device=seq.device)
    return tgt, seq[perm]
    

In [7]:
# anbn bptt
n_steps = 10000
acc = 64
for step in range(n_steps):
      tgt, x_seq = make_seq(T)
      x_seq = x_seq.unsqueeze(0)
      ht = torch.zeros(B, H).to(device)
      for k in range(x_seq.shape[1]):
            y, h_next = model(x_seq[:, k], ht)
            ht = h_next
      loss = criterion(y, tgt)/acc
      loss.backward()
      if (step+1) % acc == 0:
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
      writer.add_scalar("train/loss", loss.item()*acc, t)
      writer.add_scalar("train/P_t", 0, t)
      t += 1

In [None]:
# anbn rtrl chunk
n_steps = 30000
acc = 64
chunk = 4
for step in range(n_steps):
    tgt, x_seq = make_seq(T)
    x_seq = x_seq.unsqueeze(0)
    h = torch.zeros(B, H).requires_grad_().to(device)
    seq_rtrl.reset()
    for i in range(0, x_seq.shape[1], chunk):
        j = min(i + chunk, x_seq.shape[1])
        y, h_next = model_seq(x_seq[:, i:j], h)
        if j < x_seq.shape[1]:
            seq_rtrl.step(model_seq, x_seq[:, i:j], h, None, seq_state_params) #, [t % H, (t+1) % H])
            h = h_next.detach().requires_grad_()
    loss = criterion(y, tgt)/acc
    seq_rtrl.step(model_seq, x_seq[:, i:j], h, loss, seq_state_params) #, [t % H, (t+1) % H])
    if (step+1) % acc == 0:
        nn.utils.clip_grad_norm_(model_seq.parameters(), 1.0)
        seq_optimizer.step()
        seq_optimizer.zero_grad()
    writer.add_scalar("train/loss", loss.item()*acc, t)
    writer.add_scalar("train/P_t", 0, t) #list(seq_rtrl.P_t.values())[0].abs().mean().item(), t)
    t += 1

In [None]:
# anbn rtrl naive
n_steps = 100000
acc = 64
for step in range(n_steps):
    tgt, x_seq = make_seq(T)
    x_seq = x_seq.unsqueeze(0)
    h = torch.zeros(B, H).requires_grad_().to(device)
    rtrl.reset()
    for k in range(x_seq.shape[1]):
        y, h_next = model(x_seq[:, k], h)
        if k < x_seq.shape[1] - 1:
            rtrl.step(model, x_seq[:, k], h, None, state_params) #, [t % H, (t+1) % H])
            h = h_next.detach().requires_grad_()
    loss = criterion(y, tgt)/acc
    rtrl.step(model, x_seq[:, k], h, loss, state_params) #, [t % H, (t+1) % H])
    if (step+1) % acc == 0:
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()
    writer.add_scalar("train/loss", loss.item()*acc, t)
    writer.add_scalar("train/P_t", 0, t) #list(rtrl.P_t.values())[0].abs().mean().item(), t)
    t += 1

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs

# MoE

In [50]:
# MoE

# Expert (vmap compatible)
class ExpertBank(nn.Module):
    """
    E experts, each: y = act(W[e] @ x + b[e])
    Params are stored per-expert (ParameterList) but used via stacked banks in forward.
    """
    def __init__(self, E, d):
        super().__init__()
        self.E, self.d = E, d
        self.W = nn.ParameterList([nn.Parameter(torch.empty(d, d)) for _ in range(E)])
        self.b = nn.ParameterList([nn.Parameter(torch.empty(d))    for _ in range(E)])
        self.reset_parameters()

    def reset_parameters(self):
        for W, b in zip(self.W, self.b):
            nn.init.kaiming_uniform_(W, a=5**0.5)  # good for ReLU
            fan_in = W.size(1)
            bound = 1.0 / fan_in**0.5
            nn.init.uniform_(b, -bound, bound)

    def forward(self, x, w, idx):
        B, T, D = x.shape
        _, k = idx.shape

        W_bank = torch.stack(list(self.W), dim=0).contiguous()  # [E, D, D]
        b_bank = torch.stack(list(self.b), dim=0).contiguous()  # [E, D]

        flat_idx = idx.reshape(-1)
        W_sel = W_bank.index_select(0, flat_idx).reshape(B, k, D, D).contiguous()   # [B,k,D,D]
        b_sel = b_bank.index_select(0, flat_idx).reshape(B, k, D).contiguous()      # [B,k,D]

        y_k = torch.einsum('b t i, b k i o -> b t k o', x, W_sel) + b_sel.unsqueeze(1)
        y_k = F.relu(y_k)
        y = (y_k * w.view(B,1,k,1)).sum(dim=2)  # [B,T,D]
        return y


class MLP(nn.Module):
    def __init__(self, d, mult=1, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(d, mult * d)
        self.fc2 = nn.Linear(mult * d, d)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        return self.fc2(self.drop(F.gelu(self.fc1(x))))

class TopKGate(nn.Module):
    def __init__(self, d, n_experts, k=2):
        super().__init__()
        assert 1 <= k <= n_experts
        self.k, self.n = k, n_experts
        self.proj = nn.Linear(d, n_experts, bias=False)
    def forward(self, h):                         # h: [B,D]
        logits = self.proj(h)                     # [B,E]
        val, idx = torch.topk(logits, self.k, dim=-1)  # [B,k]
        return val.softmax(-1), idx               # normalize over top-k
    
def fourier_pos_enc(pos, d, base=10000): 
    sin = torch.sin(pos / base**(torch.arange(0,d//2).to(pos)/ (d//2))).to(pos)
    cos = torch.cos(pos / base**(torch.arange(0,d//2).to(pos)/ (d//2))).to(pos)
    return torch.cat([sin, cos], dim=-1)


# ---------- Recurrent MoE ----------
class RecurrentMoE(nn.Module):
    """
    Step:
      1) Mixed-attn: Q=latent, K/V=concat(latent, x)
      2) Pool latent -> gate -> mixture of experts (sequence-level)
      3) Write expert update back to latent slots chosen by a small gate
    """
    def __init__(self, d_model=512, n_heads=2, n_slots=4, n_experts=4, topk=2, dropout=0.0, d_in=None, d_out=None):
        super().__init__()
        d, E = d_model, n_experts
        self.d, self.n_slots = d, n_slots
        if not d_in: d_in = d; self.d_in = d_in
        if not d_out: d_out = d; self.d_out = d_out

        # Attention
        self.state_embedding = nn.Linear(d_in, d)
        self.state_mha = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
        self.state_ln_q = nn.LayerNorm(d)
        self.state_ln_kv = nn.LayerNorm(d)
        self.state_ffn = MLP(d, dropout=dropout)
        self.state_ln_ffn = nn.LayerNorm(d)

        # Sequence level gating with topk expert
        self.state_gate = TopKGate(d, n_experts, k=topk)
        # self.state_experts = nn.ModuleList([MLP(d, dropout=dropout) for _ in range(n_experts)])
        self.state_experts = ExpertBank(E, d)
        self.state_ln_moe_in = nn.LayerNorm(d)
        self.state_latent_proj = nn.Linear(d, d)

        # Latent projection
        self.state_ln_slot = nn.LayerNorm(d)
        self.state_slot_ctx = nn.Linear(d, 1, bias=False)

        # Output Attention
        self.out_embedding = nn.Linear(d_in, d)
        self.out_mha = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
        self.out_ln_q = nn.LayerNorm(d)
        self.out_ln_kv = nn.LayerNorm(d)
        self.out_ffn = MLP(d, dropout=dropout)
        self.out_ln_ffn = nn.LayerNorm(d)
        self.out_proj = nn.Linear(d, d_out)

        self.t=0

    @torch.no_grad()
    def init_state(self, batch_size, device=None, dtype=None):
        device = device or next(self.parameters()).device
        dtype = dtype or next(self.parameters()).dtype
        latent = torch.randn(batch_size, self.n_slots*self.d, device=device, dtype=dtype) * 0.02
        return latent

    def forward(self, x, state_flat):
        B, S, D = state_flat.shape[0], self.n_slots, self.d
        _, T, _ = x.shape
        device, dtype = state_flat.device, state_flat.dtype
        state_old = state_flat.reshape([B, S, D]).contiguous()
        latent = state_old.clone()
        pos = torch.arange(T, device=device).unsqueeze(-1).float() #+ self.t
        pe = fourier_pos_enc(pos, D, base=S)
        # self.t += T

        # (1) Mixed attention
        latent_state_x = self.state_embedding(x) + pe
        q = self.state_ln_q(latent)
        kv = torch.cat([latent, latent_state_x], dim=1)
        kv = self.state_ln_kv(kv)
        attn_out, _ = self.state_mha(q, kv, kv, need_weights=False)
        latent = latent + attn_out
        latent = latent + self.state_ffn(self.state_ln_ffn(latent))

        # (2) MoE with top-k experts (sequence-level) =1 here
        # pooled = self.state_ln_moe_in(latent.mean(dim=1))       # [B,D]
        # w, idx = self.state_gate(pooled)                        # [B,k]]
        # mixed = self.state_experts(latent, w, idx)
        # latent = latent + mixed

        # (3) Choose *one* target slot per sample and update state only there
        logits = self.state_slot_ctx(self.state_ln_slot(latent)).squeeze(-1)        # [B,S]
        logits, tgt_idx = torch.topk(logits, 2, dim=-1)                          # [B,2]
        w = F.softmax(logits, dim=-1)
        alpha = torch.zeros(B, S, device=device, dtype=dtype)
        alpha = alpha.scatter_add_(1, tgt_idx, w)
        beta = 0.6
        state = beta*state_old + (1-beta)* alpha.unsqueeze(-1) * torch.tanh(latent)     # out-of-place residual update

        # (4) Output 
        latent_out = self.out_embedding(x) + pe
        q = self.out_ln_q(latent_out)
        kv = torch.cat([latent_out, state], dim=1)
        kv = self.out_ln_kv(kv)
        # attn_tok  = torch.triu(torch.full((T, T), float('-inf'), device=x.device), 1)
        # attn_slot = torch.zeros(T, S, device=device)
        # attn_mask = torch.cat([attn_tok, attn_slot], dim=1)           # [T, T+S]
        attn_out, _ = self.out_mha(q, kv, kv, need_weights=False, attn_mask=None)
        latent_out = latent_out + attn_out
        latent_out = latent_out + self.out_ffn(self.out_ln_ffn(latent_out))
        y = self.out_proj(latent_out[:, -1])

        info = {"idx_experts": 0, "idx_slots": tgt_idx.detach()} # idx
        return y, info, state.reshape([B, S*D]).contiguous()

# ---------- Tiny usage ----------
B, T, D = 1, 8, 64
model = RecurrentMoE(d_model=D, n_heads=4, n_slots=64, n_experts=64, topk=2).to(device)
x = torch.randn(B, T, D).to(device)
pattern = re.compile(r'^state_experts\.(W|b)\.\d+$')
state_params = {name:p for name,p in model.named_parameters() if name.startswith("state_")}
core_params = {name:p for name, p in state_params.items() if pattern.match(name) is None}
expert_params = {name:p for name, p in state_params.items() if pattern.match(name)}
state = model.init_state(B, device=x.device).requires_grad_()
rtrl = BlockRTRL(state_params, B, state.shape[-1])

def get_expert_latent_activated(info):
    idx_slots = list(set(info['idx_slots'].flatten().tolist()))
    proj = sum((list(range(D*i, D*(i+1))) for i in idx_slots), start=[])
    expert_ids = list(set(info['idx_experts'].flatten().tolist()))
    ids_pattern = "|".join(map(str, expert_ids))
    pattern = re.compile(rf'^state_experts\.(W|b)\.({ids_pattern})$')
    active_experts_params = {name:p for name, p in state_params.items() if pattern.match(name)}
    active_params = core_params | active_experts_params
    return active_params, proj

In [51]:
y, info, state_new = model(x, state)
loss = ((y - torch.zeros_like(y))**2).mean()
# active_params, proj = get_expert_latent_activated(info)

In [53]:
y.shape

torch.Size([1, 64])

In [18]:
rtrl.step(model, x, state, loss, active_params, proj)
state = state_new.detach()

  attn_output = scaled_dot_product_attention(
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [11]:
sum(p[1].numel() for p in rtrl.buffer.tree if p is not None)

0

In [72]:
sum(p.numel() for p in rtrl.P_t.values())

21004288

# Haystack

In [104]:
# Haystack
VOC        = 8
SEQ_LEN    = 64
B          = 1
H          = 64

BOS, KEY, SEP, Q, BASE = 0, 1, 2, 3, 4

# ---- data: on-the-fly batch sampler ----
def sample_batch(bs):
    x = torch.empty(bs, SEQ_LEN, dtype=torch.long)
    y = torch.empty(bs, dtype=torch.long)
    for i in range(bs):
        k = random.randrange(VOC-BASE)
        # choose insertion so there's space before final Q
        ins = random.randrange(1, SEQ_LEN - 1 - 16)
        seq = [BOS]
        while len(seq) < ins: seq.append(random.randrange(BASE, VOC))
        seq += [KEY, BASE + k, SEP]
        while len(seq) < SEQ_LEN - 1: seq.append(random.randrange(BASE, VOC))
        seq.append(Q)
        x[i] = torch.tensor(seq)
        y[i] = k
    return x.to(device), y.to(device)

In [34]:
# Probe
x1, y1 = sample_batch(1)
with torch.no_grad():
    pred = model(x1)[0].argmax(-1).item()
print(f"probe true_key={y1.item()} predicted={pred}")
print(x1, y1)

probe true_key=45 predicted=11
tensor([[ 0, 19, 29, 64,  7, 40, 51, 20, 21, 10, 38, 45, 12, 22, 50, 52, 45, 45,
         19, 22, 32, 19, 56, 12, 14, 55, 41, 25, 38, 56,  1, 50,  2, 50, 17, 26,
         12, 64, 48, 25, 40,  6, 17, 12, 57, 41, 24, 50, 63, 33, 30, 52, 59, 25,
         34, 31, 10, 59, 14,  8, 63, 55, 54,  3]], device='cuda:0') tensor([45], device='cuda:0')


In [48]:
# model = Dummy(input_dim=D, hidden_dim=H, output_dim=D).to(device)
model = RecurrentMoE(d_model=D, n_heads=4, n_slots=H, n_experts=4, topk=2, d_in=VOC, d_out=VOC).to(device)
pattern = re.compile(r'^state_experts\.(W|b)\.\d+$')
state_params = {name:p for name,p in model.named_parameters() if name.startswith("state_")}
core_params = {name:p for name, p in state_params.items() if pattern.match(name) is None}
expert_params = {name:p for name, p in state_params.items() if pattern.match(name)}

h = model.init_state(1, device=device).requires_grad_()
rtrl = BlockRTRL(state_params, B, h.shape[-1], len_buffer=1, len_buffer_hk=2*2*D)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")  # per-step
writer = SummaryWriter(log_dir="runs/copy2000_tbptt_moe_0")

In [111]:
# model LSTM
class LSTMClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(VOC, H)
        self.rnn = nn.LSTM(H, H, num_layers=2, batch_first=True, dropout=0.)
        self.head = nn.Linear(H, VOC-BASE)
        self.norm = nn.LayerNorm(H)
    def forward(self, x, h=None):
        out, h = self.rnn(self.emb(x), h)      # [B,T,D]
        logits = self.head(self.norm(out[:, -1]))  # read at final token (Q)
        return logits, h
    
# ---- train (truncated BPTT) ----
lr  = 2e-3
model = LSTMClassifier().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss(reduction='sum')

In [112]:
acc = 64
tbptt = SEQ_LEN
cum_loss = 0
for step in range(1, 1000000):
    model.train()
    x, y = sample_batch(B)
    h = None
    # roll through sequence in windows; carry hidden; detach between
    for s in range(0, SEQ_LEN, tbptt):
        chunk = x[:, s:s+tbptt]
        logits, h = model(chunk, h)
        h = (h[0].detach(), h[1].detach())
    loss = loss_fn(logits, y)
    (loss/acc).backward()
    cum_loss += loss.item()
    
    if step % (acc*100) == 0:
        print(f"step {step:4d} | loss {cum_loss/(acc*100):.3f}")
        cum_loss = 0
    if step % acc == 0 and step != 0:
        # nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        opt.zero_grad(set_to_none=True)
    

step 6400 | loss 1.406
step 12800 | loss 1.394
step 19200 | loss 1.393
step 25600 | loss 1.391
step 32000 | loss 1.390
step 38400 | loss 1.390
step 44800 | loss 1.390
step 51200 | loss 1.390
step 57600 | loss 1.390
step 64000 | loss 1.390


KeyboardInterrupt: 

# Copy Task

In [127]:
SEQ_LEN    = 400
DELAY_MIN  = 1
DELAY_MAX  = 1
CHUNK      = 1
TBPTT      = 10
H          = 64
VOC        = 2  # vocab
D          = 128

# --- Data: random tokens and per-step random delay ---
def random_seq_copy():
    x = torch.randint(0, VOC, (SEQ_LEN,), dtype=torch.long)
    d = torch.randint(DELAY_MIN, DELAY_MAX + 1, (SEQ_LEN,), dtype=torch.long)
    # y_t = x_{t - d_t} if t - d_t >= 0 else ignore
    y = torch.full((SEQ_LEN,), -100, dtype=torch.long)  # -100 = ignore_index
    valid = torch.arange(SEQ_LEN) - d >= 0
    idx = torch.nonzero(valid, as_tuple=False).squeeze(-1)
    y[idx] = x[idx - d[idx]]
    x, y, d = x.to(device), y.to(device), d.to(device)
    return x, y, d

# model = Dummy(input_dim=D, hidden_dim=H, output_dim=D).to(device)
model = RecurrentMoE(d_model=D, n_heads=4, n_slots=H, n_experts=4, topk=2, d_in=VOC, d_out=VOC).to(device)
pattern = re.compile(r'^state_experts\.(W|b)\.\d+$')
state_params = {name:p for name,p in model.named_parameters() if name.startswith("state_")}
core_params = {name:p for name, p in state_params.items() if pattern.match(name) is None}
expert_params = {name:p for name, p in state_params.items() if pattern.match(name)}

h = model.init_state(1, device=device).requires_grad_()
rtrl = BlockRTRL(state_params, B, h.shape[-1], len_buffer=1, len_buffer_hk=2*2*D)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")  # per-step
writer = SummaryWriter(log_dir="runs/copy2000_tbptt_moe_0")
epoch = 0
running = 0.0
x, y, d = random_seq_copy()

In [101]:
# LSTM
class LSTM(nn.Module):
    """
    Input:  x  [B, T, D]
    State:  h  [B, 2H]  (concat of h and c to fit your flattened state pattern)
    Output: logits [B, T, D], info dict, h_new [B, 2H]
    """
    def __init__(self, D=16, H=128):
        super().__init__()
        self.D, self.H = D, H
        self.embed = nn.Linear(D, H)
        self.rnn = nn.LSTM(input_size=H, hidden_size=H, num_layers=1, batch_first=True)
        self.out = nn.Linear(H, D)

    @torch.no_grad()
    def init_state(self, batch_size, device=None, dtype=None):
        device = device or next(self.parameters()).device
        dtype  = dtype  or next(self.parameters()).dtype
        # pack (h, c) into one flat tensor [B, 2H]
        return torch.zeros(batch_size, 2*self.H, device=device, dtype=dtype)

    def forward(self, x, h_flat):
        # unpack flat [B,2H] -> (h0, c0) both [1,B,H]
        h0, c0 = h_flat.split(self.H, dim=-1)
        h0, c0 = h0.unsqueeze(0).contiguous(), c0.unsqueeze(0).contiguous()
        out, (hN, cN) = self.rnn(self.embed(x), (h0, c0))   # out: [B,T,H]
        logits = self.out(out)                  # [B,T,D]
        info = {}
        h_new = torch.cat([hN.squeeze(0), cN.squeeze(0)], dim=-1).contiguous()  # [B,2H]
        return logits[:,0], info, h_new


In [102]:
model = LSTM(VOC, H).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# --- Training: chunked outer loop, per-step inner loop; backward every step ---
model.train()
for epoch in range(0, 300):
    x, y, d = random_seq_copy()
    # h = torch.zeros(1, H, device=device)
    h = model.init_state(1, device=device).requires_grad_()
    # model.t = 0
    rtrl.reset()
    loss = torch.zeros(1, device=device).requires_grad_()
    for t in range(0, SEQ_LEN-CHUNK): #, CHUNK):
        end = min(t + CHUNK, SEQ_LEN-1)
        x_chunk = F.one_hot(x[t:end], num_classes=VOC).float().unsqueeze(0)  # [1, T, VOC]
        y_t = y[end-1]  # [T]
        
        logits, info, h_new = model(x_chunk, h)             # logits: [1, VOC]
        # active_params, proj = get_expert_latent_activated(info)
        loss_t = criterion(logits.squeeze(0), y_t).mean()  
        
        # opt.zero_grad()
        # rtrl.step(model, x_chunk, h, loss_t, active_params, proj)
        # loss_t.backward()
        # opt.step()
        loss = loss + loss_t
        h = h_new #.detach().requires_grad_()
        if (t % TBPTT == 0) or (t == SEQ_LEN-CHUNK-1) :
            loss /= TBPTT
            loss.backward()
            opt.step()
            opt.zero_grad()
            h = h.detach()
            loss = torch.zeros(1, device=device).requires_grad_()
        
        running = (0.98 * running + 0.02 * loss_t.item()) if running else loss_t.item()
        if t % 100 == 0: print(f"loss: {running:.4f}"); print(info)
        writer.add_scalar("train/loss", loss_t.item(), epoch*SEQ_LEN + t)
        # P_t_mean = sum(p.abs().mean() for p in rtrl.P_t.values()) / len(rtrl.P_t)
        # writer.add_scalar("train/P_t", P_t_mean.item(), epoch*(SEQ_LEN/CHUNK) + start/CHUNK)
    # opt.zero_grad(); loss.backward(); opt.step()
    print(f"epoch {epoch+1} | running CE (masked): {running:.4f}")

In [131]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x758f08de8c40>

In [None]:
import torch, gc

def list_tensors_on_gpu():
    tensors = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                tensors.append(obj)
            elif hasattr(obj, "data") and torch.is_tensor(obj.data) and obj.data.is_cuda:
                tensors.append(obj.data)
        except:
            pass

    # Sort tensors by memory size (largest first)
    tensors = sorted(tensors, key=lambda t: t.numel() * t.element_size(), reverse=True)

    total = 0
    print(f"{'Idx':>3} {'Shape':>20} {'Dtype':>10} {'Size (MB)':>12} {'Device':>10}")
    print("-" * 60)
    for i, t in enumerate(tensors):
        size_mb = t.numel() * t.element_size() / (1024 ** 2)
        total += size_mb
        print(f"{i:3d} {str(tuple(t.shape)):>20} {str(t.dtype):>10} {size_mb:12.2f} {str(t.device):>10}")

    print("-" * 60)
    print(f"Total GPU tensor memory: {total:.2f} MB")

list_tensors_on_gpu()


In [59]:
import gc, torch
gc.collect(); torch.cuda.empty_cache()
