In [None]:
import numpy as np
from torch.utils.data import DataLoader
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.model import *
from rnn_scripts.train import *
from rnn_scripts.utils import *
import matplotlib.pyplot as plt
%matplotlib inline
from tasks.seqDS import seqDS
from sklearn.decomposition import PCA
from mayavi import mlab
mlab.init_notebook()
_,_,lut1,lut2 = np.load('../data/luts.npy')


In [None]:
# Load model
model_dir = os.getcwd()+"/../models/"
model="fr_lfp_Trained" # Phase solution
#model ="N100_T1208-100644" #Rate solution

rnn,params,task_params,training_params = load_rnn(model_dir+model)

# some processing / extraction
make_deterministic(task_params, rnn)
set_dt(task_params,rnn,0.5)
ts = task_params["dt"] / 1000
ps = 1 / task_params["freq"]
period = ps / ts
qp = int(period / 4)
period = int(period)
task_params['probe_dur']=5


In [None]:
# setup task and run model
ds_det = seqDS(task_params)#
dataloader_det = DataLoader(
    ds_det, batch_size=12, shuffle=True
)
test_input, test_target, test_mask = next(iter(dataloader_det))
rates, pred = predict(rnn, test_input,mse_loss, test_target, test_mask)
labels = extract_labels(test_input, rnn.params["n_inp"] - 1)
trials = [np.where(labels == i)[0][0] for i in range(rnn.params["n_inp"] - 1)]

In [None]:
# Find low-D orthogonal basis
rates_r = rates[:,-period:].reshape(-1, rates.shape[-1])

# find recurrent subspace by PCA of Jtanh(X)
n_comp=2
pca = PCA(n_components=n_comp)
J=np.copy(rnn.rnn.w_rec.detach().numpy())
I = np.copy(rnn.rnn.w_inp[0].detach().numpy())
pca.fit(np.tanh(rates_r)@J.T)
   
# orthogonalise I
I_orth= np.copy(I)
for i in range(2):
    projt,alpha = orth_proj(pca.components_[i],I_orth)
    I_orth-=projt

# calculate variance explained by this basis
S=np.sum(np.var(rates_r,axis=0))
v1 = pca.components_[0]
v2 = pca.components_[1]
v3 = I_orth/np.linalg.norm(I_orth)
var =np.var(rates_r@v1)+np.var(rates_r@v2)+np.var(rates_r@v3)

print(var/S)


In [None]:
# Project data on the basis

ks = np.zeros((len(trials), 2, period + 1))
phases = np.linspace(0, np.pi * 2, period + 1)

for i, ind in enumerate(trials):
    k1=rates[ind,-period:]@pca.components_[0]
    k2=rates[ind,-period:]@pca.components_[1]    
    phase = (
        np.arctan2(
            test_input[ind, -(qp + period) : -qp, 0], test_input[ind, -period:, 0]
        )
        .cpu()
        .numpy()
    )
    phase = wrap(phase)
    
    # sort such that all trajectories start with the same phase
    time_ind = np.arange(len(phase))
    time_ind = np.roll(time_ind, -np.argmin(phase))
    ks[i, 0, :period] = k1[time_ind]
    ks[i, 1, :period] = k2[time_ind]
ks[:, :, -1] = ks[:, :, 0]

In [None]:
# Plot trajectories
mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))
tw=1.8
floor=-40
luts = [lut1,lut2]
cvs = np.sin((phases))

#plot floor
r=35
r_s=20
torus=def_torus(r,r_s)
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

# plot trajectories
for i in range(2):
    x1,y1,z1 = tor(ks[i,1],ks[i,0], phases,r)
    surf1 = mlab.plot3d(x1, y1, z1,
                        cvs, tube_radius=tw, colormap='cool')

    #set colormap
    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]

    #set lightning
    surf1.actor.property.lighting = False
   
    # plot shadow
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')
    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False


mlab.draw()
mlab.view(azimuth=-20, elevation=67, distance=36, 
          focalpoint=np.array([ 0,  0, -3.5]))
mlab.plot3d(0,0,0)



