In [None]:
import numpy as np
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.train import *
from rnn_scripts.utils import *
from rnn_scripts.coupled_oscillators import *
import matplotlib.pyplot as plt
%matplotlib inline

from mayavi import mlab
mlab.init_notebook()


lut1,lut2,lut1a,lut2a = np.load('../data/luts2.npy')

In [None]:
# Load models
fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"

#model = "N512_T0217-141442" #rat 1
model = "N512_T0217-151523" #rat 2
#model = "N512_T0217-151542" #rat 3

rnn,params,task_params,training_params = load_rnn(model_dir+model)


In [None]:
# Some preprocessing / extracting parameters
dt =.5
rnn.rnn.svd_orth()
set_dt(task_params,rnn,dt)
make_deterministic(task_params, rnn)
I,n,m,W = extract_loadings(rnn, orth_I=False,split=True)
alphas, I_orth = orthogonolise_Im(I,m)
freq=8.4
rad=calculate_mean_radius(freq, rnn)


In [None]:
# Extract coupling function from trained model and plot

# set/extract some parameters
N = params["n_rec"]
n_inp = params['n_inp']
w= np.pi*2*freq
tau_sw = (rnn.rnn.tau/1000)*w
inp = np.zeros(n_inp-1)

# set up coupled oscillators
co = coupled_oscillators(tau=rnn.rnn.tau/1000,freq=freq,m=m,n=n,I_orth=I_orth,alphas=alphas,rad=rad,amp=1)

# make inputs
inps=[]
for i in range(n_inp):
    inp=np.zeros((n_inp-1))
    if i>0:
        inp[i-1]=1
    inps.append(inp)

# settings for trajectories
init_state = np.zeros(2)
dur = 2
dt  = 0.0005
T_st = dur-(1/freq)
T_st=int(T_st/dt)-2
init_phis = np.arange(0,np.pi*2,np.pi)
init_thetas = [0]


#start plotting
lw=2
x = np.arange(-np.pi,np.pi,np.pi/50)
y = np.arange(-np.pi,np.pi,np.pi/50)
grid_size = len(x)

# get vmax
vmax = np.max([abs(co.plot_coupling(x,y,inp=inp)+co.w) for inp in inps])

for inp in inps:
    fig = plt.figure(figsize = (6,6))
    ax = fig.add_subplot(111)

    # stable and unstable trajectories
    all_states = co.run_sims(init_phis,init_thetas,dur,dt,inp,forward=True)
    for states in all_states:
        x_traj= rescale(wrap(states[T_st:,0]),grid_size)
        y_traj= rescale(wrap(states[T_st:,1]),grid_size)
        x_traj,y_traj = detect_breaks(x_traj,y_traj,tol=3)
        plt.plot(x_traj,y_traj, color ='black',lw=lw)

    all_states = co.run_sims(init_phis,init_thetas,dur,dt,inp,forward=False)
    for states in all_states:
        x_traj= rescale(wrap(states[T_st:,0]),grid_size)
        y_traj= rescale(wrap(states[T_st:,1]),grid_size)
        x_traj,y_traj = detect_breaks(x_traj,y_traj,tol=3)
        plt.plot(x_traj,y_traj, color ='grey',ls='--',lw=lw)

    # coupling function
    grid=co.plot_coupling(x,y,inp=inp)+co.w
    im = plt.imshow(grid,origin='lower',cmap='coolwarm',vmin=-vmax,vmax=vmax)
    cbar = plt.colorbar(im,fraction=0.046, pad=0.04,ticks=[-w,0,w])

    # plot settings
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xticks([])#[-np.pi, 0, np.pi])
    ax.set_yticks([])#-np.pi, 0, np.pi])
    ax.set_xlim([0,len(grid)])
    ax.set_ylim([-0,len(grid)])
    plt.yticks([0,len(grid)],labels=[]) 
    plt.xticks([0,len(grid)],labels=[]) 

    plt.savefig(fig_dir+"/coupling"+str(inp).replace('.','').replace(' ','')+".svg")



In [None]:
#Initialize a Meanfield RNN with a specific covariance matrix

dt = 2

# RNN settings
params =  {
    "nonlinearity" : "tanh",
    "out_nonlinearity":"tanh",
    "readout_kappa":False,
    "train_meanfield":True,
    "n_supports":4,
    "train_cov":False,
    "rank" : 2,
    "n_inp" : 3,
    "p_inp" : 1,
    "n_rec" :5,
    "p_rec" : 1,
    "n_out" : 1,
    "cov": None,
    "loadings":None,
    "cov_init_noise":0,
    "scale_w_inp": 1,
    "scale_w_out": 1,
    "scale_n": 1,
    "scale_m": 1,
    "1overN_out_scaling":True,
    "train_w_inp" :False,
    "train_w_inp_scale" : False,
    "train_w_rec": False,
    "train_b_rec": False,
    "train_m": False,
    "train_n": False,
    "train_taus":False,
    "train_w_out":False,
    "train_w_out_scale":False,
    "train_b_out": False,
    "train_x0":False,
    "tau_lims": [30],
    "dt" : dt,
    "noise_std":0.05,
    "scale_x0": 0.1,
    "randomise_x0":True,
    "orth_indices":[]
}

# Covariance settings
cov_params={}
cov_params["osc_w"]=1.56*0.4*np.pi
cov_params["osc_r"]=1.5
cov_params["osc_sdn"]= 10
cov_params["osc_sdm"]= 2.5
cov_params["osc_sdmW"]= 1
cov_params["osc_sdW"]= 16
    
cov_params["coupl_sdnimi"]= 3
cov_params["coupl_sdnimj"]= 1
cov_params["coupl_sdIosc"]= 5
cov_params["coupl_sdIstim"]= 5
cov_params["coupl_sdIn"]= 3
cov_params["coupl_sdIm"]= 0
cov_params["coupl_sdn"]= 10
cov_params["coupl_sdm"]= 2
cov_params["coupl_sdmW"]= 0
cov_params["coupl_sdW"]= 16

chol_covs = create_MF_covs(cov_params, plot=True,vm=12)

params['dt']=dt

rnn = RNN(params)
with torch.no_grad():
    rnn.rnn.cov_chols.copy_(torch.from_numpy(chol_covs))

In [None]:
# Run simulations with the reduced equations

r_range = np.arange(0.5,0.6,0.2)
phi_range = np.arange(-np.pi,np.pi,np.pi*3)
theta_range = np.arange(-np.pi,np.pi,np.pi/3)
tau =30
w= 16*np.pi
T = 4

x0s, input_ICs, phases =  create_ICs_MF(r_range,phi_range,theta_range, tau, T, dt,w)
total=len(x0s)
w_phases = wrap(phases)

input_ICs_st1= input_ICs.clone()
input_ICs_st1[:,100:300,1]=1
input_ICs_st2= input_ICs.clone()
input_ICs_st2[:,100:300,2]=1

ks1, _ = predict(rnn,input_ICs_st1, x0=x0s)
ks2, _ = predict(rnn,input_ICs_st2, x0=x0s)


In [None]:
# Run simulations with a number of models with weights sampled from the covariance matrix

ks_fss1 = []
ks_fss2 = []

# inialise finite size model
params_fs = params.copy()
params_fs['train_meanfield']=False
params_fs['train_cov']=True
params_fs['n_rec']=4096
params_fs['noise_std']=0
params_fs['n_supports']=4
rnn_fs = RNN(params_fs)
with torch.no_grad():
    rnn_fs.rnn.cov_chols.copy_(torch.from_numpy(chol_covs))

# run simulations
for modeli in range(10):

    # resample weights / loadings
    rnn_fs.rnn.resample()
    loadings = extract_loadings(rnn_fs)

    x0s_fs=np.outer(x0s[:,0],loadings[5])+np.outer(x0s[:,1],loadings[6])+np.outer(x0s[:,2],loadings[0])
    x0s_fs=torch.from_numpy(x0s_fs)

    rates, _ = predict(rnn_fs,input_ICs_st1, x0=x0s_fs)
    for ind in np.arange(total):
        ks_fs1 = np.array(proj(loadings[5:7],rates[ind,:,:])).T
        ks_fss1.append(ks_fs1)

    rates, _ = predict(rnn_fs,input_ICs_st2, x0=x0s_fs)
    for ind in np.arange(total):
        ks_fs2 = np.array(proj(loadings[5:7],rates[ind,:,:])).T
        ks_fss2.append(ks_fs2)

In [None]:
# plot trajectories

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

# plot setting
tw=.06
sc = 1.5 
floor=-1.5*sc
r=1*sc
tau=rnn.rnn.tau
freq = task_params['freq']
period = int((1000/freq)/dt)
t_start = -(period+1)
r_s=.6*sc
cvs = np.sin((w_phases[ind,t_start:]))
twfs=0.02
alphafs=1
alphafssh=0.02

#plot floor
torus=def_torus(r,r_s)
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)


# plot mean field model
x1k,y1k,z1k = tor(ks1[ind,t_start: ,1],ks1[ind,t_start: ,0], w_phases[ind,t_start:],r)
surf1 = mlab.plot3d(x1k, y1k, z1k,
                    cvs, tube_radius=tw, colormap='cool')


x2k,y2k,z2k = tor(ks2[ind,t_start:,1],ks2[ind,t_start:,0], w_phases[ind,t_start:],r)
surf2 = mlab.plot3d(x2k,y2k, z2k,
                    cvs, tube_radius=tw, colormap='cool')

surf1.module_manager.scalar_lut_manager.lut.table = lut1
surf1.actor.property.lighting = False
surf2.module_manager.scalar_lut_manager.lut.table = lut2
surf2.actor.property.lighting = False

# plot meanfield shadow
sh_surf1 = mlab.plot3d(x1k,y1k, floor*np.ones_like(x1k),
                cvs, tube_radius=tw, colormap='cool')
shadow_lut1 = np.copy(lut1)#np.ones((256,4))*240
shadow_lut1[:,3]=20
sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
sh_surf1.actor.property.lighting = False

sh_surf2 = mlab.plot3d(x2k,y2k,floor*np.ones_like(y2k),
                    cvs, tube_radius=tw, colormap='cool')
shadow_lut2 = np.copy(lut2)
shadow_lut2[:,3]=20
sh_surf2.module_manager.scalar_lut_manager.lut.table = shadow_lut2
sh_surf2.actor.property.lighting = False

# plot finite size model
for ind, (ks_fs1,ks_fs2) in enumerate(zip(ks_fss1, ks_fss2)):

    # trajectories
    x1_fs1,y1_fs1,z1_fs1 = tor(ks_fs1[t_start: ,1],ks_fs1[t_start: ,0], w_phases[ind%total,t_start:],r)
    surf1 = mlab.plot3d(x1_fs1, y1_fs1, z1_fs1,
                        tube_radius=twfs, color = (0.6,0.6,0.6),opacity=alphafs)

    x2_fs2,y2_fs2,z2_fs2 = tor(ks_fs2[t_start:,1],ks_fs2[t_start:,0], w_phases[ind%total,t_start:],r)
    surf2 = mlab.plot3d(x2_fs2,y2_fs2, z2_fs2,
                        tube_radius=twfs, color = (.8,.8,.8),opacity=alphafs)

    surf1.actor.property.lighting = False
    surf2.actor.property.lighting = False

    # add shadows
    surf1 = mlab.plot3d(x1_fs1, y1_fs1, np.ones_like(z1_fs1)*floor,
                        tube_radius=twfs, color = (0,0,0),opacity=alphafssh)

    surf2 = mlab.plot3d(x2_fs2,y2_fs2, np.ones_like(z1_fs1)*floor,
                        tube_radius=twfs, color = (.8,.8,.8),opacity=alphafssh)#
    surf1.actor.property.lighting = False
    surf2.actor.property.lighting = False

mlab.draw()
mlab.view(azimuth=90, elevation=67, distance=36, 
          focalpoint=np.array([ 0,  0, -3.5]))
mlab.plot3d(0,0,0)