In [None]:
import numpy as np
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.train import *
from rnn_scripts.utils import *
from tasks.seqDS import *

import matplotlib.pyplot as plt
%matplotlib inline

from mayavi import mlab
mlab.init_notebook()

cls = green_blue_colours()
_,_,lut1,lut2 = np.load('../data/luts.npy')
purple = [91/255, 59/255, 179/255]
turq = cls[1]


In [None]:
# Load models
fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"

#model = "N512_T0217-141442" #rat 1
model = "N512_T0217-151523" #rat 2
#model = "N512_T0217-151542" #rat 3
model_alt ="N512_T0221-113711" #alternative solution

rnn,params,task_params,training_params = load_rnn(model_dir+model)
rnn_alt,params_alt,task_params_alt,training_params_alt = load_rnn(model_dir+model_alt)

In [None]:
# Some preprocessing / extracting parameters and trajectories
dt =2
rnn.rnn.svd_orth()
set_dt(task_params,rnn,dt)
make_deterministic(task_params, rnn)
I,n,m,W = extract_loadings(rnn, orth_I=False,split=True)
alphas, I_orth = orthogonolise_Im(I,m)

task_params["probe_dur"]=2
ks,phases,rates = get_traj(rnn,task_params,freq=8,amp_scale=1)
ks_alt,phases_alt,rates_alt = get_traj(rnn_alt,task_params,freq=8,amp_scale=1)


In [None]:
ks,phases,rates = get_traj(rnn,task_params,freq=8,amp_scale=1)
ks_alt,phases_alt,rates_alt = get_traj(rnn_alt,task_params,freq=8,amp_scale=1)


In [None]:
# Panel A: phase-coding RNN rates

period = len(rates[0])
fig,axs = plt.subplots(2,2,figsize=(2,1.5))
ylim=1.05

for i,ax in enumerate(axs.flatten()):

    neuron_ind=np.random.randint(512)
    ax.plot(rates[0,:,neuron_ind],color=purple)
    ax.plot(rates[1,:,neuron_ind],color=turq)

    ax.set_xlim(0,period)
    ax.set_ylim(-ylim,ylim)

    if i<2:
        ax.spines['bottom'].set_visible(False)
        ax.set_xticks([])

    else:  
        ax.set_xticks([0,period/2,period])
        ax.set_xticklabels([r'$0$',r'$\pi$',r'$2\pi$'],fontsize=6)

    if i%2==1:
        ax.spines['left'].set_visible(False)
        ax.set_yticklabels([])
        ax.set_yticks([])

    else:
        ax.set_yticks([-1,0,1],fontsize=6,labels =[r'$-1$',r'$0$',r'$1$'])

    ax.spines['bottom'].set_position(('outward',3)) 
    ax.spines['left'].set_position(('outward', 3)) 
    ax.spines['left'].set_linewidth(0.5)
    ax.spines['bottom'].set_linewidth(0.5)
    ax.tick_params(width=.5)
    ax.tick_params(length=1)

    ax.tick_params(axis='both', which='major', pad=.5)

plt.tight_layout()
plt.savefig(fig_dir+"/rates.svg")

In [None]:
# Panel D: rate-coding RNN rates

fig,axs = plt.subplots(2,2,figsize=(2,1.5))
ylim=1.05

for i,ax in enumerate(axs.flatten()):

    neuron_ind=np.random.randint(512)
    ax.plot(rates_alt[0,:,neuron_ind],color=purple)
    ax.plot(rates_alt[1,:,neuron_ind],color=turq)

    ax.set_xlim(0,period)
    ax.set_ylim(-ylim,ylim)
    
    if i<2:
        ax.spines['bottom'].set_visible(False)
        ax.set_xticks([])

    else:  
        ax.set_xticks([0,period/2,period])
        ax.set_xticklabels([r'$0$',r'$\pi$',r'$2\pi$'],fontsize=6)

    if i%2==1:
        ax.spines['left'].set_visible(False)
        ax.set_yticklabels([])
        ax.set_yticks([])

    else:
        ax.set_yticks([-1,0,1],fontsize=6,labels =[r'$-1$',r'$0$',r'$1$'])


    ax.spines['bottom'].set_position(('outward',3)) 
    ax.spines['left'].set_position(('outward', 3)) 
    ax.spines['left'].set_linewidth(0.5)
    ax.spines['bottom'].set_linewidth(0.5)
    ax.tick_params(width=.5)
    ax.tick_params(length=1)

    ax.tick_params(axis='both', which='major', pad=.5)
plt.tight_layout()
plt.savefig(fig_dir+"rates_alt.svg")


In [None]:
# Panel B: phase-coding RNN kappa space

plt.figure(figsize=(2,2))
plt.plot(ks[0,0], ks[0,1],color = purple)#, phases)
plt.plot(ks[1,0], ks[1,1],color = turq)#, phases)
ph_i=10
s=20
a=.8
plt.scatter(ks[1,0,ph_i],ks[1,1,ph_i],color='black',zorder=2,s=s,alpha=a)
plt.scatter(ks[0,0,ph_i],ks[0,1,ph_i],color='black',zorder=2,s=s,alpha=a)
plt.xticks([])
plt.yticks([])
ax = plt.gca()
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)

plt.savefig(fig_dir+"/kspace.svg")

In [None]:
# Panel E: rate-coding RNN kappa space

plt.figure(figsize=(2,2))
plt.plot(ks_alt[0,0], ks_alt[0,1],color = purple)#, phases)
plt.plot(ks_alt[1,0], ks_alt[1,1],color = turq)#, phases)
plt.scatter(ks_alt[1,0,ph_i],ks_alt[1,1,ph_i],color='black',zorder=2,s=s,alpha=a)
plt.scatter(ks_alt[0,0,ph_i],ks_alt[0,1,ph_i],color='black',zorder=2,s=s,alpha=a)
plt.xticks([])
plt.yticks([])
ax = plt.gca()
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)

plt.savefig(fig_dir+"kspace_alt.svg")








In [None]:
# Panel C: phase-coding RNN toroid coordinates

floor=-3
r=1.3*1.3
r_s=.7*1.3
tw=.1

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))
luts = [lut2,lut1]

#plot surfaces
torus=def_torus(r,r_s)
ks[:,:,-1]=ks[:,:,0]
cvs = np.sin((phases-np.pi))

mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

for i in range(len(ks)):
    x1,y1,z1 = tor(ks[i,1],ks[i,0], phases-np.pi*0.5,r)
    surf1 = mlab.plot3d(x1, y1, z1,
                        cvs, tube_radius=tw, colormap='cool')

    #set colormap
    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]

    #set lightning
    surf1.actor.property.lighting = False
   
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')

    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False

mlab.draw()
mlab.view(-20, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)


In [None]:
# Panel F: phase-coding RNN toroid coordinates

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

ks_alt[:,:,-1]=ks_alt[:,:,0]

mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

for i in range(len(ks_alt)):
    x1,y1,z1 = tor(ks_alt[i,1],ks_alt[i,0], phases-np.pi*0.5,r)
    surf1 = mlab.plot3d(x1, y1, z1,
                        cvs, tube_radius=tw, colormap='cool')


    #set colormap
    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]

    #set lightning
    surf1.actor.property.lighting = False
   
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')

    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False

mlab.draw()
mlab.view(-20, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)