In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.utils import*
from rnn_scripts.model import *
from rnn_scripts.train import *
from rnn_scripts.coupled_oscillators import *
from tasks.seqDS import *
import matplotlib.pylab as pl
%matplotlib inline

from mayavi import mlab
mlab.init_notebook()
lut1,lut2,lut1a,lut2a = np.load('../data/luts2.npy')

In [None]:
# Set the parameters for the phase precession model

fig_dir = "../figures/"

params =  {
    "nonlinearity" : "tanh",
    "out_nonlinearity":"tanh",
    "readout_kappa":False,
    "train_meanfield":False,
    "n_supports":4,
    "train_cov":True,
    "rank" : 2,
    "n_inp" : 4,
    "p_inp" : 1,
    "n_rec" :2056,
    "p_rec" : 1,
    "n_out" : 1,
    "cov": None,
    "loadings":None,
    "cov_init_noise":0,
    "scale_w_inp": 1,
    "scale_w_out": 1,
    "scale_n": 1,
    "scale_m": 1,
    "1overN_out_scaling":True,
    "train_w_inp" :False,
    "train_w_inp_scale" : False,
    "train_w_rec": False,
    "train_b_rec": False,
    "train_m": False,
    "train_n": False,
    "train_taus":False,
    "train_w_out":False,
    "train_w_out_scale":False,
    "train_b_out": False,
    "train_x0":False,
    "tau_lims": [25],
    "dt" : 1,
    "noise_std":0.0,
    "scale_x0": 0.1,
    "randomise_x0":True,
    "orth_indices":[]
}

cov_params={}
cov_params["osc_sdw"]=1.6
cov_params["osc_sd"]=2.5
cov_params["osc_sdn"]= 9
cov_params["osc_sdm"]= 1
cov_params["osc_sdmW"]= 0
cov_params["osc_sdW"]= 4
    
cov_params["coupl_sdnimi"]= 0
cov_params["coupl_sdnimj"]= 0
cov_params["coupl_sdIosc"]= 1
cov_params["coupl_sdIstim"]= 50
cov_params["coupl_sdIn"]= 2
cov_params["coupl_sdIm"]= 0
cov_params["coupl_sdn"]= 5
cov_params["coupl_sdm"]= 1
cov_params["coupl_sdmW"]= 0
cov_params["coupl_sdW"]= 4


In [None]:
# Create initial conditions and inputs for the model
r_range = np.arange(0.5,0.6,0.2)
phi_range =np.linspace(0.2,np.pi*2+0.2,5)
theta_range = np.arange(-np.pi,np.pi,np.pi*3)
stim_range=np.linspace(0,np.pi/2,8)
tau =25
w= 8*np.pi
T = 3
dt = 1
x0s, input_ICs, phases =  create_ICs_phase_prec(r_range,phi_range,theta_range,stim_range, tau, T, dt,w,amp=1)
total=len(x0s)
w_phases = wrap(phases)


# Plot input positions
colors =pl.cm.viridis(np.linspace(0,1,len(stim_range)))
colors = colors[:len(stim_range)]
th = np.linspace(-np.pi,np.pi)
ds = 30
plt.figure(figsize=(3,3))
for ind, dist in enumerate(stim_range):
    cl = colors[ind]
    plt.scatter(dist,0,color=cl,s=ds)#,ks_fs[t_start: ,0], w_phases[ind,t_start:],r)
plt.plot(np.array([0,dist]),np.array([0,0]),color='grey',zorder=-2)#,ks_fs[t_start: ,0], w_phases[ind,t_start:],r)
plt.axis('off')

plt.savefig(fig_dir + "distance.svg")



In [None]:
# Initialise an RNN model
rnn_fs = RNN(params)
chol_covs = create_MF_covs_phase_prec(cov_params, plot=True)
with torch.no_grad():
    rnn_fs.rnn.cov_chols.copy_(torch.from_numpy(chol_covs))
rnn_fs.rnn.resample()

# extract the loadings
I,n,m,_ = extract_loadings(rnn_fs,split=True)
alphas, I_orth = orthogonolise_Im(I,m)

# project the input conditions on the model loadings
x0s_fs=np.outer(x0s[:,0],m[0])+np.outer(x0s[:,1],m[1])+\
np.outer(x0s[:,2],I[0])+np.outer(x0s[:,3],I[1])
x0s_fs=torch.from_numpy(x0s_fs)

# Simulate model and project the results on the loadings
ks_fss = []
rates, _ = predict(rnn_fs,input_ICs, x0=x0s_fs)
for ind in np.arange(total):
    cl = colors[ind%len(colors)]
    ks_fs = np.array(proj(m,rates[ind,:,:])).T
    ks_fss.append(ks_fs)

In [None]:
# Plot trajectories

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

# plot settings
tw=.06
sc =2
floor=-2*sc
r=1*sc
t_start = -251
r_s=.6*sc
cvs = np.sin((w_phases[ind,t_start:]))
twfs=0.02
alphafs=1
alphafssh=0.2

# plot floor
torus=def_torus(r,r_s)
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

# plot trajectories
for ind,ks_fs in enumerate(ks_fss[:len(colors)]):
    cl = colors[ind%len(colors)]
    colr = (cl[0],cl[1],cl[2])

    x1_fs1,y1_fs1,z1_fs1 = tor(ks_fs[t_start: ,1],ks_fs[t_start: ,0], w_phases[ind,t_start:],r)
    surf1 = mlab.plot3d(x1_fs1, y1_fs1, z1_fs1,
                        tube_radius=twfs, color = colr,opacity=alphafs)#, colormap='cool')


    surf1.actor.property.lighting = False
    surf1 = mlab.plot3d(x1_fs1, y1_fs1, np.ones_like(z1_fs1)*floor,
                        tube_radius=twfs, color = colr,opacity=alphafssh)#, colormap='cool')

    surf1.actor.property.lighting = False

mlab.draw()
mlab.view(azimuth=90, elevation=67, distance=36, 
          focalpoint=np.array([ 0,  0, -3.5]))
mlab.plot3d(0,0,0)

In [None]:
# Extract coupling function from sampled models

# set/extract some parameters
N = params["n_rec"]
n_inp = params['n_inp']
w= np.pi*8
freq=4
tau_sw = (rnn_fs.rnn.tau/1000)*w
rad=calculate_mean_radius(freq, rnn_fs)

# set up coupled oscillators
co = coupled_oscillators(tau=rnn_fs.rnn.tau/1000,freq=freq,m=m,n=n,I_orth=I_orth,alphas=alphas,rad=rad,amp=1,n_osc=2)

# make inputs
inps=[np.array([1,0]), np.array([0,1])]
inp = inps[0]

#start plotting
lw=2
x = np.arange(-np.pi,np.pi,np.pi/50)
y = np.arange(-np.pi,np.pi,np.pi/50)
grid_size = len(x)

# get vmax
basew = np.sign(np.mean(co.plot_coupling(x,y,inp=inp)))*co.w
vmax = np.max([abs(co.plot_coupling(x,y,inp=inp)-basew) for inp in inps])
for inp in inps:
    fig = plt.figure(figsize = (6,6))
    ax = fig.add_subplot(111)

    # coupling function
    grid=co.plot_coupling(x,y,inp=inp)-basew 
    im = plt.imshow(grid,origin='lower',cmap='coolwarm',vmin=-vmax,vmax=vmax)
    cbar = plt.colorbar(im,fraction=0.046, pad=0.04,ticks=[-w,0,w])

    # plot settings
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim([0,len(grid)])
    ax.set_ylim([-0,len(grid)])
    plt.yticks([0,len(grid)],labels=[]) 
    plt.xticks([0,len(grid)],labels=[]) 

    plt.savefig(fig_dir+"/coupling_phase_prec"+str(inp).replace('.','').replace(' ','')+".svg")

