In [None]:
import numpy as np
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.train import *
from rnn_scripts.utils import *
from rnn_scripts.coupled_oscillators import *
import matplotlib.pyplot as plt
%matplotlib inline

from mayavi import mlab
mlab.init_notebook()
_,_,lut1a,lut2a = np.load('../data/luts.npy')

In [None]:
# Load models
fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"
model = "N512_T0217-151523" #rat 2
rnn,params,task_params,training_params = load_rnn(model_dir+model)

In [None]:
# Some preprocessing / extracting parameters
dt =.5
rnn.rnn.svd_orth()
set_dt(task_params,rnn,dt)
make_deterministic(task_params, rnn)
I,n,m,W = extract_loadings(rnn, orth_I=False,split=True)
alphas, I_orth = orthogonolise_Im(I,m)
freq=8.4
rad=calculate_mean_radius(freq, rnn)
N = params["n_rec"]
n_inp = params['n_inp']
w= np.pi*2*freq
tau_sw = (rnn.rnn.tau/1000)*w
inp = np.zeros(n_inp-1)
period = int((1000/freq)/dt)


In [None]:
# set up coupled oscillators
co = coupled_oscillators(tau=rnn.rnn.tau/1000,freq=freq,m=m,n=n,I_orth=I_orth,alphas=alphas,rad=rad,amp=1)

In [None]:
# Make plot

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

R = 2.5
r = 1
tw=.07

# plot torus
torus=def_torus(R,r)
grey = 0.4
opac = 0.05
torusurf=mlab.mesh(torus[0], torus[1], torus[2], opacity=opac,color=(grey,grey,grey))
torusurf.actor.property.lighting =False


init_state=np.zeros(2)

# Plot CO trajectories
dur = 2
dt  = 0.0005
T_st = dur-(1/freq)
T_st=int(T_st/dt)-2
init_phis = np.arange(0,np.pi*2,np.pi)
init_thetas = [0]
inp = np.zeros(n_inp-1)
all_states = co.run_sims(init_phis,init_thetas,dur,dt,inp,forward=True)
for i, states in enumerate(all_states):
    x,y,z = tor_from_angles(wrap(states[T_st:,0]),wrap(states[T_st:,1]),R,r+tw)
    cvs = np.sin(wrap(states[T_st:,0]))

    # Make dashed 3d lines
    for it in np.arange(0,period,2):
        if i<2:
            color =(0,0,0)
        else:
            color = (0.8,0.8,0.8)
        surf1 = mlab.plot3d(x[it:it+2],y[it:it+2],z[it:it+2],cvs[it:it+2], tube_radius=tw,color=color)#colormap='cool')  #zorder = 10, ls='-')
        surf1.actor.property.lighting = False


# Plot RNN trajectories projected in the same space
tw=.05
ks,phases,rates = get_traj(rnn,task_params,freq=8,amp_scale=1)
for ind in np.arange(2):
    k_phase = np.arctan2(ks[ind,1], ks[ind,0])
    cvs = np.sin(phases)
    x,y,z = tor_from_angles(phases-0.5*np.pi,k_phase,R,r+tw)

    surf1 = mlab.plot3d(x,y,z,cvs,tube_radius=tw, colormap='cool')#,color = blue)#, colormap='cool')  #zorder = 10, ls='-')
    surf1.module_manager.scalar_lut_manager.lut.table = lut1a
    if ind==1:
        surf1.module_manager.scalar_lut_manager.lut.table = lut2a

    else:
        surf1.module_manager.scalar_lut_manager.lut.table = lut1a
    surf1.actor.property.lighting = False

mlab.plot3d(0,0,0)

