In [None]:
import numpy as np
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.train import *
from rnn_scripts.utils import *
from tasks.seqDS import *
import matplotlib.colors as colors
import matplotlib.pyplot as plt
%matplotlib inline

from mayavi import mlab
mlab.init_notebook()

_,_,luts1,luts2 = np.load('../data/luts.npy')


In [None]:
# Load model

fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"

model = "N512_T0222-185615"  #4 items
rnn,params,task_params,training_params = load_rnn(model_dir+model)


In [None]:
# Some preprocessing / extracting parameters

dt =2
rnn.rnn.svd_orth()
set_dt(task_params,rnn,dt)
make_deterministic(task_params, rnn)
I,n,m,W = extract_loadings(rnn, orth_I=False,split=True)
alphas, I_orth = orthogonolise_Im(I,m)

In [None]:
# Extract the trajectories

task_params["probe_dur"]=.375
ks,phases,rates = get_traj(rnn,task_params,freq=8,amp_scale=1)

plt.figure(figsize=(4,4))
for i in range(len(ks)):
    plt.plot(ks[i,0], ks[i,1])
vm = np.max(ks)*1.1
plt.xlim(-vm,vm)
plt.ylim(-vm,vm)

In [None]:
# Make plot of trained model

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))
luts = [luts1,luts2/3+luts1*2/3,luts2*2/3+luts1*1/3,luts2]
tw=.08
floor=-1.5
r_tor=1.5
r=2
r_s=1
cvs = np.sin((phases))

# plot floor
torus=def_torus(r_tor,r_s)
ks[:,:,-1]=ks[:,:,0]
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

# plot trajectories
for i in range(len(ks)):
    x1,y1,z1 = tor(ks[i,1],ks[i,0], phases,r)
    surf1 = mlab.plot3d(x1, y1, z1,
                        cvs, tube_radius=tw, colormap='cool')

    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]
    surf1.actor.property.lighting = False
   
    # plot shadows
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')

    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False

mlab.draw()
mlab.view(azimuth=8, elevation=67, distance=36, 
          focalpoint=np.array([ 0,  0, -3.5]))
mlab.plot3d(0,0,0)

In [None]:
#Initialize a Meanfield RNN with a covariance matrix allowing coding for 4 stimuli

dt = 2

# RNN settings
params =  {
    "nonlinearity" : "tanh",
    "out_nonlinearity":"identity",
    "readout_kappa":False,
    "train_meanfield":True,
    "n_supports":5,
    "train_cov":False,
    "rank" : 2,
    "n_inp" : 5,
    "p_inp" : 1,
    "n_rec" :7,
    "p_rec" : 1,
    "n_out" : 1,
    "cov": None,
    "loadings":None,
    "cov_init_noise":0,
    "apply_dale":False,
    "scale_w_inp": 1,
    "scale_w_out": 1,
    "scale_n": 1,
    "scale_m": 1,
    "1overN_out_scaling":True,
    "train_w_inp" :False,
    "train_w_inp_scale" : False,
    "train_w_rec": False,
    "train_b_rec": False,
    "train_m": False,
    "train_n": False,
    "train_taus":False,
    "train_w_out":False,
    "train_w_out_scale":False,
    "train_b_out": False,
    "train_x0":False,
    "tau_lims": [30],
    "dt" : dt,
    "noise_std":0.05,
    "scale_x0": 0.1,
    "randomise_x0":True,
    "orth_indices":[]
}

# Covariance settings
cov_params={}
cov_params["osc_w"]=10
cov_params["osc_r"]=7
cov_params["osc_sdn"]= 50
cov_params["osc_sdm"]= 4
cov_params["osc_sdmW"]= 0
cov_params["osc_sdW"]= 16

cov_params["coupl1_sdn1m2"]= 10
cov_params["coupl1_sdn2m1"]= 5
cov_params["coupl1_sdn1m1"]= 5
cov_params["coupl1_sdn2m2"]= -10
cov_params["coupl1_sdIosc"]= 25
cov_params["coupl1_sdm1"]= 25
cov_params["coupl1_sdm2"]= 1.25

cov_params["coupl2_sdn1m2"]= -5
cov_params["coupl2_sdn2m1"]= -10
cov_params["coupl2_sdn1m1"]= -10
cov_params["coupl2_sdn2m2"]= 5
cov_params["coupl2_sdIosc"]= 25
cov_params["coupl2_sdm1"]= 25
cov_params["coupl2_sdm2"]= 1.25

cov_params["coupl_sdIstim"]= 100
cov_params["coupl_sdIn"]= 15
cov_params["coupl_sdIm"]= 0
cov_params["coupl_sdn"]= 70
cov_params["coupl_sdmW"]= 2
cov_params["coupl_sdW"]= 16

chol_covs = create_MF_covs_4Stims(cov_params, plot=False,vm=12)

params['dt']=dt

rnn = RNN(params)
with torch.no_grad():
    rnn.rnn.cov_chols.copy_(torch.from_numpy(np.array(chol_covs)))

In [None]:
# Plot the covariance matrices

titles=["oscillator","coupling a","coupling b","coupling c","coupling d"]

labels = ["$I_{osc}$", "$I_{s_a}$", "$I_{s_b}$","$I_{s_c}$","$I_{s_d}$", "$n_1$", "$n_2$", "$m_1$", "$m_2$"]

fig, axs = plot_covs(
    [(covch@covch.T)[:9,:9] for covch in chol_covs],
    vm=120,
    labels=labels,
    titles=titles,
    figsize=(6,3),
    dpi=100,
    fontsize=5,
    float_labels=False,
    atol=0.1,
    float_lims=5,
    label_fs=5,
    numbers_fs=5,
)
plt.savefig("../figures/covs4s.svg",dpi=100)

In [None]:
# Break down the coupling function into constituent parts

def gain(z):
    """Expectation of E[tanh'(z)]] with error function approximation"""
    return 1/np.sqrt(1+(np.pi/2)*z)

def extract_gain_linear(cov,r=.5,n_inp=3, n_grid = 30, w= 1, tau = 1):
    """Extract the linear part of the coupling function and the gain function"""
    
    sdm1 = cov[n_inp+2,n_inp+2]
    sdm2 = cov[n_inp+3,n_inp+3]
    sdm1n1 = cov[n_inp,n_inp+2]
    sdm1n2 = cov[n_inp+1,n_inp+2]
    sdm2n1 = cov[n_inp,n_inp+3]
    sdm2n2 = cov[n_inp+1,n_inp+3]
    sdI = cov[0,0]
    sdIn1 = cov[0,n_inp]
    sdIn2 = cov[0,n_inp+1]

    grid_linear = np.zeros((n_grid,n_grid))
    grid_gain = np.zeros((n_grid,n_grid))

    for i, theta in enumerate(np.linspace(-np.pi,np.pi,n_grid)):
        for j, phi in enumerate(np.linspace(-np.pi,np.pi,n_grid)):
            inp = 1/np.sqrt(1+(w*tau)**2) * np.sin(phi-np.arctan2((w*tau),1))
            k1 = r * np.cos(theta)
            k2 = r * np.sin(theta)
            z = sdm1*(k1**2) + sdm2*(k2**2) + sdI*(inp**2)          
            d_theta_linear =(0.5*(sdm1n2-sdm2n1)+ 0.5*(sdm2n1+sdm1n2)*np.cos(2*theta) +\
                                0.5*(sdm2n2-sdm1n1)*np.sin(2*theta) +\
                                (inp/r)*(np.cos(theta)*sdIn2-np.sin(theta)*sdIn1))

            grid_linear[i,j]=d_theta_linear
            grid_gain[i,j]=gain(z)

    return grid_linear,grid_gain



In [None]:
# For gain function only use positive part of coolwarm as it is >0
# from: https://stackoverflow.com/a/18926541
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap
cmap = plt.get_cmap('coolwarm')
cw_half = truncate_colormap(cmap, 0.5, 1)


In [None]:
# Plot the coupling function deconstructed

n_grid = 60
w=16*np.pi
tau =.03
r = .4
vm = 8

# Population 1 and 2
ind= 1
grid_linear0,grid_gain0 = extract_gain_linear(chol_covs[ind]@chol_covs[ind].T,r=r,n_inp=5, n_grid = n_grid, w= w, tau = tau)
ind= 2
grid_linear1,grid_gain1 = extract_gain_linear(chol_covs[ind]@chol_covs[ind].T,r=r,n_inp=5, n_grid = n_grid, w= w, tau = tau)

fig,ax = plt.subplots(3,4,figsize=(6,4),gridspec_kw={'width_ratios': [1, 1, 1, 1.5], 'height_ratios': [1, 1, .07]})
vm1 = np.max(np.abs(grid_linear0+grid_linear1))
ax[0,0].imshow(grid_linear0+grid_linear1,cmap = 'coolwarm',vmin=-vm1,vmax=vm1, origin='lower')
vm2 = np.max(np.abs(grid_gain0+grid_gain1))
ax[0,1].imshow(grid_gain0+grid_gain1,cmap = cw_half,vmin=0,vmax=vm2, origin='lower')
pop1 = grid_gain0*grid_linear0 + grid_gain1*grid_linear1
vm3 = np.max(np.abs(pop1))
ax[0,2].imshow(pop1,cmap = 'coolwarm',vmin=-vm3,vmax=vm3, origin='lower')

# Population 3 and 4
ind=3
grid_linear2,grid_gain2 = extract_gain_linear(chol_covs[ind]@chol_covs[ind].T,r=r,n_inp=5, n_grid = n_grid, w= w, tau = tau)
ind=4
grid_linear3,grid_gain3 = extract_gain_linear(chol_covs[ind]@chol_covs[ind].T,r=r,n_inp=5, n_grid = n_grid, w= w, tau = tau)

im=ax[1,0].imshow(grid_linear2 + grid_linear3,cmap = 'coolwarm',vmin=-vm1,vmax=vm1, origin='lower')
fig.colorbar(im,cax=ax[2,0],orientation='horizontal')
im=ax[1,1].imshow(grid_gain2+ grid_gain3,cmap = cw_half,vmin=0,vmax=vm2, origin='lower')
cbar = fig.colorbar(im,cax=ax[2,1],orientation='horizontal')
pop2= grid_gain2*grid_linear2 + grid_gain3*grid_linear3
im=ax[1,2].imshow(pop2,cmap = 'coolwarm',vmin=-vm3,vmax=vm3, origin='lower')
fig.colorbar(im,cax=ax[2,2],orientation='horizontal')

# Total coupling function
vm4=np.max(np.abs(pop1+pop2))
im = ax[1,3].imshow(pop2+pop1,cmap = 'coolwarm',vmin=-vm4,vmax=vm4, origin='lower')
fig.colorbar(im,cax=ax[2,3],orientation='horizontal')

# Add some labels and ticks
for i, axi in enumerate(ax[:2]):
    for j, axij in enumerate(axi):
        axij.set_xticks([0,n_grid-1])
        axij.set_yticks([0,n_grid-1])
        if i==1:
            axij.set_xticklabels(["0","$2\pi$"])
        else:
            axij.set_xticklabels([])
        if j ==0 or j==3:
            axij.set_yticklabels(["0","$2\pi$"])
        else:
            axij.set_yticklabels([])

ax[0,3].set_visible(False)

plt.savefig("../figures/dyns.svg",dpi=100)

In [None]:
# Run some simulations with the reduced equations

r_range = [.5]
phi_range = np.arange(-np.pi,np.pi,np.pi*3)
theta_range = np.arange(-np.pi,np.pi,np.pi/15)
tau =30
w= 16*np.pi
T = 4

# Create initial conditions and input
x0s, input_ICs, phases =  create_ICs_MF(r_range,phi_range,theta_range, tau, T, dt,w,n_inp=5)
total=len(x0s)
w_phases = wrap(phases)
ste=500

# Also run simulations with stimulus input
input_ICs_st1= input_ICs.clone()
input_ICs_st1[:,100:ste,1]=1
input_ICs_st2= input_ICs.clone()
input_ICs_st2[:,100:ste,2]=1
input_ICs_st3= input_ICs.clone()
input_ICs_st3[:,100:ste,3]=1
input_ICs_st4= input_ICs.clone()
input_ICs_st4[:,100:ste,4]=1
rnn.rnn.weights.copy_(torch.tensor([.25,.1875,.1875,.1875,.1875]))

ks0,_ = predict(rnn,input_ICs, x0=x0s)
ks1,_ = predict(rnn,input_ICs_st1, x0=x0s)
ks2,_ = predict(rnn,input_ICs_st2, x0=x0s)
ks3,_ = predict(rnn,input_ICs_st3, x0=x0s)
ks4,_ = predict(rnn,input_ICs_st4, x0=x0s)


In [None]:
# Make some plots of the attractors 
t_st = 1000
plt.figure()
for i in range(len(phases)): 
    plt.title("No input")
    plt.scatter(wrap(phases[i,t_st:]),ks0[i,t_st:-1,0],alpha=.2)
plt.figure()
for i in range(len(phases)): 
    plt.title("Stim A")
    plt.scatter(wrap(phases[i,t_st:]),ks1[i,t_st:-1,0],alpha=.2)
plt.figure()
for i in range(len(phases)): 
    plt.title("Stim B")
    plt.scatter(wrap(phases[i,t_st:]),ks2[i,t_st:-1,0],alpha=.2)
plt.figure()
for i in range(len(phases)): 
    plt.title("Stim C")
    plt.scatter(wrap(phases[i,t_st:]),ks3[i,t_st:-1,0],alpha=.2)
plt.figure()
for i in range(len(phases)): 
    plt.title("Stim D")
    plt.scatter(wrap(phases[i,t_st:]),ks4[i,t_st:-1,0],alpha=.2)
