Cloning into 'rtrl_moe'...
remote: Enumerating objects: 43, done.[K
remote: Counting objects: 100% (43/43), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 43 (delta 16), reused 35 (delta 8), pack-reused 0 (from 0)[K
Receiving objects: 100% (43/43), 21.27 KiB | 5.32 MiB/s, done.
Resolving deltas: 100% (16/16), done.


In [6]:
%cd rtrl_moe/

/content/rtrl_moe


In [2]:
!git config --global user.email "jean_manuel.cabrillana@yahoo.fr"
!git config --global user.name "jmCabrillana"

In [14]:
!git status

On branch main
Your branch is up to date with 'origin/main'.

nothing to commit, working tree clean


In [5]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import jacrev, vmap, functional_call
from torch.utils.tensorboard import SummaryWriter
from einops import rearrange
import random
import time
import importlib
import re

import structure
from structure import CircularTree, sparse_left_mul
from rtrl_minimal import RTRL

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
importlib.reload(structure)
from structure import CircularTree, sparse_left_mul

In [17]:
# @title
# per-point function to batch jacobian later
def make_f_single(model, proj=slice(None)):
    def f(params, h, x, kw):
        x1, h1 = rearrange(x, '... -> 1 ...'), rearrange(h, '... -> 1 ...')
        *_, h_next_b = functional_call(model, params, (x1, h1), kw)
        return rearrange(h_next_b, '1 h -> h')[proj] # remove fake batch, slice
    return f

def add_grad_(p, g):
    if g is None: return
    g = g.detach()
    if p.grad is None:
        p.grad = g.clone()
    else:
        p.grad.add_(g)

class BlockRTRL:
    @torch.no_grad()
    def __init__(self, state_params, B, H, len_buffer=64, len_buffer_hk=8):
        self.state_params = state_params
        self.P_t = {k:torch.zeros([B, H, p.numel()]).to(p) for k,p in self.state_params.items()}
        self.last_update = {k:0 for k in self.state_params.keys()}
        self.len_buffer = len_buffer
        self.len_buffer_hk = len_buffer_hk
        # self.buffer = CircularMatTree(len_buffer, torch.eye(H).expand(B, -1, -1).to(list(self.state_params.values())[0]))
        self.buffer = CircularTree(len_buffer, None, sparse_left_mul)
        self.last_hk = [0]
        self.t = 0

    def reset(self):
        [P_t.zero_() for P_t in self.P_t.values()]

    @torch.no_grad()
    def get_left_product(self, k, l):
        return self.buffer.query(k, l)
      #   if l is None: l = len(self.buffer)
      #   L = torch.eye(H).expand(B, -1, -1)
      #   for i in range(k, l):
      #     L = torch.bmm(self.buffer[i], L)
      #   return L

    @torch.no_grad()
    def step(self, model, x_t, h_t, loss, active_params, proj=None, **kw):
        """
        x_t: [B,...], h_t: [B,H], P_t: [B,H,Tp], dL_dH_t: [B,H]
        """
        params = dict(model.named_parameters())
        B, H = h_t.shape[:2]
        proj = list(range(H)) if proj is None else proj
        f1 = make_f_single(model, proj)

        # batched jacobian of per-sample f
        Jh_proj = vmap(jacrev(f1, argnums=1), in_dims=(None, 0, 0, None))(active_params, h_t, x_t, kw)  # [B,H,H]
        # Jh = torch.eye(H).repeat(B, 1, 1).to(h_t), Jh[:, proj] = Jh_proj
        Jtheta_proj = vmap(jacrev(f1, argnums=0), in_dims=(None, 0, 0, None))(active_params, h_t, x_t, kw) # [B,H,[...]]
        Jtheta_proj = {k:rearrange(v, 'b h ... -> b h (...)') for k, v in Jtheta_proj.items()}  # [B,H,[Tp]]
        # >>> Detach before storing <<<
        Jh_proj = Jh_proj.detach()
        Jtheta_proj = {k: v.detach() for k, v in Jtheta_proj.items()}

        # Update circular buffer
        self.buffer.update((proj, Jh_proj))
        # self.buffer.append(Jh); if len(self.buffer) > self.len_buffer: self.buffer.pop(0)

        # RTRL recursion on active or expiring sensitivities
        for k in self.state_params.keys():
            # Active parameters update
            if k in active_params.keys():
                # Jtheta = torch.zeros([B, H, self.state_params[k].numel()]); Jtheta[:, proj] = Jtheta_tree[k]
                # t <-> index [q-1]; product for time s < t starts at s+1, ie index [s - t + (q-1) + 1] = [s - t + q]
                idx, L_Jh = self.get_left_product(self.last_update[k] - self.t + self.len_buffer, self.len_buffer)
                self.P_t[k][:, idx] = L_Jh @ self.P_t[k]
                self.P_t[k][:, proj] += Jtheta_proj[k]
                self.last_update[k] = self.t
            # Expiring parameters update
            idx, L_Jh = self.get_left_product(0, self.len_buffer)
            if k not in active_params.keys() and self.last_update[k] <= self.t - self.len_buffer: # == in practice
                self.P_t[k][:, idx] = L_Jh @ self.P_t[k]
                self.last_update[k] = self.t


        # Backpropagate, Maintain last activated
        if loss is not None:
            dL_dTheta = dict(zip(params.keys(), torch.autograd.grad(loss, params.values(), retain_graph=True, allow_unused=True)))
            (dL_dH_t,) = torch.autograd.grad(loss, h_t, retain_graph=False) #[:,self.last_hk]
            for k, g in dL_dTheta.items():
                add_grad_(params[k], g)
            for k, p in self.state_params.items():
                g_mean = torch.einsum('b h, b h t -> b t', dL_dH_t, self.P_t[k]).mean(0) #[self.last_hk]
                add_grad_(params[k], g_mean.view(p.shape))
        if len(proj) != H:
            I = set(self.last_hk) & set(proj)
            self.last_hk = ([k for k in self.last_hk + proj if k not in I] + list(I))
            self.last_hk = self.last_hk[-self.len_buffer_hk:]

        self.t += 1


In [18]:
# @title
class Dummy(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=3, output_dim=3):
        super().__init__()
        self.state_linear = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.output_linear = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, h):
        # concatenate input and state
        z = torch.cat([x, h], dim=-1)
        h_next = torch.tanh(self.state_linear(z))
        y = self.output_linear(h_next)
        return y, h_next

class Dummy2(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=3):
        super().__init__()
        self.state_fc1 = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.state_fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.state_norm = nn.LayerNorm(hidden_dim)
        self.output_fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.output_fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x, h):
        z = torch.cat([x, h], dim=-1)
        s = F.silu(self.state_fc1(z))
        s = s + F.silu(self.state_fc2(s))
        s = self.state_norm(s)
        h_next = torch.tanh(s)
        y = self.output_fc2(F.silu(self.output_fc1(h_next)))
        return y, h_next


model = Dummy(input_dim=5, hidden_dim=3, output_dim=5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
B, H, D = 1, 3, 5
state_params = {k:v.to(device) for k,v in model.named_parameters() if k.startswith("state_")}
rtrl = BlockRTRL(state_params, B, H)
# writer = SummaryWriter(log_dir="runs/block_rtrl_1")
t=0

In [None]:
# @title
n=10
for _ in range(n):
  x_t = torch.randn(B, D).to(device)   # input batch
  h_t = torch.randn(B, H).to(device).requires_grad_()   # hidden states
  target = 2*x_t + 0.5
  y, h_next = model(x_t, h_t)
  loss = ((y - target) ** 2).mean()
  optimizer.zero_grad()
  rtrl.step(model, x_t, h_t, loss, state_params, [t % H])
  # loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  optimizer.step()
  h_t = h_next.detach()
  # writer.add_scalar("train/loss", loss.item(), t)
  print(loss.item())
  t += 1

In [None]:
# @title
# use same Dummy as you defined above
B, H, D = 1, 30, 2
model = Dummy2(input_dim=D, hidden_dim=H).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
state_params = {k:v.to(device) for k,v in model.named_parameters() if k.startswith("state_")}
rtrl = BlockRTRL(state_params, B, H)
writer = SummaryWriter(log_dir="runs/anbn_Brtrl_0")
t = 0

def make_seq(len_max_a=8):
    # half the time make a valid a^n b^n, otherwise pick a mismatched n_b
    n_a = random.randint(1, len_max_a)
    if random.random() < 0.5:
        n_b = n_a
        tgt = torch.tensor([[1.0]]).to(device)
    else:
        n_b = random.choice([k for k in range(1,len_max_a+1) if k != n_a])
        tgt = torch.tensor([[0.0]]).to(device)
    return tgt, torch.cat([torch.tensor([[1.,0.]]).repeat(n_a,1), torch.tensor([[0.,1.]]).repeat(n_b,1)], dim=0).to(device)

In [None]:
# @title
n_steps = 100
for step in range(n_steps):
  tgt, x_seq = make_seq()
  h_t = torch.zeros(1, H).requires_grad_().to(device)
  rtrl.reset()
  for k in range(x_seq.size(0)):
      y, h_next = model(x_seq[k].unsqueeze(0), h_t)
      if k < x_seq.size(0) - 2:
          rtrl.step(model, x_seq[k].unsqueeze(0), h_t, None, active_params, [t % H, (t+1) % H])
          h_t = h_next.detach()
      if k == x_seq.size(0) - 2:
          h_t = h_next.detach().requires_grad_()

  loss = ((y - tgt) ** 2).mean()
  optimizer.zero_grad()
  # loss.backward()
  rtrl.step(model, x_seq[k].unsqueeze(0), h_t, loss, active_params, [t % H, (t+1) % H])
  nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  optimizer.step()
  writer.add_scalar("train/loss", loss.item(), t)
  writer.add_scalar("train/P_t", list(rtrl.P_t.values())[0].abs().mean().item(), t)
  t += 1

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs

# MoE

In [12]:
# @title
class ExpertBank(nn.Module):
    """
    E experts, each: y = act(W[e] @ x + b[e])
    Params are stored per-expert (ParameterList) but used via stacked banks in forward.
    """
    def __init__(self, E, d):
        super().__init__()
        self.E, self.d = E, d
        self.W = nn.ParameterList([nn.Parameter(torch.empty(d, d)) for _ in range(E)])
        self.b = nn.ParameterList([nn.Parameter(torch.empty(d))    for _ in range(E)])
        self.reset_parameters()

    def reset_parameters(self):
        for W, b in zip(self.W, self.b):
            nn.init.kaiming_uniform_(W, a=5**0.5)  # good for ReLU
            fan_in = W.size(1)
            bound = 1.0 / fan_in**0.5
            nn.init.uniform_(b, -bound, bound)

    def forward(self, x, w, idx):
        B, T, D = x.shape
        _, k = idx.shape

        W_bank = torch.stack(list(self.W), dim=0).contiguous()  # [E, D, D]
        b_bank = torch.stack(list(self.b), dim=0).contiguous()  # [E, D]

        flat_idx = idx.reshape(-1)
        W_sel = W_bank.index_select(0, flat_idx).reshape(B, k, D, D).contiguous()   # [B,k,D,D]
        b_sel = b_bank.index_select(0, flat_idx).reshape(B, k, D).contiguous()      # [B,k,D]

        y_k = torch.einsum('b t i, b k i o -> b t k o', x, W_sel) + b_sel.unsqueeze(1)
        y_k = F.relu(y_k)
        y = (y_k * w.view(B,1,k,1)).sum(dim=2)  # [B,T,D]
        return y


In [81]:
# @title
class MLP(nn.Module):
    def __init__(self, d, mult=1, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(d, mult * d)
        self.fc2 = nn.Linear(mult * d, d)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        return self.fc2(self.drop(F.gelu(self.fc1(x))))

class TopKGate(nn.Module):
    def __init__(self, d, n_experts, k=2):
        super().__init__()
        assert 1 <= k <= n_experts
        self.k, self.n = k, n_experts
        self.proj = nn.Linear(d, n_experts, bias=False)
    def forward(self, h):                         # h: [B,D]
        logits = self.proj(h)                     # [B,E]
        val, idx = torch.topk(logits, self.k, dim=-1)  # [B,k]
        return val.softmax(-1), idx               # normalize over top-k

# ---------- Recurrent MoE ----------
class RecurrentMoE(nn.Module):
    """
    Step:
      1) Mixed-attn: Q=latent, K/V=concat(latent, x)
      2) Pool latent -> gate -> mixture of experts (sequence-level)
      3) Write expert update back to latent slots chosen by a small gate
    """
    def __init__(self, d_model=512, n_heads=2, n_slots=4, n_experts=4, topk=2, dropout=0.0):
        super().__init__()
        d, E = d_model, n_experts
        self.d, self.n_slots = d, n_slots

        # Attention
        self.state_mha = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
        self.state_ln_q = nn.LayerNorm(d)
        self.state_ln_kv = nn.LayerNorm(d)
        self.state_ffn = MLP(d, dropout=dropout)
        self.state_ln_ffn = nn.LayerNorm(d)

        # Sequence level gating with topk expert
        self.state_gate = TopKGate(d, n_experts, k=topk)
        # self.state_experts = nn.ModuleList([MLP(d, dropout=dropout) for _ in range(n_experts)])
        self.state_experts = ExpertBank(E, d)
        self.state_ln_moe_in = nn.LayerNorm(d)
        self.state_latent_proj = nn.Linear(d, d)

        # Latent projection
        self.state_ln_slot = nn.LayerNorm(d)
        self.state_slot_ctx = nn.Linear(d, 1, bias=False)
        self.out = nn.Linear(n_slots*d, d)


    @torch.no_grad()
    def init_state(self, batch_size, device=None, dtype=None):
        device = device or next(self.parameters()).device
        dtype = dtype or next(self.parameters()).dtype
        latent = torch.randn(batch_size, self.n_slots*self.d, device=device, dtype=dtype) * 0.02
        return latent

    def forward(self, x, state_flat):
        B, S, D = state_flat.shape[0], self.n_slots, self.d
        device, dtype = state_flat.device, state_flat.dtype
        state = state_flat.reshape([B, S, D]).contiguous()
        latent = state.clone()

        # (1) Mixed attention
        q = self.state_ln_q(latent)
        kv = torch.cat([latent, x], dim=1)
        kv = self.state_ln_kv(kv)
        attn_out, _ = self.state_mha(q, kv, kv, need_weights=False)
        latent = latent + attn_out
        latent = latent + self.state_ffn(self.state_ln_ffn(latent))

        # (2) MoE with top-k experts (sequence-level) =1 here
        pooled = self.state_ln_moe_in(latent.mean(dim=1))       # [B,D]
        w, idx = self.state_gate(pooled)                        # [B,k]]
        # mixed = torch.zeros(B, S, D, device=latent.device, dtype=latent.dtype)
        # for i in range(self.state_gate.k):
        #     for e in range(len(self.state_experts)):
        #         mask = (idx[:, i] == e)                   # [B]
        #         if mask.any():
        #             acc = torch.einsum('b, b S D-> b S D', w[mask, i], self.state_experts[e](latent[mask]))
        #             mixed = mixed.index_put((mask,), acc, accumulate=True)
        mixed = self.state_experts(latent, w, idx)
        latent = latent + mixed

        # (3) Choose *one* target slot per sample and update state only there
        logits = self.state_slot_ctx(self.state_ln_slot(latent)).squeeze(-1)        # [B,S]
        w, tgt_idx = torch.topk(logits, 2, dim=-1)                          # [B,2]
        # for i in range(2):
        #     for s in range(S):
        #         mask = (tgt_idx[:, i] == s)
        #         acc = torch.einsum('b, b D-> b D', w[mask, i], latent[mask, s])
        #         state = state.index_put((mask, torch.tensor(s, device=device)),
        #                                 acc, accumulate=True)
        alpha = torch.zeros(B, S, device=device, dtype=dtype)
        alpha = alpha.scatter_add_(1, tgt_idx, w)
        state = state + alpha.unsqueeze(-1) * latent     # out-of-place residual update

        # Output and info
        y = self.out(latent.reshape([B, S*D]))
        info = {
            "idx_experts": idx.detach(),
            "idx_slots": tgt_idx.detach()
        }
        return y, info, state.reshape([B, S*D]).contiguous()

# ---------- Tiny usage ----------
B, T, D = 1, 8, 64
model = RecurrentMoE(d_model=D, n_heads=4, n_slots=64, n_experts=64, topk=2).to(device)
x = torch.randn(B, T, D).to(device)
pattern = re.compile(r'^state_experts\.(W|b)\.\d+$')
state_params = {name:p for name,p in model.named_parameters() if name.startswith("state_")}
core_params = {name:p for name, p in state_params.items() if pattern.match(name) is None}
expert_params = {name:p for name, p in state_params.items() if pattern.match(name)}
state = model.init_state(B, device=x.device).requires_grad_()
rtrl = BlockRTRL(state_params, B, state.shape[-1])

In [None]:
y, info, state_new = model(x, state)
loss = ((y - torch.zeros_like(y))**2).mean()

In [30]:
def get_expert_latent_activated(info):
    idx_slots = list(set(info['idx_slots'].flatten().tolist()))
    proj = sum((list(range(D*i, D*(i+1))) for i in idx_slots), start=[])
    expert_ids = list(set(info['idx_experts'].flatten().tolist()))
    ids_pattern = "|".join(map(str, expert_ids))
    pattern = re.compile(rf'^state_experts\.(W|b)\.({ids_pattern})$')
    active_experts_params = {name:p for name, p in state_params.items() if pattern.match(name)}
    active_params = core_params | active_experts_params
    return active_params, proj
active_params, proj = get_expert_latent_activated(info)

In [31]:
rtrl.step(model, x, state, loss, active_params, proj)
state = state_new.detach()

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [53]:
sum(p[1].numel() for p in rtrl.buffer.tree if p is not None)

122683392

In [55]:
sum(p.numel() for p in rtrl.P_t.values())

21004288

In [49]:
list(rtrl.P_t.values())[2].shape

torch.Size([1, 8192, 256])

In [85]:
SEQ_LEN    = 5000
DELAY_MIN  = 50
DELAY_MAX  = 200
TBPTT      = 32
H          = 128
D          = 16  # vocab

# --- Data: random tokens and per-step random delay ---
x = torch.randint(0, D, (SEQ_LEN,), dtype=torch.long)
d = torch.randint(DELAY_MIN, DELAY_MAX + 1, (SEQ_LEN,), dtype=torch.long)

# y_t = x_{t - d_t} if t - d_t >= 0 else ignore
y = torch.full((SEQ_LEN,), -100, dtype=torch.long)  # -100 = ignore_index
valid = torch.arange(SEQ_LEN) - d >= 0
idx = torch.nonzero(valid, as_tuple=False).squeeze(-1)
y[idx] = x[idx - d[idx]]
x, y = x.to(device), y.to(device)

# model = Dummy(input_dim=D, hidden_dim=H, output_dim=D).to(device)
model = RecurrentMoE(d_model=D, n_heads=4, n_slots=64, n_experts=64, topk=2).to(device)
pattern = re.compile(r'^state_experts\.(W|b)\.\d+$')
state_params = {name:p for name,p in model.named_parameters() if name.startswith("state_")}
core_params = {name:p for name, p in state_params.items() if pattern.match(name) is None}
expert_params = {name:p for name, p in state_params.items() if pattern.match(name)}

h = model.init_state(1, device=device).requires_grad_()
rtrl = BlockRTRL(state_params, B, h.shape[-1])
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")  # per-step

In [86]:
# --- Training: chunked outer loop, per-step inner loop; backward every step ---
model.train()
running = 0.0
for epoch in range(3):
    # h = torch.zeros(1, H, device=device)
    h = model.init_state(1, device=device).requires_grad_()
    for start in range(0, SEQ_LEN, TBPTT):
        end = min(start + TBPTT, SEQ_LEN)

        # one-hot inputs for this chunk
        x_chunk = F.one_hot(x[start:end], num_classes=D).float()  # [T, V]
        y_chunk = y[start:end]  # [T]
        loss = torch.zeros(1, device=device)

        for t in range(x_chunk.size(0)):
            xt = x_chunk[t][None, None, ...]        # [1, 1, D]
            yt = y_chunk[t].unsqueeze(0)            # [1]

            opt.zero_grad()
            logits, info, h_new = model(xt, h)             # logits: [1, D]
            active_params, proj = get_expert_latent_activated(info)
            loss_t = criterion(logits, yt).mean()  # scalar

            # Only backprop if target is valid; always advance state
            if yt.item() != -100:
                # rtrl.step(model, xt, h, loss_t, active_params, proj)
                # opt.step()
                loss = loss + loss_t
                running = (0.98 * running + 0.02 * loss_t.item()) if running else loss_t.item()
            else:
                # rtrl.step(model, xt, h, None, active_params, proj)
                pass
            h = h_new #.detach().requires_grad_()
            print(f"{running:.4f}")
        loss_t.backward()
        h = h.detach()
        opt.step()

    print(f"epoch {epoch+1} | running CE (masked): {running:.4f}")

# --- Eval on same stream (for simplicity) ---
# model.eval()
# with torch.no_grad():
#     h = torch.zeros(1, H, device=device)
#     preds = []
#     for t in range(SEQ_LEN):
#         xt = F.one_hot(x[t], num_classes=D).float().unsqueeze(0)
#         logits, h = model(xt, h)
#         preds.append(logits.argmax(dim=-1).item())
#     preds = torch.tensor(preds, device=device)

# mask = (y != -100)
# acc = (preds[mask] == y[mask]).float().mean().item()
# print(f"masked per-step accuracy: {acc*100:.2f}% "
#       f"(TBPTT={TBPTT}, delay∈[{DELAY_MIN},{DELAY_MAX}], H={H})")


0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5287
2.5347
2.5347
2.5407
2.5407
2.5407
2.5407
2.5407
2.5407
2.5494
2.5549
2.5549
2.5549
2.5549
2.5549
2.5549
2.5549
2.5549
2.5549
2.5549
2.5550
2.5611
2.5611
2.5611
2.5695
2.5695
2.5730
2.5737
2.5731
2.5731
2.5731
2.5731
2.5790
2.5831
2.5831
2.5894
2.5894
2.5894
2.5925
2.5982
2.5982
2.6026
2.6026
2.6026
2.6064
2.6064
2.6108
2.6166
2.6166
2.6166
2.6194
2.6245
2.6313
2.6364
2.6364
2.6411
2.6482
2.6482
2.6495
2.6495
2.6495
2.6539

KeyboardInterrupt: 

In [55]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7e3043f48dd0>

In [58]:
import torch, gc

def list_tensors_on_gpu():
    tensors = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                tensors.append(obj)
            elif hasattr(obj, "data") and torch.is_tensor(obj.data) and obj.data.is_cuda:
                tensors.append(obj.data)
        except:
            pass

    # Sort tensors by memory size (largest first)
    tensors = sorted(tensors, key=lambda t: t.numel() * t.element_size(), reverse=True)

    total = 0
    print(f"{'Idx':>3} {'Shape':>20} {'Dtype':>10} {'Size (MB)':>12} {'Device':>10}")
    print("-" * 60)
    for i, t in enumerate(tensors):
        size_mb = t.numel() * t.element_size() / (1024 ** 2)
        total += size_mb
        print(f"{i:3d} {str(tuple(t.shape)):>20} {str(t.dtype):>10} {size_mb:12.2f} {str(t.device):>10}")

    print("-" * 60)
    print(f"Total GPU tensor memory: {total:.2f} MB")

list_tensors_on_gpu()


  return isinstance(obj, torch.Tensor)
  elif hasattr(obj, "data") and torch.is_tensor(obj.data) and obj.data.is_cuda:


Idx                Shape      Dtype    Size (MB)     Device
------------------------------------------------------------
  0     (1, 4096, 12288) torch.float32       192.00     cuda:0
  1      (1, 4096, 4096) torch.float32        64.00     cuda:0
  2      (1, 4096, 4096) torch.float32        64.00     cuda:0
  3      (1, 4096, 4096) torch.float32        64.00     cuda:0
  4      (1, 4096, 4096) torch.float32        64.00     cuda:0
  5      (1, 4096, 4096) torch.float32        64.00     cuda:0
  6      (1, 4096, 4096) torch.float32        64.00     cuda:0
  7      (1, 4096, 4096) torch.float32        64.00     cuda:0
  8      (1, 8192, 1024) torch.float32        32.00     cuda:0
  9       (1, 8192, 768) torch.float32        24.00     cuda:0
 10       (1, 8192, 256) torch.float32         8.00     cuda:0
 11       (1, 8192, 256) torch.float32         8.00     cuda:0
 12       (1, 8192, 256) torch.float32         8.00     cuda:0
 13       (1, 8192, 256) torch.float32         8.00     cuda

In [59]:
import gc, torch
gc.collect(); torch.cuda.empty_cache()
