In [None]:
# import things
%matplotlib notebook
import tensorflow as tf
from tensorflow import keras
from mpl_toolkits import mplot3d
import numpy as np
from numpy import matlib
from scipy import signal
from scipy.spatial import distance
from scipy.stats import norm
from scipy.stats import percentileofscore
import scipy.linalg as linalg

import random
import matplotlib.pyplot as plt
import os

from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics import pairwise_distances
from scipy.optimize import least_squares
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import normalize

from dPCA import dPCA

In [None]:
# import desired emg or velocity trajectories
predict_emg = 1
do_emg_norm = 1
go_cue_idx = 20
hold_idx = 10

%run ./get_stim_dynamics_inputs.ipynb
x_train,y_train,x_test,y_test,degs = get_input_data(predict_emg,do_emg_norm,go_cue_idx,hold_idx)
plt.plot(y_train[1,:,:]);

# params
%run Param_classes.ipynb
args = InitialParams()   
stim_params = StimParams()

# more verbose params
if(predict_emg):
    l2_weight = 1e-3
    activation_weight = 1e-7
    args.lambda_l = 1e-6
    dropout_rate = 0.5
    conn_prob = 1
    args.latent_shape = [30,30]
    lr = 0.0001
    n_epochs = 3000
else:
    l2_weight = 1e-3
    activation_weight = 1e-3
    args.lambda_l = 1e-5
    dropout_rate = 0.5
    conn_prob = 1
    args.latent_shape = [20,20]
    lr = 0.0001
    n_epochs = 2000    

In [None]:
%run ./RNN_model.ipynb

num_units = args.latent_shape[0]*args.latent_shape[1]
temp_lat = lateral_effect(args)   

input_layer = keras.Input(shape=[None,x_train.shape[-1]],batch_size=x_train.shape[0])
cell = SimpleRNN_pos_loss(num_units,activation="tanh",conn_prob = conn_prob,
                              kernel_regularizer=keras.regularizers.l2(l2_weight),activation_weight=activation_weight)
RNN_layer, state_h = keras.layers.RNN(cell,return_sequences=True,return_state=True,dynamic=True)(input_layer)                                   
dropout_layer = keras.layers.Dropout(dropout_rate)(RNN_layer)
output_layer = keras.layers.TimeDistributed(keras.layers.Dense(y_test.shape[2]))(dropout_layer)

model = tf.keras.Model(inputs=input_layer,outputs=output_layer)
stim_params.is_stim=False
optimizer = keras.optimizers.Adam(learning_rate=lr)


early_stopping_cb = keras.callbacks.EarlyStopping(patience=n_epochs/5,
                                                  restore_best_weights=True,monitor="loss")

model.compile(loss="mse",optimizer=optimizer,metrics=[keras.metrics.MeanSquaredError()])

In [None]:
history=model.fit(x_train,y_train,epochs=n_epochs,verbose=True,callbacks=[early_stopping_cb])
layer_outputs = [layer.output for layer in model.layers]
activation_model = keras.models.Model(inputs=model.input,outputs=layer_outputs)

In [None]:
# summary plots
%run ./RNN_analysis_functions.ipynb
%run ./RNN_plot_functions.ipynb

In [None]:
color_list=['b','g','r','c','m','y','k','tab:purple']
        
go_cue_offset = [10,30]
activations = activation_model.predict(x_test)
rnn_activations = activations[1]
y_pred = activations[-1]
pd_data,depth_data = get_pds(rnn_activations,go_cue_offset[0]+go_cue_idx,go_cue_offset[1]+go_cue_idx)

plt.figure()
plt.subplot(2,2,1)
plt.hist(pd_data,20);
plt.subplot(2,2,2)
plt.hist(depth_data,20);
plt.subplot(2,2,3)

pd_grid = np.zeros(args.latent_shape)
pos_lis = cell.pos.astype("int")
for i in range(pd_data.shape[0]):
    pd_grid[pos_lis[i,0],pos_lis[i,1]] = pd_data[i]
plt.imshow(pd_grid)

if(predict_emg==1):
    out_idx = [12,14,16,18]
else:
    out_idx = [0,1]
    
plt.figure()
for i_out in range(len(out_idx)):
    plt.subplot(2,2,i_out+1)
    for tgt_dir_idx in range(y_test.shape[0]):
        plt.plot(np.transpose(y_pred[tgt_dir_idx,:,out_idx[i_out]]),'--',color=color_list[tgt_dir_idx]) 
        plt.plot(np.transpose(y_test[tgt_dir_idx,:,out_idx[i_out]]),'-',color=color_list[tgt_dir_idx])
        
        
rnn_activation_grid = make_activation_grid(rnn_activations,cell,args)
idx_plot = [go_cue_idx,go_cue_idx+10]
plot_activation_grids(rnn_activation_grid,idx_plot)

idx_plot = [go_cue_idx+15,go_cue_idx+20]
plot_activation_grids(rnn_activation_grid,idx_plot)
    
idx_plot = [go_cue_idx+30,go_cue_idx+35]
plot_activation_grids(rnn_activation_grid,idx_plot)

In [None]:
# compare activation map vs shuffled maps to see if there is an organization of neurons
num_shuffles = 500
        
activations = activation_model.predict(x_test)
rnn_activations = activations[1]    
rnn_activations.shape
act_loss, shuffled_loss, act_loss_perc = evaluate_topography(rnn_activations)

plt.figure()
plt.subplot(2,1,1)
plt.plot(100-act_loss_perc) # bigger = more organization
plt.subplot(2,1,2)
plt.plot(y_pred[0,:,0])

In [None]:
stim_params = StimParams()
args.noise_val = 0.0
stim_params.stim_pos = np.array([12,12])
stim_params.stim_dist_tau = 2
stim_params.stim_time = [go_cue_idx+10,go_cue_idx+15]

rnn_activations_stim,rnn_activations_no_stim,rnn_activation_grid_stim,rnn_activation_grid_no_stim,y_pred_stim,y_pred_no_stim = run_example_stim_exp()

rnn_activation_grid_diff = rnn_activation_grid_stim - rnn_activation_grid_no_stim

# plot activation difference
plt.figure()
subplot_idx = [6,2,4,8]
for i in range(len(subplot_idx)):
    plt.subplot(3,3,subplot_idx[i])
    plt.imshow(rnn_activation_grid_diff[i,stim_params.stim_time[0],:,:])
    plt.colorbar()
 

plt.figure()
subplot_idx = [6,2,4,8]
for i in range(len(subplot_idx)):
    plt.subplot(3,3,subplot_idx[i])
    plt.imshow(rnn_activation_grid_diff[i,stim_params.stim_time[0]+2,:,:])
    plt.colorbar()
    
plt.figure()
subplot_idx = [6,2,4,8]
for i in range(len(subplot_idx)):
    plt.subplot(3,3,subplot_idx[i])
    plt.imshow(rnn_activation_grid_diff[i,stim_params.stim_time[1]+1,:,:])
    plt.colorbar()

# plot prediction in both cases
plt.figure()
out_idx = [0,1,2,3,4,5,6,7]
color_list = ['k','r','b','g']
counter = 0
for musc_idx in out_idx:
    plt.subplot(4,2,counter+1)
    for tgt_dir_idx in range(y_pred_no_stim.shape[0]):
        plt.plot(y_pred_no_stim[tgt_dir_idx,:,musc_idx],color=color_list[tgt_dir_idx])
        plt.plot(y_pred_stim[tgt_dir_idx,:,musc_idx],linestyle='--',color=color_list[tgt_dir_idx])
        plt.axvline(stim_params.stim_time[0],linewidth=0.3)
        plt.axvline(stim_params.stim_time[0],linewidth=0.3)
        
    counter = counter + 1
    
# example neuron -- pick based on stim electrode location
rec_neuron = get_recorded_neuron(stim_params.stim_pos,cell.pos)

print(cell.pos[rec_neuron,:])
print(pd_data[rec_neuron])

plt.figure()
unit=rec_neuron
for tgt_idx in range(rnn_activations_no_stim.shape[0]):
    plt.subplot(3,3,tgt_idx+1)
    plt.plot(np.transpose(rnn_activations_no_stim[tgt_idx,:,unit]),color=color_list[tgt_idx]);
    plt.plot(np.transpose(rnn_activations_stim[tgt_idx,:,unit]),'--',color=color_list[tgt_idx]);
    plt.axvline(stim_params.stim_time[0],linewidth=0.3)
    plt.axvline(stim_params.stim_time[1],linewidth=0.3)  

In [None]:
stim_params = StimParams()
# get stim conditions
x = np.linspace(0, args.latent_shape[0], 5, dtype=np.float32)
y = np.linspace(0, args.latent_shape[1], 5, dtype=np.float32)
xv, yv = np.meshgrid(x, y)
xv = np.reshape(xv, (xv.size, 1))
yv = np.reshape(yv, (yv.size, 1))
stim_pos_test = np.hstack((xv, yv))
stim_dist_tau_test = [0.5,1,2,4,8]
stim_duration_test = [2,4,8,12]

emg_msd, rnn_msd, emg_recov, emg_stim_effect, rnn_recov, rnn_stim_effect=run_stim_exp()

# average over stim positions
emg_recov_mean = np.mean(np.mean(emg_recov,axis=3),axis=0)
emg_stim_effect_mean = np.mean(np.mean(emg_stim_effect,axis=3),axis=0)
rnn_recov_mean = np.mean(np.mean(rnn_recov,axis=3),axis=0)
rnn_stim_effect_mean = np.mean(np.mean(rnn_stim_effect,axis=3),axis=0)

# plot effect of amplitude and stim duration
plt.figure()
ax = plt.axes(projection='3d')
for i_tau in range(len(stim_dist_tau_test)):
    for i_dur in range(len(stim_duration_test)):
        ax.scatter3D(stim_dist_tau_test[i_tau],stim_duration_test[i_dur],rnn_stim_effect_mean[i_tau,i_dur],'.',color='k')
        
ax.set_xlabel("stim tau")
ax.set_ylabel("stim duration")
ax.set_zlabel("stim effect")

# plot recovery with different amplitudes and stim durations
plt.figure()
ax = plt.axes(projection='3d')
for i_tau in range(len(stim_dist_tau_test)):
    for i_dur in range(len(stim_duration_test)):
        ax.scatter3D(stim_dist_tau_test[i_tau],stim_duration_test[i_dur],rnn_recov_mean[i_tau,i_dur],'.',color='k')
        
ax.set_xlabel("stim tau")
ax.set_ylabel("stim duration")
ax.set_zlabel("recovery")

# plot example MSD for rnn and emg
plt.figure()
plt.subplot(2,1,1)
plt.plot(np.transpose(rnn_msd[:,-1,0,2,:]))
plt.subplot(2,1,2)
plt.plot(np.transpose(emg_msd[:,-1,0,2,:]))

In [None]:
# dynamical analysis (PCA? and such on activations in RNN layer)
# activations is tgt_dir x time x neuron
# truncate activations from go cue to end
stim_params = StimParams()
args.noise_val = 0.0
stim_params.stim_pos = np.array([10,10])
stim_params.stim_dist_tau = 5
stim_params.stim_time = [go_cue_idx+10,go_cue_idx+15]

rnn_activations_stim,rnn_activations_no_stim,rnn_activation_grid_stim,rnn_activation_grid_no_stim,y_pred_stim,y_pred_no_stim = run_example_stim_exp()

rnn_no_stim_trunc = rnn_activations_no_stim[:,:,:]
rnn_stim_trunc = rnn_activations_stim[:,:,:]

rnn_activations_no_stim_flat = rnn_no_stim_trunc.reshape(-1,rnn_activations_no_stim.shape[-1])
rnn_activations_stim_flat = rnn_stim_trunc.reshape(-1,rnn_activations_stim.shape[-1])

pca = PCA(n_components=10)
pca.fit(rnn_activations_no_stim_flat)

rnn_pca_no_stim = pca.transform(rnn_activations_no_stim_flat)
rnn_pca_stim = pca.transform(rnn_activations_stim_flat)
rnn_pca_no_stim = rnn_pca_no_stim.reshape(len(degs),int(rnn_pca_no_stim.shape[0]/len(degs)),rnn_pca_no_stim.shape[1])
rnn_pca_stim = rnn_pca_stim.reshape(len(degs),int(rnn_pca_stim.shape[0]/len(degs)),rnn_pca_stim.shape[1])

dims = [0,1]
plt.figure()
for tgt_idx in range(rnn_pca_no_stim.shape[0]):
    plot_pca_traces(rnn_pca_no_stim,tgt_idx,dims)
    plot_pca_traces(rnn_pca_stim,tgt_idx,dims,is_stim=True)
    
plt.figure()
for tgt_idx in range(rnn_pca_no_stim.shape[0]):
    plt.subplot(3,3,tgt_idx+1)
    plot_pca_traces(rnn_pca_no_stim,tgt_idx,dims)
    plot_pca_traces(rnn_pca_stim,tgt_idx,dims,is_stim=True)
    
    


In [None]:
# dpca
# X = multi dimensional array. [n,t,s,d] where n is neurons, t is time, s is stimulus, d is decision
# labels = optional; list of characters to describe parameter axes ('tsd')
# n_components = Dictionary or integer; if integer use the same number of components in each marginalization,
    #otherwise every (key,value) pair refers to the number of components (value) in a marginalization (key).
stim_params = StimParams()
args.noise_val = 0.0
stim_params.stim_pos = np.array([0,0])
stim_params.stim_dist_tau = 1
stim_params.stim_time = [go_cue_idx+10,go_cue_idx+15]

rnn_activations_stim,rnn_activations_no_stim,\
    rnn_activation_grid_stim,rnn_activation_grid_no_stim,\
    y_pred_stim,y_pred_no_stim = run_example_stim_exp()
        
X_no_stim = np.transpose(rnn_activations_no_stim,(2,1,0))
X_stim = np.transpose(rnn_activations_stim,(2,1,0))

labels = 'ts'
dpca = dPCA.dPCA(labels)
dpca.protect = ['t']
Z_no_stim = dpca.fit_transform(X_no_stim)
Z_stim = dpca.transform(X_stim)

for i_dim in range(3):
    plt.figure()
    plot_dpca_traces(Z_no_stim,0,'t',i_dim)
    plot_dpca_traces(Z_stim,1,'t',i_dim)
     
dims = [0,1]
plt.figure()
for i_tgt in range(Z_stim['s'].shape[-1]):
    plt.plot(Z_no_stim['s'][dims[0],:,i_tgt],Z_no_stim['s'][dims[1],:,i_tgt])
    #plt.plot(Z_stim['s'][dims[0],:,i_tgt],Z_stim['s'][dims[1],:,i_tgt])

In [None]:
# Potent and Null space analysis
# get a mapping from neuron activation to muscle activation
explained_var_thresh = 0.75

%run ./RNN_analysis_functions.ipynb

activations = activation_model.predict(x_test)
neural_act = activations[1] 
muscle_act = activations[-1]

neural_act_flat = flatten_activation_mat(neural_act)
muscle_act_flat = flatten_activation_mat(muscle_act)

# reduce dimensions -- match explained variance across muscles and neurons
neural_pca = PCA()
muscle_pca = PCA()
neural_pca.fit(neural_act_flat)
muscle_pca.fit(muscle_act_flat)

n_neural_comp = np.argwhere(np.cumsum(neural_pca.explained_variance_ratio_) >= explained_var_thresh)[0][0]
n_muscle_comp = np.argwhere(np.cumsum(muscle_pca.explained_variance_ratio_) >= explained_var_thresh)[0][0]

neural_pca = PCA(n_components=n_neural_comp)
muscle_pca = PCA(n_components=n_muscle_comp)
neural_act_low_dim = unflatten_activation_mat(neural_pca.fit_transform(neural_act_flat),neural_act)
muscle_act_low_dim = unflatten_activation_mat(muscle_pca.fit_transform(muscle_act_flat),muscle_act)

# project
null_space, potent_space = get_null_potent_space(neural_act_low_dim, muscle_act_low_dim) # neural dims, muscle dims
neural_null_proj = project_into_space(neural_act_low_dim, null_space)
neural_pot_proj = project_into_space(neural_act_low_dim, potent_space)

# stimulate and project activations into potent and null space -- do this for different amplitudes and stim locations
stim_params = StimParams()
args.noise_val = 0.0
stim_params.stim_time = [go_cue_idx+10,go_cue_idx+15]
amps_test = np.array([0.5,1,2,5,10])
# get locs test
x = np.linspace(0, args.latent_shape[0], 5, dtype=np.float32)
y = np.linspace(0, args.latent_shape[1], 5, dtype=np.float32)
xv, yv = np.meshgrid(x, y)
xv = np.reshape(xv, (xv.size, 1))
yv = np.reshape(yv, (yv.size, 1))
stim_pos_test = np.hstack((xv, yv))

null_dist = np.zeros((amps_test.shape[0],stim_pos_test.shape[0],neural_act.shape[1]))
pot_dist = np.zeros_like(null_dist)

neural_stim_null_proj_all = np.zeros((amps_test.shape[0],stim_pos_test.shape[0],\
                                      neural_act.shape[0],neural_act.shape[1],\
                                      n_neural_comp-n_muscle_comp))
neural_stim_pot_proj_all = np.zeros((amps_test.shape[0],stim_pos_test.shape[0],\
                                     neural_act.shape[0],neural_act.shape[1],\
                                      n_muscle_comp))

for i_amp in range(amps_test.shape[0]):
    for i_loc in range(stim_pos_test.shape[0]):
        stim_params.stim_pos = stim_pos_test[i_loc,:]
        stim_params.stim_dist_tau = amps_test[i_amp]


        neural_act_stim,neural_act_no_stim,\
            neural_act_grid_stim,neural_act_grid_no_stim,\
            musc_pred_stim,musc_pred_no_stim = run_example_stim_exp()

        # lower dimensionality then project stim data into spaces
        neural_act_stim_flat = flatten_activation_mat(neural_act_stim)
        neural_act_stim_low_d =  unflatten_activation_mat(neural_pca.transform(neural_act_stim_flat),neural_act_stim)

        neural_stim_null_proj = project_into_space(neural_act_stim_low_d,null_space)
        neural_stim_pot_proj = project_into_space(neural_act_stim_low_d,potent_space)

        # compute distance moved from normal trajectory during stim condition in null and potent space across time and store
        null_dist[i_amp,i_loc,:] = np.sum(np.sum(np.square(neural_stim_null_proj - neural_null_proj),axis=2),axis=0)
        pot_dist[i_amp,i_loc,:] = np.sum(np.sum(np.square(neural_stim_pot_proj - neural_pot_proj),axis=2),axis=0)
        
        # store all proj data
        neural_stim_null_proj_all[i_amp,i_loc,:,:,:] = neural_stim_null_proj
        neural_stim_pot_proj_all[i_amp,i_loc,:,:,:] = neural_stim_pot_proj

In [None]:
# analyze distance data -- time course (potent recovers slower than null?), 
    # magnitude post stim relative to pre stim (hump after end of stim implies planning?)

# plot dists for example site and amp
plt.figure() 
plt.plot(np.transpose(null_dist[4,12,:]),color='k');
plt.plot(np.transpose(pot_dist[4,12,:]),color='r');
plt.axvline(stim_params.stim_time[0],linewidth=0.5)
plt.axvline(stim_params.stim_time[1],linewidth=0.5)

# compare distances between null and pot across amps and sites
summed_null_dist = np.sum(null_dist,axis=2)
summed_pot_dist = np.sum(pot_dist,axis=2)
plot_space_dist(summed_null_dist,summed_pot_dist)
    
summed_null_dist = np.sum(null_dist[:,:,stim_params.stim_time[0]:stim_params.stim_time[1]],axis=2)
summed_pot_dist = np.sum(pot_dist[:,:,stim_params.stim_time[0]:stim_params.stim_time[1]],axis=2)
plot_space_dist(summed_null_dist,summed_pot_dist) 

summed_null_dist = np.sum(null_dist[:,:,stim_params.stim_time[1]+1:stim_params.stim_time[1]+5],axis=2)
summed_pot_dist = np.sum(pot_dist[:,:,stim_params.stim_time[1]+1:stim_params.stim_time[1]+5],axis=2)
plot_space_dist(summed_null_dist,summed_pot_dist) 

In [None]:
# plot proj onto null space and proj onto potent space (on the same axis?)
plt.figure()
dim = 0;
tgt = 2;
plt.plot(neural_stim_null_proj[tgt,:,dim],neural_stim_pot_proj[tgt,:,dim])
plt.plot(neural_stim_null_proj[tgt,stim_params.stim_time[0]:stim_params.stim_time[1],dim],\
         neural_stim_pot_proj[tgt,stim_params.stim_time[0]:stim_params.stim_time[1],dim])
plt.plot(neural_stim_null_proj[tgt,0,dim],neural_stim_pot_proj[tgt,0,dim],markersize=20,marker='.')
plt.plot(neural_stim_null_proj[tgt,go_cue_idx,dim],neural_stim_pot_proj[tgt,go_cue_idx,dim],markersize=20,marker='.')
plt.plot(neural_stim_null_proj[tgt,50,dim],neural_stim_pot_proj[tgt,50,dim],markersize=20,marker='.')

In [None]:
tgt = 2
amp = 3
loc = 10

plt.figure()
plt.subplot(2,1,1)
plt.plot(neural_pot_proj[tgt,:,:])
plt.gca().set_prop_cycle(None)
plt.plot(neural_stim_pot_proj_all[amp,loc,tgt,:,:],linestyle='--')
plt.axvline(hold_idx,linewidth=0.5)
plt.axvline(go_cue_idx,linewidth=0.5)
plt.axvline(stim_params.stim_time[0],linewidth=0.5)
plt.axvline(stim_params.stim_time[1],linewidth=0.5)

plt.subplot(2,1,2)
plt.plot(neural_null_proj[tgt,:,:])
plt.gca().set_prop_cycle(None)
plt.plot(neural_stim_null_proj_all[amp,loc,tgt,:,:],linestyle='--')
plt.axvline(hold_idx,linewidth=0.5)
plt.axvline(go_cue_idx,linewidth=0.5)
plt.axvline(stim_params.stim_time[0],linewidth=0.5)
plt.axvline(stim_params.stim_time[1],linewidth=0.5)