In [17]:
# Symmetry-Constrained PINN-RNN for Structured Drift Recovery
# ===========================================================
# This notebook trains on the FIRST 4 recorded days and then extrapolates to predict
# activity on Day 10 (a day never seen in data generation or training).
# It visualizes (i) latent drift, (ii) simulated activity for Day 10, and
# (iii) decoder accuracy on the extrapolated day.

# -----------------------
# Setup
# ad test
# -----------------------

import torch, numpy as np, matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchdiffeq import odeint
from sklearn.metrics import r2_score
from scipy.linalg import subspace_angles

plt.rcParams['figure.figsize'] = (6,4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# -----------------------
# Synthetic Data (10 recorded days)
# -----------------------
T_rec, S, K, N, rank = 10, 5, 40, 50, 3

W0 = 0.0005*torch.randn(N,N)
U,_,Vt = torch.linalg.svd(W0)
U, V = U[:,:rank], Vt[:rank,:].T
B = 0.05*torch.randn(N,1);  b = 0.1*torch.randn(N);  R = 0.05*torch.randn(1,N)

t_vals = torch.linspace(0,2*np.pi,T_rec)
z_true = torch.stack([torch.sin(t_vals+np.pi*i/rank)+0.1*t_vals for i in range(rank)],1)

x = torch.zeros(T_rec,S,K,N); v = torch.zeros(T_rec,S,K)
for t in range(T_rec):
    Wt = W0 + U @ torch.diag_embed(z_true[t].unsqueeze(0))[0] @ V.T
    for s in range(S):
        v_ts = 0.5*torch.randn(K)
        x_ts = torch.zeros(K,N); x_ts[0] = 0.1*torch.randn(N)
        for k in range(K-1):
            x_ts[k+1] = torch.tanh(Wt@x_ts[k] + (B*v_ts[k]).squeeze() + b)
        v[t,s] = v_ts; x[t,s] = x_ts

a = x.mean((0,1,2))  # fixed bias/excitability

# -----------------------
# Dataset / Dataloader (train on first 4 days)
# -----------------------
class SplitDS(Dataset):
    def __init__(self,x,v,days): self.x,self.v,self.days,self.S=x,v,days,x.shape[1]
    def __len__(self): return len(self.days)*self.S
    def __getitem__(self,idx): t= self.days[idx//self.S]; s=idx%self.S; return self.x[t,s],self.v[t,s],t

train_days = list(range(4))
loader = DataLoader(SplitDS(x,v,train_days),batch_size=4,shuffle=True)


Device: cuda


In [None]:
# ---- OPTIONAL: make connectivity produce a ring‑attractor (rank‑2 rotation)
# Set this flag to True to overwrite W0 with a rank‑2 rotational component
# embedded into the N‑dimensional network, then re‑generate x/v using the same
# generative loop so downstream cells remain unchanged.
RING_ATTRACTOR = True
if RING_ATTRACTOR:
    # Build an orthonormal 2D subspace (p,q) and embed a skew‑symmetric rotation there
    p = torch.randn(N)
    p = p / p.norm()
    q = torch.randn(N)
    q = q - (p * q).sum() * p  # make q orthogonal to p
    q = q / q.norm()
    P = torch.stack([p, q], dim=1)  # N x 2 basis
    omega = 0.9  # rotation strength (increase for faster/slower rotation)
    A = torch.tensor([[0., -omega], [omega, 0.]])
    # embed rotation and add small random noise so matrix is not exactly skew
    W_rot = 0.9 * (P @ A @ P.T)
    # New W0 has a dominant rotational rank‑2 component + small noise
    W0 = W_rot + 0.02 * torch.randn(N, N)
    # Recompute low‑rank factors U,V (keeps the same 'rank' used by the drift)
    U, _, Vt = torch.linalg.svd(W0)
    U, V = U[:, :rank], Vt[:rank, :].T
    print('Ring attractor W0 created (embedded rank‑2 rotation), omega=', omega)
    # Re-generate synthetic x,v using the same generator loop (keeps your drift z_true)
    x = torch.zeros(T_rec, S, K, N)
    v = torch.zeros(T_rec, S, K)
    for t in range(T_rec):
        Wt = W0 + U @ torch.diag_embed(z_true[t].unsqueeze(0))[0] @ V.T
        for s in range(S):
            v_ts = 0.5 * torch.randn(K)
            x_ts = torch.zeros(K, N); x_ts[0] = 0.1 * torch.randn(N)
            for k in range(K-1):
                x_ts[k+1] = torch.tanh(Wt @ x_ts[k] + (B * v_ts[k]).squeeze() + b)
            v[t, s] = v_ts; x[t, s] = x_ts
    a = x.mean((0,1,2))  # recompute dataset mean/bias
else:
    print('RING_ATTRACTOR=False : keeping original W0 and data generation')

Ring attractor W0 created (embedded rank‑2 rotation), omega= 0.9
