In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4); seed = 1; torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
import warnings
from sklearn.decomposition import PCA
from textwrap import wrap
from scipy.stats import wilcoxon
from scipy.linalg import subspace_angles


# from model import *
from functions import *
# os.chdir('/home/yl4317/Documents/two_module_rnn/')

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    

In [None]:
def participation_ratio(lambdas):
    """ compute the participation ratio from a list of eigenvalues """
    
    sum_of_squares = 0
    square_of_sum = 0
    
    for l in lambdas:
        sum_of_squares += l**2
        square_of_sum += l
    
    pr = square_of_sum**2/sum_of_squares
    
    return pr



def compute_subspace(activity, d='pr'):
    """ compute the subspace from a collection of neural trajectories
        activity: (n_trials*n_timesteps) * n_neurons
        d: # of dimensions for the subspace. Default: 'pr' (use the participation ratio) 
        
        return: 
        subspace - n_dimensions * n_embedded_dimsneions
        exp_var_ratio - explained variance ratio
    """
    if d!='pr':
        pca = PCA(n_components=d)
        pca.fit(activity)
        subspace = pca.components_
        exp_var_ratio = pca.explained_variance_ratio_
        n_dim = d
    elif d=='pr':
        pca = PCA(n_components=activity.shape[-1])
        pca.fit(activity)
        exp_var_ratio = pca.explained_variance_ratio_
        pr = int(np.round(participation_ratio(exp_var_ratio)))
        subspace = pca.components_[:pr]
        exp_var_ratio = exp_var_ratio[:pr]
        n_dim = pr
        
    return subspace, exp_var_ratio, n_dim


def normalize_along_row(x):
    """ normalize the rows of x """
    
    y = np.zeros(x.shape)
    for i in range(x.shape[0]):
        y[i, :] = x[i, :]/np.sqrt(np.linalg.norm(x[i, :], ord=2))
        
    return y


def remove_pane_and_grid_3d(ax):
    """ remove the pane color and grid of a 3d plot """
    
    ax.grid(False)
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    

# Figure 6a, b: visualize neural trajectories

In [None]:
d = 2

for model_name in sorted(os.listdir('/model/directory/')):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name)
        
        # load model
        path_to_file = '/model/directory/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        
        # load data
        with open('/where/test/data/is/stored/{}'.format(model_name+'_testdata_noiseless'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        if mean_perf<0.8 or mean_perf_rule<0.8:
            print('low performing model ({}/{})'.format(mean_perf, mean_perf_rule))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()

        # group trials
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        stable_trs = rule1_trs_stable + rule2_trs_stable
        
        stims = [_[0] for _ in test_data['stims']]
        ref_cards = [_['center_card'] for _ in stims]
        ref00trs = [tr for tr in range(len(ref_cards)) if ref_cards[tr]['color']==0 and ref_cards[tr]['shape']==0]
        ref01trs = [tr for tr in range(len(ref_cards)) if ref_cards[tr]['color']==0 and ref_cards[tr]['shape']==1]
        ref10trs = [tr for tr in range(len(ref_cards)) if ref_cards[tr]['color']==1 and ref_cards[tr]['shape']==0]
        ref11trs = [tr for tr in range(len(ref_cards)) if ref_cards[tr]['color']==1 and ref_cards[tr]['shape']==1]
        
        
        
        # do PCA over all trajectories
        time_period_rule = np.arange(hp_task_test['trial_start']//hp_test['dt'], hp_task_test['center_card_on']//hp_test['dt'])    # fdbk + ITI
        time_period_choice = np.arange(hp_task_test['resp_start']//hp_test['dt'], hp_task_test['resp_end']//hp_test['dt'])    # ref card + test cards (response)
        neuron_used = list(model.rnn.cg_idx['sr_esoma']) + list(model.rnn.cg_idx['sr_pv']) + list(model.rnn.cg_idx['sr_sst']) + list(model.rnn.cg_idx['sr_vip'])
        rnn_activity_used_rule = rnn_activity[:, :, :, neuron_used][:, time_period_rule, :, :]   # part of activity used for PCA 
        rnn_activity_used_rule = rnn_activity_used_rule[:, :, 0, :]    # squeeze the batch dimension 
        rnn_activity_used_choice = rnn_activity[:, :, :, neuron_used][:, time_period_choice, :, :]   # part of activity used for PCA
        rnn_activity_used_choice = rnn_activity_used_choice[:, :, 0, :]    # squeeze the batch dimension
        pca_all_traj_rule = PCA(n_components=10)
        rnn_activity_flat_rule = rnn_activity_used_rule.reshape(rnn_activity_used_rule.shape[0]*rnn_activity_used_rule.shape[1], rnn_activity_used_rule.shape[-1])
        pca_all_traj_rule.fit(rnn_activity_flat_rule)
        print('% of explained variance (rule)', pca_all_traj_rule.explained_variance_ratio_)
        pca_all_traj_choice = PCA(n_components=10)
        rnn_activity_flat_choice = rnn_activity_used_choice.reshape(rnn_activity_used_choice.shape[0]*rnn_activity_used_choice.shape[1], rnn_activity_used_choice.shape[-1])
        pca_all_traj_choice.fit(rnn_activity_flat_choice)
        print('% of explained variance (choice)', pca_all_traj_choice.explained_variance_ratio_)
        
        
        
        
        
        
        
        
        # Figure 6a
        row_idx = 0
        col_idx = 0
        fig, ax = plt.subplots(3, 3, figsize=[20, 20], subplot_kw={'projection': '3d'})
        fig.suptitle('rule PCA')
        fig.patch.set_facecolor('white')
        for pc_idx in range(8):
            for tr in range(rnn_activity_used_rule.shape[0]):
                if tr in rule1_trs_stable:
                    color = 'royalblue'
                elif tr in rule2_trs_stable:
                    color = 'lightcoral'
                else:
                    continue
#                 rnn_activity_lowd = pca_from_rule.transform(rnn_activity_used[tr, :, :])
                rnn_activity_lowd = pca_all_traj_rule.transform(rnn_activity_used_rule[tr, :, :])
                traj_x, traj_y, traj_z = rnn_activity_lowd[:, pc_idx], rnn_activity_lowd[:, pc_idx+1], rnn_activity_lowd[:, pc_idx+2]
                ax[row_idx, col_idx].plot(traj_x, traj_y, traj_z, color=color)
                ax[row_idx, col_idx].scatter(traj_x[-1], traj_y[-1], traj_z[-1], s=50, color=color)
                ax[row_idx, col_idx].scatter(traj_x[0], traj_y[0], traj_z[0], s=50, color='k')
                ax[row_idx, col_idx].set_xlabel('PC {}'.format(pc_idx+1))
                ax[row_idx, col_idx].set_ylabel('PC {}'.format(pc_idx+2))
                ax[row_idx, col_idx].set_zlabel('PC {}'.format(pc_idx+3))
                
            remove_pane_and_grid_3d(ax[row_idx, col_idx])
            
            pc_idx += 1
            col_idx += 1
            if col_idx >= ax.shape[1]:
                col_idx = 0
                row_idx += 1
                
                
                
        # Figure 6b
        row_idx = 0
        col_idx = 0
        fig, ax = plt.subplots(3, 3, figsize=[20, 20], subplot_kw={'projection': '3d'})
        fig.suptitle('choice PCA')
        fig.patch.set_facecolor('white')
        for pc_idx in range(8):
            for tr in range(rnn_activity_used_choice.shape[0]):
                if tr in c1_trs_stable:
                    color = '#7fc97f'
                elif tr in c2_trs_stable:
                    color = '#beaed4'
                elif tr in c3_trs_stable:
                    color = '#fdc086'
                else:
                    continue
                rnn_activity_lowd = pca_all_traj_choice.transform(rnn_activity_used_choice[tr, :, :])
                traj_x, traj_y, traj_z = rnn_activity_lowd[:, pc_idx], rnn_activity_lowd[:, pc_idx+1], rnn_activity_lowd[:, pc_idx+2]
                ax[row_idx, col_idx].plot(traj_x, traj_y, traj_z, color=color, linewidth=3)
                ax[row_idx, col_idx].scatter(traj_x[-1], traj_y[-1], traj_z[-1], s=75, color=color)
                ax[row_idx, col_idx].scatter(traj_x[0], traj_y[0], traj_z[0], s=75, color='k')
                ax[row_idx, col_idx].set_xlabel('PC {}'.format(pc_idx+1))
                ax[row_idx, col_idx].set_ylabel('PC {}'.format(pc_idx+2))
                ax[row_idx, col_idx].set_zlabel('PC {}'.format(pc_idx+3))
                
            
            remove_pane_and_grid_3d(ax[row_idx, col_idx])
        
            pc_idx += 1
            col_idx += 1
            if col_idx >= ax.shape[1]:
                col_idx = 0
                row_idx += 1
                
        fig.tight_layout()
        plt.show()