In [1]:
# import things
%matplotlib notebook
import tensorflow as tf
from tensorflow import keras
from mpl_toolkits import mplot3d
import numpy as np
from numpy import matlib
from scipy import signal
from scipy.spatial import distance
from scipy.stats import norm
from scipy.stats import percentileofscore
import random
import matplotlib.pyplot as plt
import os

from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics import pairwise_distances
from scipy.optimize import least_squares

from dPCA import dPCA

In [2]:
# import desired emg or velocity trajectories
predict_emg = 1
do_emg_norm = 1
go_cue_idx = 20
hold_idx = 10

%run ./get_stim_dynamics_inputs.ipynb
x_train,y_train,x_test,y_test = get_input_data(predict_emg,do_emg_norm,go_cue_idx,hold_idx)
plt.plot(y_train[1,:,:]);

# params
%run Param_classes.ipynb
args = InitialParams()   
stim_params = StimParams()

# more verbose params
if(predict_emg):
    l2_weight = 1e-3
    activation_weight = 1e-7
    args.lambda_l = 1e-6
    dropout_rate = 0.5
    conn_prob = 1
    args.latent_shape = [30,30]
    lr = 0.0001
    n_epochs = 2500
else:
    l2_weight = 1e-3
    activation_weight = 1e-3
    args.lambda_l = 1e-5
    dropout_rate = 0.5
    conn_prob = 1
    args.latent_shape = [20,20]
    lr = 0.0001
    n_epochs = 2000    

<IPython.core.display.Javascript object>

In [4]:
%run ./RNN_model.ipynb

num_units = args.latent_shape[0]*args.latent_shape[1]
temp_lat = lateral_effect(args)   

input_layer = keras.Input(shape=[None,x_train.shape[-1]],batch_size=x_train.shape[0])
cell = SimpleRNN_pos_loss(num_units,activation="tanh",conn_prob = conn_prob,
                              kernel_regularizer=keras.regularizers.l2(l2_weight),activation_weight=activation_weight)
RNN_layer, state_h = keras.layers.RNN(cell,return_sequences=True,return_state=True,dynamic=True)(input_layer)                                   
dropout_layer = keras.layers.Dropout(dropout_rate)(RNN_layer)
output_layer = keras.layers.TimeDistributed(keras.layers.Dense(y_test.shape[2]))(dropout_layer)

model = tf.keras.Model(inputs=input_layer,outputs=output_layer)
stim_params.is_stim=False
optimizer = keras.optimizers.Adam(learning_rate=lr)


early_stopping_cb = keras.callbacks.EarlyStopping(patience=n_epochs/5,
                                                  restore_best_weights=True,monitor="loss")

model.compile(loss="mse",optimizer=optimizer,metrics=[keras.metrics.MeanSquaredError()])

In [None]:
history=model.fit(x_train,y_train,epochs=n_epochs,verbose=True,callbacks=[early_stopping_cb])
layer_outputs = [layer.output for layer in model.layers]
activation_model = keras.models.Model(inputs=model.input,outputs=layer_outputs)

Train on 4 samples
Epoch 1/2500
Epoch 2/2500
Epoch 3/2500
Epoch 4/2500
Epoch 5/2500
Epoch 6/2500
Epoch 7/2500
Epoch 8/2500
Epoch 9/2500
Epoch 10/2500
Epoch 11/2500
Epoch 12/2500
Epoch 13/2500
Epoch 14/2500
Epoch 15/2500
Epoch 16/2500
Epoch 17/2500
Epoch 18/2500
Epoch 19/2500
Epoch 20/2500
Epoch 21/2500
Epoch 22/2500
Epoch 23/2500
Epoch 24/2500
Epoch 25/2500
Epoch 26/2500
Epoch 27/2500
Epoch 28/2500
Epoch 29/2500
Epoch 30/2500
Epoch 31/2500
Epoch 32/2500
Epoch 33/2500
Epoch 34/2500
Epoch 35/2500
Epoch 36/2500
Epoch 37/2500
Epoch 38/2500
Epoch 39/2500
Epoch 40/2500
Epoch 41/2500
Epoch 42/2500
Epoch 43/2500
Epoch 44/2500
Epoch 45/2500
Epoch 46/2500
Epoch 47/2500
Epoch 48/2500
Epoch 49/2500
Epoch 50/2500
Epoch 51/2500
Epoch 52/2500
Epoch 53/2500
Epoch 54/2500
Epoch 55/2500
Epoch 56/2500
Epoch 57/2500
Epoch 58/2500
Epoch 59/2500
Epoch 60/2500
Epoch 61/2500
Epoch 62/2500
Epoch 63/2500
Epoch 64/2500
Epoch 65/2500
Epoch 66/2500
Epoch 67/2500
Epoch 68/2500
Epoch 69/2500
Epoch 70/2500
Epoch 71/2

Epoch 75/2500
Epoch 76/2500
Epoch 77/2500
Epoch 78/2500
Epoch 79/2500
Epoch 80/2500
Epoch 81/2500
Epoch 82/2500
Epoch 83/2500
Epoch 84/2500
Epoch 85/2500
Epoch 86/2500
Epoch 87/2500
Epoch 88/2500
Epoch 89/2500
Epoch 90/2500
Epoch 91/2500
Epoch 92/2500
Epoch 93/2500
Epoch 94/2500
Epoch 95/2500
Epoch 96/2500
Epoch 97/2500
Epoch 98/2500
Epoch 99/2500
Epoch 100/2500
Epoch 101/2500
Epoch 102/2500
Epoch 103/2500
Epoch 104/2500
Epoch 105/2500
Epoch 106/2500
Epoch 107/2500
Epoch 108/2500
Epoch 109/2500
Epoch 110/2500
Epoch 111/2500
Epoch 112/2500
Epoch 113/2500
Epoch 114/2500
Epoch 115/2500
Epoch 116/2500
Epoch 117/2500
Epoch 118/2500
Epoch 119/2500
Epoch 120/2500
Epoch 121/2500
Epoch 122/2500
Epoch 123/2500
Epoch 124/2500
Epoch 125/2500
Epoch 126/2500
Epoch 127/2500
Epoch 128/2500
Epoch 129/2500
Epoch 130/2500
Epoch 131/2500
Epoch 132/2500
Epoch 133/2500
Epoch 134/2500
Epoch 135/2500
Epoch 136/2500
Epoch 137/2500
Epoch 138/2500
Epoch 139/2500
Epoch 140/2500
Epoch 141/2500
Epoch 142/2500
Epoch

Epoch 147/2500
Epoch 148/2500
Epoch 149/2500
Epoch 150/2500
Epoch 151/2500
Epoch 152/2500
Epoch 153/2500
Epoch 154/2500
Epoch 155/2500
Epoch 156/2500
Epoch 157/2500
Epoch 158/2500
Epoch 159/2500
Epoch 160/2500
Epoch 161/2500
Epoch 162/2500
Epoch 163/2500
Epoch 164/2500
Epoch 165/2500
Epoch 166/2500
Epoch 167/2500
Epoch 168/2500
Epoch 169/2500
Epoch 170/2500
Epoch 171/2500
Epoch 172/2500
Epoch 173/2500
Epoch 174/2500
Epoch 175/2500
Epoch 176/2500
Epoch 177/2500
Epoch 178/2500
Epoch 179/2500
Epoch 180/2500
Epoch 181/2500
Epoch 182/2500
Epoch 183/2500
Epoch 184/2500
Epoch 185/2500
Epoch 186/2500
Epoch 187/2500
Epoch 188/2500
Epoch 189/2500
Epoch 190/2500
Epoch 191/2500
Epoch 192/2500
Epoch 193/2500
Epoch 194/2500
Epoch 195/2500
Epoch 196/2500
Epoch 197/2500
Epoch 198/2500
Epoch 199/2500
Epoch 200/2500
Epoch 201/2500
Epoch 202/2500
Epoch 203/2500
Epoch 204/2500
Epoch 205/2500
Epoch 206/2500
Epoch 207/2500
Epoch 208/2500
Epoch 209/2500
Epoch 210/2500
Epoch 211/2500
Epoch 212/2500
Epoch 213/

Epoch 220/2500
Epoch 221/2500
Epoch 222/2500
Epoch 223/2500
Epoch 224/2500
Epoch 225/2500
Epoch 226/2500
Epoch 227/2500
Epoch 228/2500
Epoch 229/2500
Epoch 230/2500
Epoch 231/2500
Epoch 232/2500
Epoch 233/2500
Epoch 234/2500
Epoch 235/2500
Epoch 236/2500
Epoch 237/2500
Epoch 238/2500
Epoch 239/2500
Epoch 240/2500
Epoch 241/2500
Epoch 242/2500
Epoch 243/2500
Epoch 244/2500
Epoch 245/2500
Epoch 246/2500
Epoch 247/2500
Epoch 248/2500
Epoch 249/2500
Epoch 250/2500
Epoch 251/2500
Epoch 252/2500
Epoch 253/2500
Epoch 254/2500
Epoch 255/2500
Epoch 256/2500
Epoch 257/2500
Epoch 258/2500
Epoch 259/2500
Epoch 260/2500
Epoch 261/2500
Epoch 262/2500
Epoch 263/2500
Epoch 264/2500
Epoch 265/2500
Epoch 266/2500
Epoch 267/2500
Epoch 268/2500
Epoch 269/2500
Epoch 270/2500
Epoch 271/2500
Epoch 272/2500
Epoch 273/2500
Epoch 274/2500
Epoch 275/2500
Epoch 276/2500
Epoch 277/2500
Epoch 278/2500
Epoch 279/2500
Epoch 280/2500
Epoch 281/2500
Epoch 282/2500
Epoch 283/2500
Epoch 284/2500
Epoch 285/2500
Epoch 286/

Epoch 293/2500
Epoch 294/2500
Epoch 295/2500
Epoch 296/2500
Epoch 297/2500
Epoch 298/2500
Epoch 299/2500
Epoch 300/2500
Epoch 301/2500
Epoch 302/2500
Epoch 303/2500
Epoch 304/2500
Epoch 305/2500
Epoch 306/2500
Epoch 307/2500
Epoch 308/2500
Epoch 309/2500
Epoch 310/2500
Epoch 311/2500
Epoch 312/2500
Epoch 313/2500
Epoch 314/2500
Epoch 315/2500
Epoch 316/2500
Epoch 317/2500
Epoch 318/2500
Epoch 319/2500
Epoch 320/2500
Epoch 321/2500
Epoch 322/2500
Epoch 323/2500
Epoch 324/2500
Epoch 325/2500
Epoch 326/2500
Epoch 327/2500
Epoch 328/2500
Epoch 329/2500
Epoch 330/2500
Epoch 331/2500
Epoch 332/2500
Epoch 333/2500
Epoch 334/2500
Epoch 335/2500
Epoch 336/2500
Epoch 337/2500
Epoch 338/2500
Epoch 339/2500
Epoch 340/2500
Epoch 341/2500
Epoch 342/2500
Epoch 343/2500
Epoch 344/2500
Epoch 345/2500
Epoch 346/2500
Epoch 347/2500
Epoch 348/2500
Epoch 349/2500
Epoch 350/2500
Epoch 351/2500
Epoch 352/2500
Epoch 353/2500
Epoch 354/2500
Epoch 355/2500
Epoch 356/2500
Epoch 357/2500
Epoch 358/2500
Epoch 359/

Epoch 366/2500
Epoch 367/2500
Epoch 368/2500
Epoch 369/2500
Epoch 370/2500
Epoch 371/2500
Epoch 372/2500
Epoch 373/2500
Epoch 374/2500
Epoch 375/2500
Epoch 376/2500
Epoch 377/2500
Epoch 378/2500
Epoch 379/2500
Epoch 380/2500
Epoch 381/2500
Epoch 382/2500
Epoch 383/2500
Epoch 384/2500
Epoch 385/2500
Epoch 386/2500
Epoch 387/2500
Epoch 388/2500
Epoch 389/2500
Epoch 390/2500
Epoch 391/2500
Epoch 392/2500
Epoch 393/2500
Epoch 394/2500
Epoch 395/2500
Epoch 396/2500
Epoch 397/2500
Epoch 398/2500
Epoch 399/2500
Epoch 400/2500
Epoch 401/2500
Epoch 402/2500
Epoch 403/2500
Epoch 404/2500
Epoch 405/2500
Epoch 406/2500
Epoch 407/2500
Epoch 408/2500
Epoch 409/2500
Epoch 410/2500
Epoch 411/2500
Epoch 412/2500
Epoch 413/2500
Epoch 414/2500
Epoch 415/2500
Epoch 416/2500
Epoch 417/2500
Epoch 418/2500
Epoch 419/2500
Epoch 420/2500
Epoch 421/2500
Epoch 422/2500
Epoch 423/2500
Epoch 424/2500
Epoch 425/2500
Epoch 426/2500
Epoch 427/2500
Epoch 428/2500
Epoch 429/2500
Epoch 430/2500
Epoch 431/2500
Epoch 432/

Epoch 439/2500
Epoch 440/2500
Epoch 441/2500
Epoch 442/2500
Epoch 443/2500
Epoch 444/2500
Epoch 445/2500
Epoch 446/2500
Epoch 447/2500
Epoch 448/2500
Epoch 449/2500
Epoch 450/2500
Epoch 451/2500
Epoch 452/2500
Epoch 453/2500
Epoch 454/2500
Epoch 455/2500
Epoch 456/2500
Epoch 457/2500
Epoch 458/2500
Epoch 459/2500
Epoch 460/2500
Epoch 461/2500
Epoch 462/2500
Epoch 463/2500
Epoch 464/2500
Epoch 465/2500
Epoch 466/2500
Epoch 467/2500
Epoch 468/2500
Epoch 469/2500
Epoch 470/2500
Epoch 471/2500
Epoch 472/2500
Epoch 473/2500
Epoch 474/2500
Epoch 475/2500
Epoch 476/2500
Epoch 477/2500
Epoch 478/2500
Epoch 479/2500
Epoch 480/2500
Epoch 481/2500
Epoch 482/2500
Epoch 483/2500
Epoch 484/2500
Epoch 485/2500
Epoch 486/2500
Epoch 487/2500
Epoch 488/2500
Epoch 489/2500
Epoch 490/2500
Epoch 491/2500
Epoch 492/2500
Epoch 493/2500
Epoch 494/2500
Epoch 495/2500
Epoch 496/2500
Epoch 497/2500
Epoch 498/2500
Epoch 499/2500
Epoch 500/2500
Epoch 501/2500
Epoch 502/2500
Epoch 503/2500
Epoch 504/2500
Epoch 505/

Epoch 512/2500
Epoch 513/2500
Epoch 514/2500
Epoch 515/2500
Epoch 516/2500
Epoch 517/2500
Epoch 518/2500
Epoch 519/2500
Epoch 520/2500
Epoch 521/2500
Epoch 522/2500
Epoch 523/2500
Epoch 524/2500
Epoch 525/2500
Epoch 526/2500
Epoch 527/2500
Epoch 528/2500
Epoch 529/2500
Epoch 530/2500
Epoch 531/2500
Epoch 532/2500
Epoch 533/2500
Epoch 534/2500
Epoch 535/2500
Epoch 536/2500
Epoch 537/2500
Epoch 538/2500
Epoch 539/2500
Epoch 540/2500
Epoch 541/2500
Epoch 542/2500
Epoch 543/2500
Epoch 544/2500
Epoch 545/2500
Epoch 546/2500
Epoch 547/2500
Epoch 548/2500
Epoch 549/2500
Epoch 550/2500
Epoch 551/2500
Epoch 552/2500
Epoch 553/2500
Epoch 554/2500
Epoch 555/2500
Epoch 556/2500
Epoch 557/2500
Epoch 558/2500
Epoch 559/2500
Epoch 560/2500
Epoch 561/2500
Epoch 562/2500
Epoch 563/2500
Epoch 564/2500
Epoch 565/2500
Epoch 566/2500
Epoch 567/2500
Epoch 568/2500
Epoch 569/2500
Epoch 570/2500
Epoch 571/2500
Epoch 572/2500
Epoch 573/2500
Epoch 574/2500
Epoch 575/2500
Epoch 576/2500
Epoch 577/2500
Epoch 578/

Epoch 585/2500
Epoch 586/2500
Epoch 587/2500
Epoch 588/2500
Epoch 589/2500
Epoch 590/2500
Epoch 591/2500
Epoch 592/2500
Epoch 593/2500
Epoch 594/2500
Epoch 595/2500
Epoch 596/2500
Epoch 597/2500
Epoch 598/2500
Epoch 599/2500
Epoch 600/2500
Epoch 601/2500
Epoch 602/2500
Epoch 603/2500
Epoch 604/2500
Epoch 605/2500
Epoch 606/2500
Epoch 607/2500
Epoch 608/2500
Epoch 609/2500
Epoch 610/2500
Epoch 611/2500
Epoch 612/2500
Epoch 613/2500
Epoch 614/2500
Epoch 615/2500
Epoch 616/2500
Epoch 617/2500
Epoch 618/2500
Epoch 619/2500
Epoch 620/2500
Epoch 621/2500
Epoch 622/2500
Epoch 623/2500
Epoch 624/2500
Epoch 625/2500
Epoch 626/2500
Epoch 627/2500
Epoch 628/2500
Epoch 629/2500
Epoch 630/2500
Epoch 631/2500
Epoch 632/2500
Epoch 633/2500
Epoch 634/2500
Epoch 635/2500
Epoch 636/2500
Epoch 637/2500
Epoch 638/2500
Epoch 639/2500
Epoch 640/2500
Epoch 641/2500
Epoch 642/2500
Epoch 643/2500
Epoch 644/2500
Epoch 645/2500
Epoch 646/2500
Epoch 647/2500
Epoch 648/2500
Epoch 649/2500
Epoch 650/2500
Epoch 651/

Epoch 658/2500
Epoch 659/2500
Epoch 660/2500
Epoch 661/2500
Epoch 662/2500
Epoch 663/2500
Epoch 664/2500
Epoch 665/2500
Epoch 666/2500
Epoch 667/2500
Epoch 668/2500
Epoch 669/2500
Epoch 670/2500
Epoch 671/2500
Epoch 672/2500
Epoch 673/2500
Epoch 674/2500
Epoch 675/2500
Epoch 676/2500
Epoch 677/2500
Epoch 678/2500
Epoch 679/2500
Epoch 680/2500
Epoch 681/2500
Epoch 682/2500
Epoch 683/2500
Epoch 684/2500
Epoch 685/2500
Epoch 686/2500
Epoch 687/2500
Epoch 688/2500
Epoch 689/2500
Epoch 690/2500
Epoch 691/2500
Epoch 692/2500
Epoch 693/2500
Epoch 694/2500
Epoch 695/2500
Epoch 696/2500
Epoch 697/2500
Epoch 698/2500
Epoch 699/2500
Epoch 700/2500
Epoch 701/2500
Epoch 702/2500
Epoch 703/2500
Epoch 704/2500
Epoch 705/2500
Epoch 706/2500
Epoch 707/2500
Epoch 708/2500
Epoch 709/2500
Epoch 710/2500
Epoch 711/2500
Epoch 712/2500
Epoch 713/2500
Epoch 714/2500
Epoch 715/2500
Epoch 716/2500
Epoch 717/2500
Epoch 718/2500
Epoch 719/2500
Epoch 720/2500
Epoch 721/2500
Epoch 722/2500
Epoch 723/2500
Epoch 724/

Epoch 731/2500
Epoch 732/2500
Epoch 733/2500
Epoch 734/2500
Epoch 735/2500
Epoch 736/2500
Epoch 737/2500
Epoch 738/2500
Epoch 739/2500
Epoch 740/2500
Epoch 741/2500
Epoch 742/2500
Epoch 743/2500
Epoch 744/2500
Epoch 745/2500
Epoch 746/2500
Epoch 747/2500
Epoch 748/2500
Epoch 749/2500
Epoch 750/2500
Epoch 751/2500
Epoch 752/2500
Epoch 753/2500
Epoch 754/2500
Epoch 755/2500
Epoch 756/2500
Epoch 757/2500
Epoch 758/2500
Epoch 759/2500
Epoch 760/2500
Epoch 761/2500
Epoch 762/2500
Epoch 763/2500
Epoch 764/2500
Epoch 765/2500
Epoch 766/2500
Epoch 767/2500
Epoch 768/2500
Epoch 769/2500
Epoch 770/2500
Epoch 771/2500
Epoch 772/2500
Epoch 773/2500
Epoch 774/2500
Epoch 775/2500
Epoch 776/2500
Epoch 777/2500
Epoch 778/2500
Epoch 779/2500
Epoch 780/2500
Epoch 781/2500
Epoch 782/2500
Epoch 783/2500
Epoch 784/2500
Epoch 785/2500
Epoch 786/2500
Epoch 787/2500
Epoch 788/2500
Epoch 789/2500
Epoch 790/2500
Epoch 791/2500
Epoch 792/2500
Epoch 793/2500
Epoch 794/2500
Epoch 795/2500
Epoch 796/2500
Epoch 797/

Epoch 804/2500
Epoch 805/2500
Epoch 806/2500
Epoch 807/2500
Epoch 808/2500
Epoch 809/2500
Epoch 810/2500
Epoch 811/2500
Epoch 812/2500
Epoch 813/2500
Epoch 814/2500
Epoch 815/2500
Epoch 816/2500
Epoch 817/2500
Epoch 818/2500
Epoch 819/2500
Epoch 820/2500
Epoch 821/2500
Epoch 822/2500
Epoch 823/2500
Epoch 824/2500
Epoch 825/2500
Epoch 826/2500
Epoch 827/2500
Epoch 828/2500
Epoch 829/2500
Epoch 830/2500
Epoch 831/2500
Epoch 832/2500
Epoch 833/2500
Epoch 834/2500
Epoch 835/2500
Epoch 836/2500
Epoch 837/2500
Epoch 838/2500
Epoch 839/2500
Epoch 840/2500
Epoch 841/2500
Epoch 842/2500
Epoch 843/2500
Epoch 844/2500
Epoch 845/2500
Epoch 846/2500
Epoch 847/2500
Epoch 848/2500
Epoch 849/2500
Epoch 850/2500
Epoch 851/2500


In [None]:
# summary stuff
color_list=['b','g','r','c','m','y','k','tab:purple']

def get_pds(activations, start_idx,end_idx):
    # get mean activation from start_idx to end in each direction for each neuron
    mean_activation = np.mean(activations[:,start_idx:end_idx,:],axis=1)
    pds = np.zeros((mean_activation.shape[1],))
    depths = np.zeros((mean_activation.shape[1],))
    # fit mean fr with cosine model
    t=np.array(degs)
    guess_amp = 1; guess_phase = 0; guess_mean = 0;
    for unit in range(mean_activation.shape[1]):
        fit_data = mean_activation[:,unit]
        optimize_fn = lambda x: x[0]*np.cos(t+x[1]) + x[2] - fit_data
        fit = least_squares(optimize_fn,x0=[guess_amp,guess_phase,guess_mean],bounds=([0,-np.pi,-1],[3,np.pi,1]))        
        pds[unit] = fit.x[1]
        depths[unit] = fit.x[0]
        # use theta param as pd
    return pds, depths

def make_activation_grid(activations,cell,args):
    activation_grid = np.zeros((activations.shape[0],activations.shape[1],args.latent_shape[0],args.latent_shape[1]))
    pos_lis = cell.pos.astype("int")
    for i in range(activation_grid.shape[0]):
        for j in range(activation_grid.shape[1]):
            for k in range(pos_lis.shape[0]):
                activation_grid[i,j,pos_lis[k,0],pos_lis[k,1]] = activations[i,j,k]
    
    return activation_grid

        

go_cue_offset = [10,30]
activations = activation_model.predict(x_test)
rnn_activations = activations[1]
y_pred = activations[-1]
pd_data,depth_data = get_pds(rnn_activations,go_cue_offset[0]+go_cue_idx,go_cue_offset[1]+go_cue_idx)
plt.figure()
plt.subplot(2,2,1)
plt.hist(pd_data,20);
plt.subplot(2,2,2)
plt.hist(depth_data,20);
plt.subplot(2,2,3)

pd_grid = np.zeros(args.latent_shape)
pos_lis = cell.pos.astype("int")
for i in range(pd_data.shape[0]):
    pd_grid[pos_lis[i,0],pos_lis[i,1]] = pd_data[i]
plt.imshow(pd_grid)

if(predict_emg==1):
    out_idx = [12,14,16,18]
else:
    out_idx = [0,1]
    
plt.figure()
for i_out in range(len(out_idx)):
    plt.subplot(2,2,i_out+1)
    for tgt_dir_idx in range(y_test.shape[0]):
        plt.plot(np.transpose(y_pred[tgt_dir_idx,:,out_idx[i_out]]),'--',color=color_list[tgt_dir_idx]) 
        plt.plot(np.transpose(y_test[tgt_dir_idx,:,out_idx[i_out]]),'-',color=color_list[tgt_dir_idx])
        
        
rnn_activation_grid = make_activation_grid(rnn_activations,cell,args)
idx_plot = go_cue_idx
plt.figure()
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.imshow(rnn_activation_grid[i,idx_plot,:,:])
    
idx_plot = go_cue_idx+15
plt.figure()
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.imshow(rnn_activation_grid[i,idx_plot,:,:])
    
idx_plot = go_cue_idx+30
plt.figure()
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.imshow(rnn_activation_grid[i,idx_plot,:,:])

In [None]:
# compare activation map vs shuffled maps to see if there is "topography"
num_shuffles = 1000

def shuffle_map(act_map):
    act_map = np.transpose(act_map)
    act_map = np.random.permutation(act_map)
    act_map = np.transpose(act_map)
    return act_map

def get_lateral_loss(act_map):
    A = tf.matmul(tf.convert_to_tensor(act_map,dtype="float32"),cell.lateral_effect)
    # loss is -1*mean of diagonal elements
    return -1*tf.reduce_mean(tf.linalg.diag_part(A))

def evaluate_topography(activations):
    num_steps = activations.shape[1]
    shuffled_loss = np.zeros((num_steps,num_shuffles))
    act_loss = np.zeros((num_steps,))
    act_loss_perc = np.zeros_like(act_loss)
    # for each step in time
    for i_step in range(num_steps):
        # shuffle map some number of times
        activation_map = activations[:,i_step,:]
        act_loss[i_step] = get_lateral_loss(activation_map)
        
        for i_shuffle in range(num_shuffles):
            shuffled_map = shuffle_map(activation_map)
            shuffled_loss[i_step,i_shuffle] = get_lateral_loss(shuffled_map)
        # get percentile of act_loss
        act_loss_perc[i_step] = percentileofscore(shuffled_loss[i_step,:],act_loss[i_step])
        
        
    return act_loss, shuffled_loss, act_loss_perc
        
activations = activation_model.predict(x_test)
rnn_activations = activations[1]    
rnn_activations.shape
act_loss, shuffled_loss, act_loss_perc = evaluate_topography(rnn_activations)
plt.figure()
plt.subplot(2,1,1)
plt.plot(act_loss_perc)
plt.subplot(2,1,2)
plt.plot(y_pred[0,:,0])

In [None]:
# stimulation experiments    
def mean_squared_difference(x,y):
    # x,y is a tgt x time x signal matrix, get meansquared difference between x,y over time in all tgt conditions across signals
    return np.mean((x-y)**2,axis=2)

def get_recovery_time(msd): # from end of stim
    # threshold is based on msd at end of trial
    threshold = 0.1*np.max(np.max(msd))
    # find first time under threshold post end of stim
    under_threshold = np.argwhere(msd[:,stim_params.stim_time[1]:]<=threshold)
    recov_time = np.zeros((msd.shape[0],),dtype=np.int)
    
    for tgt_idx in range(msd.shape[0]):
        is_tgt = np.nonzero(under_threshold[:,0]==tgt_idx)
        if(len(is_tgt[0])==0):
            recov_time[tgt_idx] = msd.shape[1]-stim_params.stim_time[1]+1
        else:
            recov_time[tgt_idx] = under_threshold[is_tgt[0][0],1]
    return recov_time
   
def get_stim_effect(msd):
    return np.mean(msd[:,stim_params.stim_time[0]:stim_params.stim_time[1]],axis=1)
    
def get_metrics(x,y):
    # get msd
    msd = mean_squared_difference(x,y)
    
    # get recovery time based on msd
    recov_time = get_recovery_time(msd)
    stim_effect = get_stim_effect(msd)
    return (msd,recov_time,stim_effect)


def get_stim_data():
    stim_params.is_stim=False
    cell.reset_counter()
    activations = activation_model.predict(x_test)
    rnn_activations_no_stim = activations[1]
    emg_pred_no_stim = activations[-1]

    stim_params.is_stim=True
    cell.reset_counter()
    activations = activation_model.predict(x_test)
    rnn_activations_stim = activations[1]
    emg_pred_stim = activations[-1]
    
    emg_metrics = get_metrics(emg_pred_no_stim,emg_pred_stim)
    rnn_metrics = get_metrics(rnn_activations_no_stim,rnn_activations_stim)
    #rnn_activation_grid_no_stim = make_activation_grid(rnn_activations_no_stim,cell,args)
    #rnn_activation_grid_stim = make_activation_grid(rnn_activations_stim,cell,args)
    return emg_metrics, rnn_metrics
    
def run_stim_exp():    
    emg_recov = np.zeros((stim_pos_test.shape[0],len(stim_dist_tau_test),len(stim_duration_test),x_test.shape[0]))
    rnn_recov = np.zeros_like(emg_recov)
    emg_stim_effect = np.zeros_like(emg_recov)
    rnn_stim_effect = np.zeros_like(emg_recov)
    
    for i_pos in range(stim_pos_test.shape[0]):
        print(i_pos/stim_pos_test.shape[0])
        stim_params.stim_pos = stim_pos_test[i_pos,:]
        for i_tau in range(len(stim_dist_tau_test)):
            stim_params.stim_dist_tau = stim_dist_tau_test[i_tau]
            for i_dur in range(len(stim_duration_test)):
                stim_params.stim_time = [go_cue_idx+5,go_cue_idx+5+stim_duration_test[i_dur]-1]
                
                emg_metrics, rnn_metrics = get_stim_data()
                emg_recov[i_pos,i_tau,i_dur,:] = emg_metrics[1]
                emg_stim_effect[i_pos,i_tau,i_dur,:] = emg_metrics[2]
                rnn_recov[i_pos,i_tau,i_dur,:] = rnn_metrics[1]
                rnn_stim_effect[i_pos,i_tau,i_dur,:] = rnn_metrics[2]

    
    print("done")
    return emg_recov, emg_stim_effect, rnn_recov, rnn_stim_effect

stim_params = StimParams()
# get stim conditions
x = np.linspace(0, args.latent_shape[0], 5, dtype=np.float32)
y = np.linspace(0, args.latent_shape[1], 5, dtype=np.float32)
xv, yv = np.meshgrid(x, y)
xv = np.reshape(xv, (xv.size, 1))
yv = np.reshape(yv, (yv.size, 1))
stim_pos_test = np.hstack((xv, yv))
stim_dist_tau_test = [0.5,1,2,5,10,20]
stim_duration_test = [1,2,4,8,16]

emg_recov, emg_stim_effect, rnn_recov, rnn_stim_effect=run_stim_exp()

# average over stim positions
emg_recov_mean = np.mean(np.mean(emg_recov,axis=3),axis=0)
emg_stim_effect_mean = np.mean(np.mean(emg_stim_effect,axis=3),axis=0)
rnn_recov_mean = np.mean(np.mean(rnn_recov,axis=3),axis=0)
rnn_stim_effect_mean = np.mean(np.mean(rnn_stim_effect,axis=3),axis=0)

# plot effect of amplitude and stim duration
plt.figure()
ax = plt.axes(projection='3d')
for i_tau in range(len(stim_dist_tau_test)):
    for i_dur in range(len(stim_duration_test)):
        ax.scatter3D(stim_dist_tau_test[i_tau],stim_duration_test[i_dur],rnn_stim_effect_mean[i_tau,i_dur],'.',color='k')
        
ax.set_xlabel("stim tau")
ax.set_ylabel("stim duration")
ax.set_zlabel("stim effect")

# plot recovery with different amplitudes and stim durations
plt.figure()
ax = plt.axes(projection='3d')
for i_tau in range(len(stim_dist_tau_test)):
    for i_dur in range(len(stim_duration_test)):
        ax.scatter3D(stim_dist_tau_test[i_tau],stim_duration_test[i_dur],rnn_recov_mean[i_tau,i_dur],'.',color='k')
        
ax.set_xlabel("stim tau")
ax.set_ylabel("stim duration")
ax.set_zlabel("recovery")

In [None]:
args.noise_val = 0.0
stim_params.is_stim = True
stim_params.stim_pos = np.array([10,10])
stim_params.stim_dist_tau = 4
stim_params.stim_time = [go_cue_idx+10,go_cue_idx+15]
cell.reset_counter()
activations = activation_model.predict(x_test)
rnn_activations_stim = activations[1]
y_pred_stim = activations[-1]
rnn_activation_grid_stim = make_activation_grid(rnn_activations_stim,cell,args)

stim_params.is_stim=False
cell.reset_counter()
activations = activation_model.predict(x_test)
rnn_activations_no_stim = activations[1]
y_pred_no_stim = activations[-1]
rnn_activation_grid_no_stim = make_activation_grid(rnn_activations_no_stim,cell,args)

rnn_activation_grid_diff = rnn_activation_grid_stim - rnn_activation_grid_no_stim

# plot activation difference
plt.figure()
subplot_idx = [6,2,4,8]
for i in range(len(subplot_idx)):
    plt.subplot(3,3,subplot_idx[i])
    plt.imshow(rnn_activation_grid_diff[i,stim_params.stim_time[0],:,:])
    plt.colorbar()
 

plt.figure()
subplot_idx = [6,2,4,8]
for i in range(len(subplot_idx)):
    plt.subplot(3,3,subplot_idx[i])
    plt.imshow(rnn_activation_grid_diff[i,stim_params.stim_time[1]+1,:,:])
    plt.colorbar()

# plot prediction in both cases
plt.figure()
out_idx = 7
for tgt_dir_idx in range(y_pred_no_stim.shape[0]):
    plt.subplot(3,3,tgt_dir_idx+1)
    plt.plot(y_pred_no_stim[tgt_dir_idx,:,out_idx])
    plt.plot(y_pred_stim[tgt_dir_idx,:,out_idx])
    plt.axvline(stim_params.stim_time[0],linewidth=0.2)
    plt.axvline(stim_params.stim_time[0],linewidth=0.2)
    
# example neuron
plt.figure()
unit=37
for tgt_idx in range(rnn_activations_no_stim.shape[0]):
    plt.subplot(3,3,tgt_idx+1)
    plt.plot(np.transpose(rnn_activations_no_stim[tgt_idx,:,unit]),color=color_list[tgt_idx]);
    plt.plot(np.transpose(rnn_activations_stim[tgt_idx,:,unit]),'--',color=color_list[tgt_idx]);
    plt.axvline(stim_params.stim_time[0],linewidth=0.2)
    plt.axvline(stim_params.stim_time[1],linewidth=0.2)  

In [None]:
def plot_pca_traces(data,tgt_idx,dim_idx,is_stim=False):
    line_type = '-'
    if(is_stim):
        line_type='--'
    plt.plot(data[tgt_idx,:,dim_idx[0]],data[tgt_idx,:,dim_idx[1]],line_type,color=color_list[tgt_idx]);
    plt.plot(data[tgt_idx,0,dim_idx[0]],data[tgt_idx,0,dim_idx[1]],'.',color=color_list[tgt_idx]);
    plt.plot(data[tgt_idx,go_cue_idx,dim_idx[0]],data[tgt_idx,go_cue_idx,dim_idx[1]],'o',color=color_list[tgt_idx])
    
# dynamical analysis (PCA? and such on activations in RNN layer)
# activations is tgt_dir x time x neuron
# truncate activations from go cue to end
rnn_no_stim_trunc = rnn_activations_no_stim[:,:,:]
rnn_stim_trunc = rnn_activations_stim[:,:,:]

rnn_activations_no_stim_flat = rnn_no_stim_trunc.reshape(-1,rnn_activations_no_stim.shape[-1])
rnn_activations_stim_flat = rnn_stim_trunc.reshape(-1,rnn_activations_stim.shape[-1])

pca = PCA(n_components=10)
pca.fit(rnn_activations_no_stim_flat)

rnn_pca_no_stim = pca.transform(rnn_activations_no_stim_flat)
rnn_pca_stim = pca.transform(rnn_activations_stim_flat)
rnn_pca_no_stim = rnn_pca_no_stim.reshape(len(degs),int(rnn_pca_no_stim.shape[0]/len(degs)),rnn_pca_no_stim.shape[1])
rnn_pca_stim = rnn_pca_stim.reshape(len(degs),int(rnn_pca_stim.shape[0]/len(degs)),rnn_pca_stim.shape[1])

dims = [1,2]
plt.figure()
for tgt_idx in range(rnn_pca_no_stim.shape[0]):
    plot_pca_traces(rnn_pca_no_stim,tgt_idx,dims)
    
    
plt.figure()
for tgt_idx in range(rnn_pca_no_stim.shape[0]):
    plt.subplot(3,3,tgt_idx+1)
    plot_pca_traces(rnn_pca_no_stim,tgt_idx,dims)
    plot_pca_traces(rnn_pca_stim,tgt_idx,dims,is_stim=True)

In [None]:
# dpca
# X = multi dimensional array. [n,t,s,d] where n is neurons, t is time, s is stimulus, d is decision
# labels = optional; list of characters to describe parameter axes ('tsd')
# n_components = Dictionary or integer; if integer use the same number of components in each marginalization,
    #otherwise every (key,value) pair refers to the number of components (value) in a marginalization (key).

def plot_dpca_traces(data,is_stim,label,dim):
    if(is_stim):
        linestyle = '--'
    else:
        linestyle = '-'
    x_data = np.arange(0,data[label].shape[1])
    
    for i_tgt in range(data[label].shape[-1]): # targets are in last dimension
        if(data[label].shape[-1] > 4):
            plt.subplot(3,3,i_tgt+1)
        else:
            plt.subplot(2,2,i_tgt+1)
        plt.plot(x_data,data[label][dim,:,i_tgt],color=color_list[i_tgt],linestyle=linestyle)    
        plt.axvline(stim_params.stim_time[0],linewidth=0.2)
        plt.axvline(stim_params.stim_time[1],linewidth=0.2)
        
X_no_stim = np.transpose(rnn_activations_no_stim,(2,1,0))
X_stim = np.transpose(rnn_activations_stim,(2,1,0))

labels = 'ts'
dpca = dPCA.dPCA(labels)
dpca.protect = ['t']
Z_no_stim = dpca.fit_transform(X_no_stim)
Z_stim = dpca.transform(X_stim)


for i_dim in range(3):
    plt.figure()
    plot_dpca_traces(Z_no_stim,0,'t',i_dim)
    plot_dpca_traces(Z_stim,1,'t',i_dim)
     
