Hcekcing the compute aspec and epoch run time per second 

In [None]:
# ===============================================================
#  DESTINODE with neuron & lap subsampling  (variable‑K, u‑decoder)
# ===============================================================
import numpy as np, torch, matplotlib.pyplot as plt, torch.nn as nn, random
from torch.utils.data import Dataset, DataLoader
from torchdiffeq import odeint

# ------------------- 0. KNOBS -------------------
KEEP_NEURON_FRAC = 0.05          # 0–1   (20 % of neurons)
KEEP_LAP_FRAC    = 0.4          # 0–1   (50 % of laps per session)
RNG_SEED         = 42            # for reproducibility
torch.manual_seed(RNG_SEED); np.random.seed(RNG_SEED); random.seed(RNG_SEED)

# ------------------ 1. SETTINGS ------------------
SUBSET_FILE   = "destin_debug_subset.npz"
rank          = 3
train_days    = [0, 1, 2, 3]
test_day      = 6
epochs        = 10
LAMBDA_SMOOTH = 5e-2
device        = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------ 2. LOAD & SUBSAMPLE ----------
data = np.load(SUBSET_FILE, allow_pickle=True)
sess_ids = sorted({int(k.split('_')[1]) for k in data if k.startswith('dff_')})

# 2a. decide once‑and‑for‑all neuron mask
N_full = data['dff_0'].shape[0]
n_keep = int(np.ceil(N_full * KEEP_NEURON_FRAC))
keep_neurons = np.random.choice(N_full, n_keep, replace=False)
keep_neurons.sort()                                      # keep order

to_tensor = lambda x: torch.tensor(x, dtype=torch.float32)
x_sess = []; v_sess = []; u_sess = []
for t in sess_ids:
    dff_full = data[f'dff_{t}'][keep_neurons]             # [N_keep, K_vr]
    dff_full = to_tensor(dff_full.T)                      # [K_vr, N_keep]
    vel_full = to_tensor(data[f'vel_{t}'])
    pos_full = to_tensor(data[f'pos_{t}'])
    pos_full = (pos_full - pos_full.mean()) / pos_full.std()

    lap_idx  = data[f'lap_idx_{t}']

    # 2b. optional lap subsampling
    all_laps = np.unique(lap_idx)
    n_laps_keep = int(np.ceil(len(all_laps) * KEEP_LAP_FRAC))
    keep_laps = np.random.choice(all_laps, n_laps_keep, replace=False)

    laps_x = []; laps_v = []; laps_u = []
    for s in sorted(keep_laps):
        mask = lap_idx == s
        if not mask.any(): continue
        laps_x.append(dff_full[mask])
        laps_v.append(vel_full[mask])
        laps_u.append(pos_full[mask])

    x_sess.append(laps_x); v_sess.append(laps_v); u_sess.append(laps_u)

T, N = len(x_sess), len(keep_neurons)
print(f"Sessions: {T} | Neurons kept: {N}/{N_full} "
      f"| Example lap shape: {x_sess[0][0].shape}")

# ------------------ 3. DATASET -------------------
class RaggedSplitDS(Dataset):
    def __init__(self, x, v, u, days):
        self.x,self.v,self.u = x,v,u
        self.map = [(t,s) for t in days for s in range(len(x[t]))]
    def __len__(self): return len(self.map)
    def __getitem__(self, idx):
        t,s = self.map[idx]
        return self.x[t][s], self.v[t][s], self.u[t][s], t

loader = DataLoader(RaggedSplitDS(x_sess,v_sess,u_sess,train_days),
                    batch_size=1, shuffle=True, collate_fn=lambda b:b)

# ------------------ 4. MODEL (unchanged except N) -----------
W0 = 0.05*torch.randn(N,N)
U,_,Vt = torch.linalg.svd(W0)
U = U[:,:rank]; V = Vt[:rank,:].T
B = 0.05*torch.randn(N,1); R = 0.05*torch.randn(1,N)

# ===============================================================
#  DAY‑0‑DRIVEN INITIALISATION  (insert before “# ------------------ 4. MODEL ---------------------”)
# ===============================================================
print("‣ deriving Day‑0 statistics for W0 and b_fixed …")

# -------- gather all frames from session‑0 after subsampling --------
X0 = torch.cat(x_sess[0], dim=0)           # [K_total_day0 , N]
X0 = X0.float()                            # ensure fp32

# -------- 1) activation bias  b_fixed -------------------------------
avg_rate = X0.mean(0)                      # mean across time
b_init   = -(avg_rate - avg_rate.mean())   # anti‑correlated w/ firing
b_init   = (0.5 / b_init.abs().max()) * b_init   # scale to ±0.5

# # -------- 2) connectivity seed  W0 ---------------------------------
# #   use Pearson correlation matrix of day‑0 activity
# X0c   = X0 - avg_rate                      # centre
# std   = X0c.std(0, unbiased=False) + 1e-9
# corr  = (X0c.T @ X0c) / X0c.shape[0]
# corr  = corr / (std[:,None] * std[None,:]) # now ρ_{ij} in [‑1,1]
# W0_init = 0.05 * corr                      # match earlier 0.05 scale

# # -------- 3) low‑rank bases from new W0 ----------------------------
# U,_,Vt = torch.linalg.svd(W0_init.cpu())
# U = U[:,:rank]; V = Vt[:rank,:].T

# print(f"   » b range  {b_init.min():.2f} … {b_init.max():.2f}")
# print(f"   » W0 diag mean±sd  {W0_init.diag().mean():.3f} ± {W0_init.diag().std():.3f}")

# # keep tensors for the model section
# W0 = W0_init.clone()
# B  = 0.05 * torch.randn(N,1)
# R  = 0.05 * torch.randn(1,N)


class SlowODE(nn.Module):
    def __init__(self,r):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(r,64), nn.Tanh(), nn.Linear(64,r))
    def forward(self,t,z): return self.net(z)

class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.U,self.V,self.W0 = U.to(device), V.to(device), W0.to(device)
        self.z0 = nn.Parameter(torch.zeros(rank))
        self.slow = SlowODE(rank)
        self.B = nn.Parameter(B.clone()); self.R = nn.Parameter(R.clone())
        self.b_fixed = nn.Parameter(torch.zeros(N), requires_grad=True)
        # self.b_fixed = nn.Parameter(b_init.to(device), requires_grad=False)

    def weights(self,T):
        t = torch.arange(T, dtype=torch.float32, device=self.z0.device)
        z = odeint(self.slow, self.z0[None], t).squeeze(1)       # [T,r]
        Ud   = self.U[None] * z[:,None]                          # [T,N,r]
        Vt   = self.V.T.expand(T, -1, -1)                        # [T,r,N]
        Wadd = torch.bmm(Ud, Vt)                                 # [T,N,N]
        return self.W0[None] + Wadd
    def rnn_cell(self,xk,vk,W):
        return torch.tanh(W@xk + (self.B*vk).squeeze() + self.b_fixed)
    def forward(self,x_gt,v_seq,u_seq,t_idx,Ws):
        K = x_gt.size(0); W = Ws[t_idx]
        rec = dec = 0.; x_prev = x_gt[0]
        for k in range(K-1):
            x_pred = self.rnn_cell(x_prev, v_seq[k], W)
            rec += (x_pred - x_gt[k+1]).pow(2).sum()
            dec += (self.R @ x_pred - u_seq[k]).pow(2)
            x_prev = x_pred
            # ----- inside PINN.forward  (keep all code above unchanged) -----
            lambda_rec = 1.0      # full weight on reconstruction
            lambda_dec = 0.1      # put decoder on same numeric scale

        return lambda_rec * rec + lambda_dec * dec
        

model = PINN().to(device)
opt   = torch.optim.Adam(model.parameters(), lr=1e-4)

# ------------------ 5. TRAIN ---------------------
for ep in range(epochs):
    total=0.
    for batch in loader:
        Ws = model.weights(T)
        losses=[]
        for x_gt,v_gt,u_gt,t_idx in batch:
            x_gt,v_gt,u_gt = (x_gt.to(device), v_gt.to(device), u_gt.to(device))
            losses.append(model(x_gt,v_gt,u_gt,t_idx,Ws))
        rec_dec = torch.stack(losses).mean()
        smooth  = (Ws[1:]-Ws[:-1]).pow(2).sum()
        loss = rec_dec + LAMBDA_SMOOTH*smooth
        opt.zero_grad(); loss.backward(); opt.step(); total+=loss.item()
    if ep%5==0:
        print(f"Epoch {ep:03d} | total {total/len(loader):.3e} "
              f"| rec+dec {rec_dec.item():.3e}")




Sessions: 8 | Neurons kept: 16/305 | Example lap shape: torch.Size([668, 16])
‣ deriving Day‑0 statistics for W0 and b_fixed …
Epoch 000 | total 3.649e+03 | rec+dec 3.429e+03
Epoch 005 | total 2.192e+03 | rec+dec 1.854e+03
Epoch 010 | total 8.811e+02 | rec+dec 9.773e+02
Epoch 015 | total 5.657e+02 | rec+dec 8.556e+02
Epoch 020 | total 5.476e+02 | rec+dec 6.504e+02
Epoch 025 | total 5.438e+02 | rec+dec 7.545e+02
Epoch 030 | total 5.410e+02 | rec+dec 3.101e+02
Epoch 035 | total 5.390e+02 | rec+dec 6.378e+02


KeyboardInterrupt: 

In [2]:
import time, platform, torch
import numpy as np

def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def get_device_info(device):
    if device.type == "cuda":
        i = torch.cuda.current_device()
        name = torch.cuda.get_device_name(i)
        cap = torch.cuda.get_device_capability(i)
        mem_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3)
        return f"{name} (CUDA cc {cap[0]}.{cap[1]}, {mem_gb:.1f} GB)"
    return "CPU"

# ---- static info ----
print("=== Compute / system ===")
print("PyTorch:", torch.__version__)
print("Device:", device)
print("GPU:", get_device_info(device))
print("CPU:", platform.processor() or platform.platform())
print("Trainable params:", f"{count_params(model):,}")

# ---- data scale ----
num_sessions = len(x_sess)
laps_per_day = [len(x_sess[d]) for d in range(num_sessions)]
Ks = [x.shape[0] for d in range(num_sessions) for x in x_sess[d]]
print("\n=== Data scale ===")
print("Sessions:", num_sessions)
print("Laps per session:", laps_per_day)
print("Seq length K: mean±sd =", f"{np.mean(Ks):.1f}±{np.std(Ks):.1f}",
      "| min/max =", (int(np.min(Ks)), int(np.max(Ks))))
print("Neurons N:", x_sess[0][0].shape[1])

# ---- quick timing: a few training steps (no full retrain) ----
def time_steps(n_steps=20):
    model.train()
    t0 = time.time()
    it = iter(loader)
    for _ in range(n_steps):
        batch = next(it)
        Ws = model.weights(T)
        losses=[]
        for x_gt, v_gt, u_gt, t_idx in batch:
            x_gt, v_gt, u_gt = x_gt.to(device), v_gt.to(device), u_gt.to(device)
            losses.append(model(x_gt, v_gt, u_gt, t_idx, Ws))
        rec_dec = torch.stack(losses).mean()
        smooth  = (Ws[1:]-Ws[:-1]).pow(2).sum()
        loss = rec_dec + LAMBDA_SMOOTH*smooth
        loss.backward()
        model.zero_grad(set_to_none=True)
    if device.type == "cuda":
        torch.cuda.synchronize()
    dt = time.time() - t0
    return dt / n_steps

sec_per_step = time_steps(n_steps=20)
print("\n=== Timing (approx) ===")
print(f"Seconds per optimization step (batch_size=1): {sec_per_step:.4f}s")
print("Note: total runtime ≈ sec_per_step × (#batches per epoch) × (#epochs)")


=== Compute / system ===
PyTorch: 2.8.0+cu128
Device: cuda
GPU: NVIDIA H200 (CUDA cc 9.0, 139.8 GB)
CPU: x86_64
Trainable params: 502

=== Data scale ===
Sessions: 8
Laps per session: [12, 12, 14, 14, 12, 12, 13, 12]
Seq length K: mean±sd = 692.6±148.5 | min/max = (419, 1521)
Neurons N: 16

=== Timing (approx) ===
Seconds per optimization step (batch_size=1): 0.2561s
Note: total runtime ≈ sec_per_step × (#batches per epoch) × (#epochs)


High neuron number

In [3]:
# ===============================================================
#  DESTINODE with neuron & lap subsampling  (variable‑K, u‑decoder)
# ===============================================================
import numpy as np, torch, matplotlib.pyplot as plt, torch.nn as nn, random
from torch.utils.data import Dataset, DataLoader
from torchdiffeq import odeint

# ------------------- 0. KNOBS -------------------
KEEP_NEURON_FRAC = 0.5          # 0–1   (20 % of neurons)
KEEP_LAP_FRAC    = 0.4          # 0–1   (50 % of laps per session)
RNG_SEED         = 42            # for reproducibility
torch.manual_seed(RNG_SEED); np.random.seed(RNG_SEED); random.seed(RNG_SEED)

# ------------------ 1. SETTINGS ------------------
SUBSET_FILE   = "destin_debug_subset.npz"
rank          = 3
train_days    = [0, 1, 2, 3]
test_day      = 6
epochs        = 10
LAMBDA_SMOOTH = 5e-2
device        = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------ 2. LOAD & SUBSAMPLE ----------
data = np.load(SUBSET_FILE, allow_pickle=True)
sess_ids = sorted({int(k.split('_')[1]) for k in data if k.startswith('dff_')})

# 2a. decide once‑and‑for‑all neuron mask
N_full = data['dff_0'].shape[0]
n_keep = int(np.ceil(N_full * KEEP_NEURON_FRAC))
keep_neurons = np.random.choice(N_full, n_keep, replace=False)
keep_neurons.sort()                                      # keep order

to_tensor = lambda x: torch.tensor(x, dtype=torch.float32)
x_sess = []; v_sess = []; u_sess = []
for t in sess_ids:
    dff_full = data[f'dff_{t}'][keep_neurons]             # [N_keep, K_vr]
    dff_full = to_tensor(dff_full.T)                      # [K_vr, N_keep]
    vel_full = to_tensor(data[f'vel_{t}'])
    pos_full = to_tensor(data[f'pos_{t}'])
    pos_full = (pos_full - pos_full.mean()) / pos_full.std()

    lap_idx  = data[f'lap_idx_{t}']

    # 2b. optional lap subsampling
    all_laps = np.unique(lap_idx)
    n_laps_keep = int(np.ceil(len(all_laps) * KEEP_LAP_FRAC))
    keep_laps = np.random.choice(all_laps, n_laps_keep, replace=False)

    laps_x = []; laps_v = []; laps_u = []
    for s in sorted(keep_laps):
        mask = lap_idx == s
        if not mask.any(): continue
        laps_x.append(dff_full[mask])
        laps_v.append(vel_full[mask])
        laps_u.append(pos_full[mask])

    x_sess.append(laps_x); v_sess.append(laps_v); u_sess.append(laps_u)

T, N = len(x_sess), len(keep_neurons)
print(f"Sessions: {T} | Neurons kept: {N}/{N_full} "
      f"| Example lap shape: {x_sess[0][0].shape}")

# ------------------ 3. DATASET -------------------
class RaggedSplitDS(Dataset):
    def __init__(self, x, v, u, days):
        self.x,self.v,self.u = x,v,u
        self.map = [(t,s) for t in days for s in range(len(x[t]))]
    def __len__(self): return len(self.map)
    def __getitem__(self, idx):
        t,s = self.map[idx]
        return self.x[t][s], self.v[t][s], self.u[t][s], t

loader = DataLoader(RaggedSplitDS(x_sess,v_sess,u_sess,train_days),
                    batch_size=1, shuffle=True, collate_fn=lambda b:b)

# ------------------ 4. MODEL (unchanged except N) -----------
W0 = 0.05*torch.randn(N,N)
U,_,Vt = torch.linalg.svd(W0)
U = U[:,:rank]; V = Vt[:rank,:].T
B = 0.05*torch.randn(N,1); R = 0.05*torch.randn(1,N)

# ===============================================================
#  DAY‑0‑DRIVEN INITIALISATION  (insert before “# ------------------ 4. MODEL ---------------------”)
# ===============================================================
print("‣ deriving Day‑0 statistics for W0 and b_fixed …")

# -------- gather all frames from session‑0 after subsampling --------
X0 = torch.cat(x_sess[0], dim=0)           # [K_total_day0 , N]
X0 = X0.float()                            # ensure fp32

# -------- 1) activation bias  b_fixed -------------------------------
avg_rate = X0.mean(0)                      # mean across time
b_init   = -(avg_rate - avg_rate.mean())   # anti‑correlated w/ firing
b_init   = (0.5 / b_init.abs().max()) * b_init   # scale to ±0.5

# # -------- 2) connectivity seed  W0 ---------------------------------
# #   use Pearson correlation matrix of day‑0 activity
# X0c   = X0 - avg_rate                      # centre
# std   = X0c.std(0, unbiased=False) + 1e-9
# corr  = (X0c.T @ X0c) / X0c.shape[0]
# corr  = corr / (std[:,None] * std[None,:]) # now ρ_{ij} in [‑1,1]
# W0_init = 0.05 * corr                      # match earlier 0.05 scale

# # -------- 3) low‑rank bases from new W0 ----------------------------
# U,_,Vt = torch.linalg.svd(W0_init.cpu())
# U = U[:,:rank]; V = Vt[:rank,:].T

# print(f"   » b range  {b_init.min():.2f} … {b_init.max():.2f}")
# print(f"   » W0 diag mean±sd  {W0_init.diag().mean():.3f} ± {W0_init.diag().std():.3f}")

# # keep tensors for the model section
# W0 = W0_init.clone()
# B  = 0.05 * torch.randn(N,1)
# R  = 0.05 * torch.randn(1,N)


class SlowODE(nn.Module):
    def __init__(self,r):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(r,64), nn.Tanh(), nn.Linear(64,r))
    def forward(self,t,z): return self.net(z)

class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.U,self.V,self.W0 = U.to(device), V.to(device), W0.to(device)
        self.z0 = nn.Parameter(torch.zeros(rank))
        self.slow = SlowODE(rank)
        self.B = nn.Parameter(B.clone()); self.R = nn.Parameter(R.clone())
        self.b_fixed = nn.Parameter(torch.zeros(N), requires_grad=True)
        # self.b_fixed = nn.Parameter(b_init.to(device), requires_grad=False)

    def weights(self,T):
        t = torch.arange(T, dtype=torch.float32, device=self.z0.device)
        z = odeint(self.slow, self.z0[None], t).squeeze(1)       # [T,r]
        Ud   = self.U[None] * z[:,None]                          # [T,N,r]
        Vt   = self.V.T.expand(T, -1, -1)                        # [T,r,N]
        Wadd = torch.bmm(Ud, Vt)                                 # [T,N,N]
        return self.W0[None] + Wadd
    def rnn_cell(self,xk,vk,W):
        return torch.tanh(W@xk + (self.B*vk).squeeze() + self.b_fixed)
    def forward(self,x_gt,v_seq,u_seq,t_idx,Ws):
        K = x_gt.size(0); W = Ws[t_idx]
        rec = dec = 0.; x_prev = x_gt[0]
        for k in range(K-1):
            x_pred = self.rnn_cell(x_prev, v_seq[k], W)
            rec += (x_pred - x_gt[k+1]).pow(2).sum()
            dec += (self.R @ x_pred - u_seq[k]).pow(2)
            x_prev = x_pred
            # ----- inside PINN.forward  (keep all code above unchanged) -----
            lambda_rec = 1.0      # full weight on reconstruction
            lambda_dec = 0.1      # put decoder on same numeric scale

        return lambda_rec * rec + lambda_dec * dec
        

model = PINN().to(device)
opt   = torch.optim.Adam(model.parameters(), lr=1e-4)

# ------------------ 5. TRAIN ---------------------
for ep in range(epochs):
    total=0.
    for batch in loader:
        Ws = model.weights(T)
        losses=[]
        for x_gt,v_gt,u_gt,t_idx in batch:
            x_gt,v_gt,u_gt = (x_gt.to(device), v_gt.to(device), u_gt.to(device))
            losses.append(model(x_gt,v_gt,u_gt,t_idx,Ws))
        rec_dec = torch.stack(losses).mean()
        smooth  = (Ws[1:]-Ws[:-1]).pow(2).sum()
        loss = rec_dec + LAMBDA_SMOOTH*smooth
        opt.zero_grad(); loss.backward(); opt.step(); total+=loss.item()
    if ep%5==0:
        print(f"Epoch {ep:03d} | total {total/len(loader):.3e} "
              f"| rec+dec {rec_dec.item():.3e}")




Sessions: 8 | Neurons kept: 153/305 | Example lap shape: torch.Size([668, 153])
‣ deriving Day‑0 statistics for W0 and b_fixed …
Epoch 000 | total 3.247e+04 | rec+dec 2.885e+04
Epoch 005 | total 1.529e+04 | rec+dec 1.323e+04


KeyboardInterrupt: 

In [4]:
import time, platform, torch
import numpy as np

def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def get_device_info(device):
    if device.type == "cuda":
        i = torch.cuda.current_device()
        name = torch.cuda.get_device_name(i)
        cap = torch.cuda.get_device_capability(i)
        mem_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3)
        return f"{name} (CUDA cc {cap[0]}.{cap[1]}, {mem_gb:.1f} GB)"
    return "CPU"

# ---- static info ----
print("=== Compute / system ===")
print("PyTorch:", torch.__version__)
print("Device:", device)
print("GPU:", get_device_info(device))
print("CPU:", platform.processor() or platform.platform())
print("Trainable params:", f"{count_params(model):,}")

# ---- data scale ----
num_sessions = len(x_sess)
laps_per_day = [len(x_sess[d]) for d in range(num_sessions)]
Ks = [x.shape[0] for d in range(num_sessions) for x in x_sess[d]]
print("\n=== Data scale ===")
print("Sessions:", num_sessions)
print("Laps per session:", laps_per_day)
print("Seq length K: mean±sd =", f"{np.mean(Ks):.1f}±{np.std(Ks):.1f}",
      "| min/max =", (int(np.min(Ks)), int(np.max(Ks))))
print("Neurons N:", x_sess[0][0].shape[1])

# ---- quick timing: a few training steps (no full retrain) ----
def time_steps(n_steps=20):
    model.train()
    t0 = time.time()
    it = iter(loader)
    for _ in range(n_steps):
        batch = next(it)
        Ws = model.weights(T)
        losses=[]
        for x_gt, v_gt, u_gt, t_idx in batch:
            x_gt, v_gt, u_gt = x_gt.to(device), v_gt.to(device), u_gt.to(device)
            losses.append(model(x_gt, v_gt, u_gt, t_idx, Ws))
        rec_dec = torch.stack(losses).mean()
        smooth  = (Ws[1:]-Ws[:-1]).pow(2).sum()
        loss = rec_dec + LAMBDA_SMOOTH*smooth
        loss.backward()
        model.zero_grad(set_to_none=True)
    if device.type == "cuda":
        torch.cuda.synchronize()
    dt = time.time() - t0
    return dt / n_steps

sec_per_step = time_steps(n_steps=20)
print("\n=== Timing (approx) ===")
print(f"Seconds per optimization step (batch_size=1): {sec_per_step:.4f}s")
print("Note: total runtime ≈ sec_per_step × (#batches per epoch) × (#epochs)")


=== Compute / system ===
PyTorch: 2.8.0+cu128
Device: cuda
GPU: NVIDIA H200 (CUDA cc 9.0, 139.8 GB)
CPU: x86_64
Trainable params: 913

=== Data scale ===
Sessions: 8
Laps per session: [12, 12, 14, 14, 12, 12, 13, 12]
Seq length K: mean±sd = 692.6±148.5 | min/max = (419, 1521)
Neurons N: 153

=== Timing (approx) ===
Seconds per optimization step (batch_size=1): 0.2644s
Note: total runtime ≈ sec_per_step × (#batches per epoch) × (#epochs)
