In [None]:
import numpy as np
import os, sys
sys.path.append(os.getcwd()+"/..")
from rnn_scripts.train import *
from rnn_scripts.utils import *
from rnn_scripts.bifurcations import *
from tasks.seqDS import *
from scipy.cluster.vq import kmeans2

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.patches as patches

from mayavi import mlab
mlab.init_notebook()

cls = green_blue_colours()
_,_,lut1,lut2 = np.load('../data/luts.npy')
purple = [91/255, 59/255, 179/255]
turq = cls[1]


In [None]:
# Load models
fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"

#model = "N512_T0217-141442" #rat 1
model = "N512_T0217-151523" #rat 2
#model = "N512_T0217-151542" #rat 3
model_alt ="N512_T0221-113711" #alternative solution

rnn,params,task_params,training_params = load_rnn(model_dir+model)


In [None]:
# Some preprocessing / extracting parameters
dt =.5
rnn.rnn.svd_orth()
set_dt(task_params,rnn,dt)
make_deterministic(task_params, rnn)
I,n,m,W = extract_loadings(rnn, orth_I=False,split=True)
alphas, I_orth = orthogonolise_Im(I,m)

In [None]:
# Create trajectories for Poincare Maps

T = 2
tau=rnn.rnn.tau
freq = task_params['freq']
period = int((1000/freq)/dt)
w=np.pi*2*freq
rad=calculate_mean_radius(freq, rnn)

#range of inital conditions
r_range = np.arange(rad,rad+0.21,.2)
phi_range = np.arange(-np.pi,np.pi,np.pi/4)
theta_range = [-0.5*np.pi]
total = len(r_range)*len(phi_range)*len(theta_range)

# run trajectories without stimulus
x0s, input_ICs, phases = create_ICs(r_range,phi_range,theta_range, tau, T, dt,w,m,I_orth)
rates_ICs, _ = predict(rnn.cpu(),input_ICs, x0=x0s)

# run trajectories with stimulus
_, input_ICs_st1, phases_st1 = create_ICs(r_range,phi_range,theta_range, tau, T, dt,w,m,I_orth,stim_ind=1)
rates_ICs_st1, _ = predict(rnn.cpu(),input_ICs_st1, x0=x0s)
_, input_ICs_st2, phases_st2 = create_ICs(r_range,phi_range,theta_range, tau, T, dt,w,m,I_orth,stim_ind=2)
rates_ICs_st2, _ = predict(rnn.cpu(),input_ICs_st2, x0=x0s)

#Project rates on M:
Ks = np.zeros((total,2,len(input_ICs[0])+1))
Ks_st1 = np.zeros((total,2,len(input_ICs[0])+1))
Ks_st2 = np.zeros((total,2,len(input_ICs[0])+1))

for ind in np.arange(total):
    k = proj(m,rates_ICs[ind])
    Ks[ind]=np.array(k)
    k_st1 = proj(m,rates_ICs_st1[ind])
    Ks_st1[ind]=np.array(k_st1)
    k_st2 = proj(m,rates_ICs_st2[ind])
    Ks_st2[ind]=np.array(k_st2)

In [None]:
# Make Poincare map and cluster
pm = poincare_map(Ks,period,0,phases[:,0],ph0=0)
pm_st1 = poincare_map(Ks_st1,period,0,phases[:,0],ph0=0)
pm_st2 = poincare_map(Ks_st2,period,0,phases_st2[:,0],ph0=0)

_, labels =kmeans2(pm[:,:,-1],2,seed=3)
colors = []
for lab in labels:
    if lab ==0:
        colors.append(turq)
    else:
        colors.append(purple)


In [None]:
# Poincare map plots
l = np.max(pm)*1.1
for i in [0,1,13]:
    fig=plt.figure(figsize=(2,2))
    ds=30
    plt.scatter(pm[:,1,i],pm[:,0,i],color = colors,alpha=1,s=ds)
    plt.xticks([])
    plt.yticks([])
    plt.xlim(-l,l)
    plt.ylim(-l,l)

    plt.savefig(fig_dir + "PM" + str(i)+".svg")

In [None]:
# Poincare map plots with stimulus input

fig=plt.figure(figsize=(2,2))
l = np.max([pm_st1,pm_st2])*1.1
ds=60
plt.scatter(pm_st1[0,1,10],pm_st1[0,0,10],color = colors[0],alpha=1,s=ds)
plt.xticks([])
plt.yticks([])
plt.xlim(-l,l)
plt.ylim(-l,l)

plt.savefig(fig_dir + "PM_S1.svg")

fig=plt.figure(figsize=(2,2))
plt.scatter(pm_st2[0,1,10],pm_st2[0,0,10],color =purple,alpha=1,s=ds)
plt.xticks([])
plt.yticks([])
plt.xlim(-l,l)
plt.ylim(-l,l)

plt.savefig(fig_dir + "PM_S2.svg")





In [None]:
# Poincare map trajectory plots, C=0

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

T_start =0
T_end = period//2

floor=-2.7
r=1.3
r_s=.7
tw=.02


# Create floor
torus=def_torus(r,r_s)
m_color = (0.5,0.5,0.5)
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

#Create Poincare section
surf = mlab.mesh(np.array([[0,0],[0,0]]), np.array([[-r-r_s,-r-r_s],[-r+r_s,-r+r_s]]), 
          np.array([[-r_s,r_s],[-r_s,r_s]])
          , color=m_color, opacity=1)
surf.actor.property.lighting = False

#Create Poincare shadow on floor
surf = mlab.mesh(np.array([[0,0],[0,0]]),np.array([[-r-r_s,-r-r_s],[-r+r_s,-r+r_s]]), 
          np.array([[floor,floor+0.01],[floor,floor+0.01]])
          , color=(0.5,0.5,0.5), opacity=1)
surf.actor.property.lighting = False

# Create luts for trajectories and shades
trans_lut1 = np.copy(lut1)
trans_lut1[:,3]=(np.flip(np.linspace(1,0,len(trans_lut1[:,3]))**2))*255
trans_lut2 = np.copy(lut2)
trans_lut2[:,3]=(np.flip(np.linspace(1,0,len(trans_lut2[:,3]))**2))*255

shadow_lut1 = np.copy(trans_lut1)
shadow_lut1[:,3]*=0.02
shadow_lut2 = np.copy(trans_lut2)
shadow_lut2[:,3]*=0.02

# Create trajectories
for ind in range(total):
    
    k=Ks[ind]
    ki = wrap(phases[ind])
    cvs = np.sin((ki[T_start:T_end])-np.pi)
    x,y,z = tor(k[1, T_start:T_end], k[0, T_start:T_end], ki[T_start:T_end],r=r)

    #trajectory   
    surf1 = mlab.plot3d(x, y, z,
                        cvs, tube_radius=tw, colormap='cool')
    surf1.actor.property.lighting = False

    #shadow
    sh_surf1 = mlab.plot3d(x,y, floor*np.ones_like(ki[T_start:T_end]),
                            cvs, tube_radius=tw, colormap='cool')
    sh_surf1.actor.property.lighting = False

    #set colormap
    if labels[ind]==0:
        surf1.module_manager.scalar_lut_manager.lut.table = trans_lut1
        sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1#lut_to_grey(lut1)

    else:
        surf1.module_manager.scalar_lut_manager.lut.table = trans_lut2
        sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut2#lut_to_grey(lut1)

mlab.draw()
mlab.view(-20, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)

In [None]:
# Poincare map trajectory plots, C=1
mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))


T_start =period-period//3
T_end = period+period//3


# Create floor
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

#Create Poincare section
surf = mlab.mesh(np.array([[0,0],[0,0]]), np.array([[-r-r_s,-r-r_s],[-r+r_s,-r+r_s]]), 
          np.array([[-r_s,r_s],[-r_s,r_s]])
          , color=m_color, opacity=1)
surf.actor.property.lighting = False

#Create Poincare shadow on floor
surf = mlab.mesh(np.array([[0,0],[0,0]]),np.array([[-r-r_s,-r-r_s],[-r+r_s,-r+r_s]]), 
          np.array([[floor,floor+0.01],[floor,floor+0.01]])
          , color=(0.5,0.5,0.5), opacity=1)
surf.actor.property.lighting = False

# Create trajectories
for ind in range(total):
    
    k=Ks[ind]
    ki = wrap(phases[ind])
    cvs = np.sin((ki[T_start:T_end])-np.pi)
    x,y,z = tor(k[1, T_start:T_end], k[0, T_start:T_end], ki[T_start:T_end],r=r)

    #trajectory   
    surf1 = mlab.plot3d(x, y, z,
                        cvs, tube_radius=tw, colormap='cool')
    surf1.actor.property.lighting = False

    #shadow
    sh_surf1 = mlab.plot3d(x,y, floor*np.ones_like(ki[T_start:T_end]),
                            cvs, tube_radius=tw, colormap='cool')
    sh_surf1.actor.property.lighting = False

    #set colormap
    if labels[ind]==0:
        surf1.module_manager.scalar_lut_manager.lut.table = trans_lut1
        sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1#lut_to_grey(lut1)

    else:
        surf1.module_manager.scalar_lut_manager.lut.table = trans_lut2
        sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut2#lut_to_grey(lut1)

mlab.draw()
mlab.view(-20, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)






In [None]:
# Poincare map trajectory plots, C=13
mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))


T_start =period*13
T_end = period*14


# Create floor
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

#Create Poincare section
surf = mlab.mesh(np.array([[0,0],[0,0]]), np.array([[-r-r_s,-r-r_s],[-r+r_s,-r+r_s]]), 
          np.array([[-r_s,r_s],[-r_s,r_s]])
          , color=m_color, opacity=1)
surf.actor.property.lighting = False

#Create Poincare shadow on floor
surf = mlab.mesh(np.array([[0,0],[0,0]]),np.array([[-r-r_s,-r-r_s],[-r+r_s,-r+r_s]]), 
          np.array([[floor,floor+0.01],[floor,floor+0.01]])
          , color=(0.5,0.5,0.5), opacity=1)
surf.actor.property.lighting = False

#shadow lut, non transparent
shadow_lut1 = np.copy(lut1)
shadow_lut1[:,3]*=0.02
shadow_lut2 = np.copy(lut2)
shadow_lut2[:,3]*=0.02

# Create trajectories
for ind in range(total):
    
    k=Ks[ind]
    ki = wrap(phases[ind])
    cvs = np.sin((ki[T_start:T_end])-np.pi)
    x,y,z = tor(k[1, T_start:T_end], k[0, T_start:T_end], ki[T_start:T_end],r=r)

    #trajectory   
    surf1 = mlab.plot3d(x, y, z,
                        cvs, tube_radius=tw, colormap='cool')
    surf1.actor.property.lighting = False

    #shadow
    sh_surf1 = mlab.plot3d(x,y, floor*np.ones_like(ki[T_start:T_end]),
                            cvs, tube_radius=tw, colormap='cool')
    sh_surf1.actor.property.lighting = False

    #set colormap
    if labels[ind]==0:
        surf1.module_manager.scalar_lut_manager.lut.table = lut1
        sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1

    else:
        surf1.module_manager.scalar_lut_manager.lut.table = lut2
        sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut2

mlab.draw()
mlab.view(-20, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)







In [None]:
# Check if stimulus-induced-bifurcation data for this model is available
stim_bifur_avail=False
file = "../data/stim_bifur_dat-"+model+'.pkl'
if os.path.isfile(file):
    with open(file, "rb") as f:
        bifur_dat = pickle.load(f)
    stim_bifur_avail=True
else:
    print("NO STIMULUS BIFURCATION DATA AVAILABLE")
    print("Run rnn_scripts/run_bifurcations_stimulus.py first")

In [None]:
# Plot the eigenvalues /floquet multipliers for the stimulus-induced bifurcation

if stim_bifur_avail:
    #  extract data
    evs = bifur_dat["evs"]
    Ks = bifur_dat["Ks"]
    amps = bifur_dat["amps"]
    e0l0,e0l1,e1l0,e1l1= get_evs_stim_bifurc(Ks,evs)

    # plot bifurcation from stimulus 1
    plt.figure(figsize=(4,4))
    colors = []
    bifurt =np.argmax(e0l0)
    lw=4
    lw_dot=2

    plt.axvline(amps[bifurt],ls=':',color='black',lw=lw_dot)
    plt.plot(amps[:bifurt+1],e0l0[:bifurt+1],color=turq,zorder=-1,lw=lw)
    plt.plot(amps[:-1],e0l1[:-1],color=purple,zorder=0,lw=lw)
    plt.ylim(0,1)
    plt.yticks(np.arange(0,1.01,1),labels=[])
    plt.xticks(np.arange(0,0.251,.25),labels=[])
    plt.xlim(0,.25)
    plt.savefig(fig_dir + "FM_S1.svg")

    # plot bifurcation from stimulus 2
    plt.figure(figsize=(4,4))
    colors = []
    bifurt =np.argmax(e1l1)
    lw=4
    lw_dot=2
    
    plt.axvline(amps[bifurt],ls=':',color='black',lw=lw_dot)
    plt.plot(amps[:bifurt+1],e1l1[:bifurt+1],color=purple,zorder=-1,lw=lw)
    plt.plot(amps[:-1],e1l0[:-1],color=turq,zorder=0,lw=lw)
    plt.ylim(0,1)
    plt.yticks(np.arange(0,1.01,0.5),labels=[])#[0,.5,1])
    plt.xlim(0,.25)
    plt.yticks(np.arange(0,1.01,1),labels=[])#[0,1])
    plt.xticks(np.arange(0,0.251,.25),labels=[])#[0,.25])
    plt.savefig(fig_dir + "FM_S2.svg")



In [None]:
# check if amplitude, frequency induced bifurcation data is available
bifur_avail=False
file = "../data/bifur_dat-"+model+'.pkl'
file = "../data/300pix_bifur_dat.pkl"
if os.path.isfile(file):
    with open(file, "rb") as f:
        bifur_dat = pickle.load(f)
    bifur_avail=True
else:
    print("NO BIFURCATION DATA AVAILABLE")
    print("Run rnn_scripts/run_bifurcations.py first")

In [None]:
# Load or extract training statistsics for the plot
file ="../data/train_stats_rat2.npy"
if os.path.isfile(file):
        train_stats =np.load(file)
else:
    # extract stats, this might take a while
    data_path= "_"
    get_dataset_stats(task_params,training_params, data_path,2/1000, file)
    np.load(file)

amp_stats=train_stats[0]
pow_stats=train_stats[1]
freq_stats=train_stats[2]


In [None]:
# Create the bifurcation plot

if bifur_avail:
    evs = bifur_dat["evs"]
    freqs = np.flip(bifur_dat["freqs"])
    amps = bifur_dat["amps"]
    
    # plot the maximum Floquet multiplier for each frequency and amplitude
    evs_max = np.max(evs,axis=-1)
    vmax=1
    ytick_ind=60
    xtick_ind=60

    fig, axs = plt.subplots(1,1, figsize=(8,8))
    im=axs.imshow(np.flip(np.max(evs_max,axis=-1),axis=0),cmap='GnBu',vmax=vmax,
                  interpolation='None')
    divider = make_axes_locatable(axs)
    cax = divider.append_axes('right', size='5%', pad=0.25)
    fig.colorbar(im, cax=cax)
    axs.set_yticks(np.arange(len(freqs))[::ytick_ind])
    axs.set_yticklabels(['{:.1f}'.format(i) for i in freqs[::ytick_ind]])
    axs.set_xticks(np.arange(len(amps))[::xtick_ind])
    axs.set_xticklabels(['{:.1f}'.format(i) for i in amps[::xtick_ind]])
    
    # Add a box to the plot denoting parameters seend during training
    x1 = arg_is_close(amps,np.mean(amp_stats)-np.std(amp_stats))
    x2 = arg_is_close(amps,np.mean(amp_stats)+np.std(amp_stats))
    x1 = arg_is_close(amps,np.percentile(amp_stats,5))
    x2 = arg_is_close(amps,np.percentile(amp_stats,95))
    xm = arg_is_close(amps,np.mean(amp_stats))
    width = x2-x1
    y1 = arg_is_close(freqs,np.mean(freq_stats)-np.std(freq_stats))
    y2= arg_is_close(freqs,np.mean(freq_stats)+np.std(freq_stats))
    y1 = arg_is_close(freqs,np.percentile(freq_stats,5))
    y2= arg_is_close(freqs,np.percentile(freq_stats,95))
    ym= arg_is_close(freqs,np.mean(freq_stats))
    height= y2-y1 #needed as log scale
    rect = patches.Rectangle([x1, y1], width, height, linewidth=2, edgecolor='black', facecolor='none',ls='--')
    axs.add_patch(rect)

    # Add red dots to indicate which models we plot the trajectories of
    axs.scatter(xm,ym,color='red')
    xm = arg_is_close(amps,0.27)
    ym= arg_is_close(freqs,8.4)
    axs.scatter(xm,ym,color='red')
    xm = arg_is_close(amps,0.27)
    ym= arg_is_close(freqs,6.2)
    axs.scatter(xm,ym,color='red')
    xm = arg_is_close(amps,0.7)
    ym= arg_is_close(freqs,4.5)
    axs.scatter(xm,ym,color='red')
    xm = arg_is_close(amps,0.5)
    ym= arg_is_close(freqs,2.2)
    axs.scatter(xm,ym,color='red')

    plt.savefig("/Users/matthijs/Documents/Figures_AI/Bif.svg")

In [None]:
# Plot trajectories for the models we selected
task_params["probe_dur"]=3
ks,phases,_ = get_traj(rnn,task_params,freq=8.4,amp_scale=1)
floor=-1.4
v_scale=1
floor=-3
r=1.3*1.3
r_s=.7*1.3
tw=.1

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

luts = [lut1,lut2]

#plot floor
torus=def_torus(r,r_s)
cvs = np.sin((phases))
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

# plot trajectories
for i in range(len(ks)):
    x1,y1,z1 = tor(ks[i,1],ks[i,0], phases,r)
    surf1 = mlab.plot3d(x1, y1, z1*v_scale,
                        cvs, tube_radius=tw, colormap='cool')

    #set colormap
    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]

    #set lightning
    surf1.actor.property.lighting = False
   
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')

    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False

mlab.draw()
mlab.view(-60, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)

In [None]:
ks,phases,_ = get_traj(rnn,task_params,freq=6.2,amp_scale=.27)

floor=-1.4
v_scale=1
floor=-3
r=1.3*1.3
r_s=.7*1.3
tw=.1

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

#plot floor
torus=def_torus(r,r_s)
cvs = np.sin((phases))
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

# plot trajectories
for i in range(len(ks)):
    x1,y1,z1 = tor(ks[i,1],ks[i,0], phases,r)
    surf1 = mlab.plot3d(x1, y1, z1*v_scale,
                        cvs, tube_radius=tw, colormap='cool')

    #set colormap
    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]

    #set lightning
    surf1.actor.property.lighting = False 
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')
    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False

mlab.draw()
mlab.view(-60, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)

In [None]:
ks,phases,_ = get_traj(rnn,task_params,freq=8.4,amp_scale=.27)

floor=-1.4
v_scale=1
floor=-3
r=1.3*1.3
r_s=.7*1.3
tw=.1

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

#plot floor
torus=def_torus(r,r_s)
cvs = np.sin((phases))
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

# plot trajectories
for i in [0]:
    x1,y1,z1 = tor(ks[i,1],ks[i,0], phases,r)
    surf1 = mlab.plot3d(x1, y1, z1*v_scale,
                        cvs, tube_radius=tw, colormap='cool')

    #set colormap
    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]

    #set lightning
    surf1.actor.property.lighting = False
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')
    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False

mlab.draw()
mlab.view(-60, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)


In [None]:
ks,phases,_ = get_traj(rnn,task_params,freq=4.5,amp_scale=.7)

v_scale=3
floor=-1.5
r=1
r_s=.5
tw=.1/2

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

# plot floor
torus=def_torus(r,r_s)
cvs = np.sin((phases))
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

# plot trajectory
for i in range(len(ks)):
    x1,y1,z1 = tor(ks[i,1],ks[i,0], phases,r)
    surf1 = mlab.plot3d(x1, y1, z1*v_scale,
                        cvs, tube_radius=tw, colormap='cool')

    #set colormap
    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]

    #set lightning
    surf1.actor.property.lighting = False
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')

    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False

mlab.draw()
mlab.view(-60, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)

In [None]:
ks,phases,_ = get_traj(rnn,task_params,freq=2.2,amp_scale=.5)

v_scale=3
floor=-1.5
r=1
r_s=.5
tw=.1/2

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

# plot floor
torus=def_torus(r,r_s)
cvs = np.sin((phases))
mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=(0.5,0.5,0.5), opacity=0.05)

# plot trajectories
for i in range(len(ks)):
    x1,y1,z1 = tor(ks[i,1],ks[i,0], phases,r)
    surf1 = mlab.plot3d(x1, y1, z1*v_scale,
                        cvs, tube_radius=tw, colormap='cool')

    #set colormap
    surf1.module_manager.scalar_lut_manager.lut.table = luts[i]

    #set lightning
    surf1.actor.property.lighting = False
    sh_surf1 = mlab.plot3d(x1,y1, floor*np.ones_like(phases),
                        cvs, tube_radius=tw, colormap='cool')

    shadow_lut1 = np.copy(luts[i])
    shadow_lut1[:,3]=20
    sh_surf1.module_manager.scalar_lut_manager.lut.table = shadow_lut1
    sh_surf1.actor.property.lighting = False

mlab.draw()
mlab.view(-60, 60, 15,  
          np.array([0, 0, 0]))
mlab.plot3d(0,0,0)