In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4); seed = 1; torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
import warnings
from sklearn.decomposition import PCA
from textwrap import wrap
from scipy.stats import wilcoxon
from scipy.linalg import subspace_angles


from functions import *

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    

In [None]:
def participation_ratio(lambdas):
    """ compute the participation ratio from a list of eigenvalues """
    
    sum_of_squares = 0
    square_of_sum = 0
    
    for l in lambdas:
        sum_of_squares += l**2
        square_of_sum += l
    
    pr = square_of_sum**2/sum_of_squares
    
    return pr



def compute_subspace(activity, d='pr'):
    """ compute the subspace from a collection of neural trajectories
        activity: (n_trials*n_timesteps) * n_neurons
        d: # of dimensions for the subspace. Default: 'pr' (use the participation ratio) 
        
        return: 
        subspace - n_dimensions * n_embedded_dimsneions
        exp_var_ratio - explained variance ratio
    """
    if d!='pr':
        pca = PCA(n_components=d)
        pca.fit(activity)
        subspace = pca.components_
        exp_var_ratio = pca.explained_variance_ratio_
        n_dim = d
    elif d=='pr':
        pca = PCA(n_components=activity.shape[-1])
        pca.fit(activity)
        exp_var_ratio = pca.explained_variance_ratio_
        pr = int(np.round(participation_ratio(exp_var_ratio)))
        subspace = pca.components_[:pr]
        exp_var_ratio = exp_var_ratio[:pr]
        n_dim = pr
        
    return subspace, exp_var_ratio, n_dim


def normalize_along_row(x):
    """ normalize the rows of x """
    
    y = np.zeros(x.shape)
    for i in range(x.shape[0]):
        y[i, :] = x[i, :]/np.sqrt(np.linalg.norm(x[i, :], ord=2))
        
    return y


def remove_pane_and_grid_3d(ax):
    """ remove the pane color and grid of a 3d plot """
    
    ax.grid(False)
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    

# Generate data for figure 6c, d (principal angel between subspaces)

In [None]:
all_data = {}

d = 'pr'    # number of dimensions for each subspace

model_dir = ''
test_data_dir = ''

for model_name in sorted(os.listdir(model_dir)):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name)
        
        
        
        # load model
        path_to_file = model_dir+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)

            
        # load data
        with open(test_data_dir+model_name+'_testdata_noiseless_no_current_matrix', 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        if mean_perf<0.8 or mean_perf_rule<0.8:
            print('low performing model ({}/{})'.format(mean_perf, mean_perf_rule))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()

        # group trials
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        stable_trs = rule1_trs_stable + rule2_trs_stable

        
        
        
        
        
        # obtain different subspaces
        # generate subspaces
        neuron_used = list(model.rnn.cg_idx['sr_esoma']) + list(model.rnn.cg_idx['sr_pv']) + list(model.rnn.cg_idx['sr_sst']) + list(model.rnn.cg_idx['sr_vip'])
        time_used_rule = np.arange(hp_task_test['trial_history_start']//hp_test['dt'], hp_task_test['center_card_on']//hp_test['dt'])    # use the inter-trial epoch
        time_used_choice = np.arange(hp_task_test['resp_start']//hp_test['dt'], hp_task_test['resp_end']//hp_test['dt'])
        
        ## rule 1 subspace
        rnn_activity_sm_rule1 = rnn_activity[rule1_trs_stable, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
        rnn_activity_sm_rule1_flat = rnn_activity_sm_rule1.reshape(rnn_activity_sm_rule1.shape[0]*rnn_activity_sm_rule1.shape[1], rnn_activity_sm_rule1.shape[-1])
        subspace_rule1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule1_flat, d=d)
        
        ## rule 2 subspace
        rnn_activity_sm_rule2 = rnn_activity[rule2_trs_stable, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
        rnn_activity_sm_rule2_flat = rnn_activity_sm_rule2.reshape(rnn_activity_sm_rule2.shape[0]*rnn_activity_sm_rule2.shape[1], rnn_activity_sm_rule2.shape[-1])
        subspace_rule2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule2_flat, d=d)
        
        ## choice 1 subspace
        rnn_activity_sm_c1 = rnn_activity[c1_trs_stable, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
        rnn_activity_sm_c1_flat = rnn_activity_sm_c1.reshape(rnn_activity_sm_c1.shape[0]*rnn_activity_sm_c1.shape[1], rnn_activity_sm_c1.shape[-1])
        subspace_c1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c1_flat, d=d)
        
        ## choice 2 subspace
        rnn_activity_sm_c2 = rnn_activity[c2_trs_stable, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
        rnn_activity_sm_c2_flat = rnn_activity_sm_c2.reshape(rnn_activity_sm_c2.shape[0]*rnn_activity_sm_c2.shape[1], rnn_activity_sm_c2.shape[-1])
        subspace_c2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c2_flat, d=d)
        
        ## choice 3 subspace
        rnn_activity_sm_c3 = rnn_activity[c3_trs_stable, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
        rnn_activity_sm_c3_flat = rnn_activity_sm_c3.reshape(rnn_activity_sm_c3.shape[0]*rnn_activity_sm_c3.shape[1], rnn_activity_sm_c3.shape[-1])
        subspace_c3, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c3_flat, d=d)
        
        # rule subspace
        rnn_activity_rule = rnn_activity[stable_trs, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
        rnn_activity_rule_flat = rnn_activity_rule.reshape(rnn_activity_rule.shape[0]*rnn_activity_rule.shape[1], rnn_activity_rule.shape[-1])
        subspace_rule, exp_var_ratio, n_dim = compute_subspace(rnn_activity_rule_flat, d=d)
        
        # choice subspace
        rnn_activity_choice = rnn_activity[stable_trs, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
        rnn_activity_choice_flat = rnn_activity_choice.reshape(rnn_activity_choice.shape[0]*rnn_activity_choice.shape[1], rnn_activity_choice.shape[-1])
        subspace_choice, exp_var_ratio, n_dim = compute_subspace(rnn_activity_choice_flat, d=d)
        
        # compute angle between subspaces
        angle_rule_subspace = np.rad2deg(subspace_angles(subspace_rule1.T, subspace_rule2.T)[0])
        angle_choice12_subspace = np.rad2deg(subspace_angles(subspace_c1.T, subspace_c2.T)[0])
        angle_choice13_subspace = np.rad2deg(subspace_angles(subspace_c1.T, subspace_c3.T)[0])
        angle_choice23_subspace = np.rad2deg(subspace_angles(subspace_c2.T, subspace_c3.T)[0])
        angle_choice_subspace_avg = np.mean([angle_choice12_subspace, angle_choice13_subspace, angle_choice23_subspace])
        angle_rule_choice = np.rad2deg(subspace_angles(subspace_rule.T, subspace_choice.T)[0])
        
        
        # Do the same for shuffled data
        angle_rules_shuffle = []
        angle_choices_shuffle = []
        angle_rule_choices_shuffle = []
        
        for _ in range(100):
            # generate shuffled trials
            rule_trs_stable = rule1_trs_stable + rule2_trs_stable
            rule1_trs_split1 = np.random.choice(rule1_trs_stable, size=len(rule1_trs_stable)//2, replace=False)
            rule1_trs_split2 = [tr for tr in rule1_trs_stable if tr not in rule1_trs_split1]
            rule2_trs_split1 = np.random.choice(rule2_trs_stable, size=len(rule2_trs_stable)//2, replace=False)
            rule2_trs_split2 = [tr for tr in rule2_trs_stable if tr not in rule2_trs_split1]
            
            choice_trs_stable = c1_trs_stable + c2_trs_stable + c3_trs_stable
            c1_trs_split1 = np.random.choice(c1_trs_stable, size=len(c1_trs_stable)//2, replace=False)
            c1_trs_split2 = [tr for tr in c1_trs_stable if tr not in c1_trs_split1]
            c2_trs_split1 = np.random.choice(c2_trs_stable, size=len(c2_trs_stable)//2, replace=False)
            c2_trs_split2 = [tr for tr in c2_trs_stable if tr not in c2_trs_split1]
            c3_trs_split1 = np.random.choice(c3_trs_stable, size=len(c3_trs_stable)//2, replace=False)
            c3_trs_split2 = [tr for tr in c3_trs_stable if tr not in c3_trs_split1]

            
            # generate subspaces
        
            ## rule 1 subspace
            rnn_activity_sm_rule1_split1 = rnn_activity[rule1_trs_split1, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_sm_rule1_split1 = rnn_activity_sm_rule1_split1.reshape(rnn_activity_sm_rule1_split1.shape[0]*rnn_activity_sm_rule1_split1.shape[1], rnn_activity_sm_rule1_split1.shape[-1])
            subspace_rule1_split1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule1_split1, d=d)

            rnn_activity_sm_rule1_split2 = rnn_activity[rule1_trs_split2, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_sm_rule1_split2 = rnn_activity_sm_rule1_split2.reshape(rnn_activity_sm_rule1_split2.shape[0]*rnn_activity_sm_rule1_split2.shape[1], rnn_activity_sm_rule1_split2.shape[-1])
            subspace_rule1_split2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule1_split2, d=d)

            
            

            ## rule 2 subspace
            rnn_activity_sm_rule2_split1 = rnn_activity[rule2_trs_split1, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_sm_rule2_split1 = rnn_activity_sm_rule2_split1.reshape(rnn_activity_sm_rule2_split1.shape[0]*rnn_activity_sm_rule2_split1.shape[1], rnn_activity_sm_rule2_split1.shape[-1])
            subspace_rule2_split1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule2_split1, d=d)

            rnn_activity_sm_rule2_split2 = rnn_activity[rule2_trs_split2, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_sm_rule2_split2 = rnn_activity_sm_rule2_split2.reshape(rnn_activity_sm_rule2_split2.shape[0]*rnn_activity_sm_rule2_split2.shape[1], rnn_activity_sm_rule2_split2.shape[-1])
            subspace_rule2_split2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule2_split2, d=d)


                
                
            ## choice 1 subspace
            rnn_activity_sm_c1_split1 = rnn_activity[c1_trs_split1, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
            rnn_activity_sm_c1_split1 = rnn_activity_sm_c1_split1.reshape(rnn_activity_sm_c1_split1.shape[0]*rnn_activity_sm_c1_split1.shape[1], rnn_activity_sm_c1_split1.shape[-1])
            subspace_c1_split1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c1_split1, d=d)

            rnn_activity_sm_c1_split2 = rnn_activity[c1_trs_split2, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
            rnn_activity_sm_c1_split2 = rnn_activity_sm_c1_split2.reshape(rnn_activity_sm_c1_split2.shape[0]*rnn_activity_sm_c1_split2.shape[1], rnn_activity_sm_c1_split2.shape[-1])
            subspace_c1_split2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c1_split2, d=d)

                
            ## choice 2 subspace
            rnn_activity_sm_c2_split1 = rnn_activity[c2_trs_split1, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
            rnn_activity_sm_c2_split1 = rnn_activity_sm_c2_split1.reshape(rnn_activity_sm_c2_split1.shape[0]*rnn_activity_sm_c2_split1.shape[1], rnn_activity_sm_c2_split1.shape[-1])
            subspace_c2_split1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c2_split1, d=d)
 
            rnn_activity_sm_c2_split2 = rnn_activity[c2_trs_split2, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
            rnn_activity_sm_c2_split2 = rnn_activity_sm_c2_split2.reshape(rnn_activity_sm_c2_split2.shape[0]*rnn_activity_sm_c2_split2.shape[1], rnn_activity_sm_c2_split2.shape[-1])
            subspace_c2_split2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c2_split2, d=d)


                
            ## choice 3 subspace
            rnn_activity_sm_c3_split1 = rnn_activity[c3_trs_split1, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
            rnn_activity_sm_c3_split1 = rnn_activity_sm_c3_split1.reshape(rnn_activity_sm_c3_split1.shape[0]*rnn_activity_sm_c3_split1.shape[1], rnn_activity_sm_c3_split1.shape[-1])
            subspace_c3_split1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c3_split1, d=d)

            rnn_activity_sm_c3_split2 = rnn_activity[c3_trs_split2, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
            rnn_activity_sm_c3_split2 = rnn_activity_sm_c3_split2.reshape(rnn_activity_sm_c3_split2.shape[0]*rnn_activity_sm_c3_split2.shape[1], rnn_activity_sm_c3_split2.shape[-1])
            subspace_c3_split2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_c3_split2, d=d)

                
                
                
                
            ## rule and choice subspaces
            rule_trs_stable = rule1_trs_stable + rule2_trs_stable
            rule_trs_split1 = np.random.choice(rule_trs_stable, size=len(rule_trs_stable)//2, replace=False)
            rule_trs_split2 = [tr for tr in rule_trs_stable if tr not in rule_trs_split1]
            
            choice_trs_stable = c1_trs_stable + c2_trs_stable + c3_trs_stable
            choice_trs_split1 = np.random.choice(choice_trs_stable, size=len(choice_trs_stable)//2, replace=False)
            choice_trs_split2 = [tr for tr in choice_trs_stable if tr not in choice_trs_split1]
            
            rnn_activity_rule_split1 = rnn_activity[rule_trs_split1, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_rule_split2 = rnn_activity[rule_trs_split2, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_rule_split1_flat = rnn_activity_rule_split1.reshape(rnn_activity_rule_split1.shape[0]*rnn_activity_rule_split1.shape[1], rnn_activity_rule_split1.shape[-1])
            rnn_activity_rule_split2_flat = rnn_activity_rule_split2.reshape(rnn_activity_rule_split2.shape[0]*rnn_activity_rule_split2.shape[1], rnn_activity_rule_split2.shape[-1])
            
            rnn_activity_choice_split1 = rnn_activity[choice_trs_split1, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
            rnn_activity_choice_split2 = rnn_activity[choice_trs_split2, :, 0, :][:, time_used_choice, :][:, :, neuron_used]
            rnn_activity_choice_split1_flat = rnn_activity_choice_split1.reshape(rnn_activity_choice_split1.shape[0]*rnn_activity_choice_split1.shape[1], rnn_activity_choice_split1.shape[-1])
            rnn_activity_choice_split2_flat = rnn_activity_choice_split2.reshape(rnn_activity_choice_split2.shape[0]*rnn_activity_choice_split2.shape[1], rnn_activity_choice_split2.shape[-1])
            
            subspace_rule_split1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_rule_split1_flat, d=d)
            subspace_rule_split2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_rule_split2_flat, d=d)
            subspace_choice_split1, exp_var_ratio, n_dim = compute_subspace(rnn_activity_choice_split1_flat, d=d)
            subspace_choice_split2, exp_var_ratio, n_dim = compute_subspace(rnn_activity_choice_split2_flat, d=d)
            
            
            
        
            # compute angle between subspaces
            angle_rule1_subspace_shuffle = np.rad2deg(subspace_angles(subspace_rule1_split1.T, subspace_rule1_split2.T)[0])
            angle_rule2_subspace_shuffle = np.rad2deg(subspace_angles(subspace_rule2_split1.T, subspace_rule2_split2.T)[0])
            angle_rule_subspace_shuffle = np.mean([angle_rule1_subspace_shuffle, angle_rule2_subspace_shuffle])
            
            angle_c1_subspace_shuffle = np.rad2deg(subspace_angles(subspace_c1_split1.T, subspace_c1_split2.T)[0])
            angle_c2_subspace_shuffle = np.rad2deg(subspace_angles(subspace_c2_split1.T, subspace_c2_split2.T)[0])
            angle_c3_subspace_shuffle = np.rad2deg(subspace_angles(subspace_c3_split1.T, subspace_c3_split2.T)[0])
            angle_choice_subspace_avg_shuffle = np.mean([angle_c1_subspace_shuffle, angle_c2_subspace_shuffle, angle_c3_subspace_shuffle])
            
            angle_rule_subspace_shuffle = np.rad2deg(subspace_angles(subspace_rule_split1.T, subspace_rule_split2.T )[0])
            angle_choice_subspace_shuffle = np.rad2deg(subspace_angles(subspace_choice_split1.T, subspace_choice_split2.T)[0])
            angle_rule_choice_subspace_shuffle = np.mean([angle_rule_subspace_shuffle, angle_choice_subspace_shuffle])

            # append to the list
            angle_rules_shuffle.append(angle_rule_subspace_shuffle)
            angle_choices_shuffle.append(angle_choice_subspace_avg_shuffle)
            angle_rule_choices_shuffle.append(angle_rule_choice_subspace_shuffle)
            

        # collect data
        all_data[model_name] = {
                               'model_name': model_name,
                               'hp': hp_test,
                               'angle_rule_subspace': angle_rule_subspace,
                               'angle_rule_subspace_shuffle': angle_rules_shuffle,
                               'angle_choice_subspace_avg': angle_choice_subspace_avg,
                               'angle_choice_subspace_avg_shuffle': angle_choices_shuffle,
                               'angle_rule_choice': angle_rule_choice,
                               'angle_rule_choice_shuffle': angle_rule_choices_shuffle
                               }

# Figure 6c, d

In [None]:
# angle_rules_all_models = []
angle_rules_all_models_shuffle = []
angle_choices_all_models = []
angle_choices_all_models_shuffle = []
angle_rule_choice_all_models = []
angle_rule_choice_all_models_shuffle = []

model_list = list(all_data.keys())
for model in model_list:
    if all_data[model]['hp']['dend_nonlinearity']!='subtractive':    # here, subselect models with either divisive or subtractive nonlinearity
        continue
    angle_rules_all_models.append(all_data[model]['angle_rule_subspace'])
    angle_rules_all_models_shuffle.extend(all_data[model]['angle_rule_subspace_shuffle'])
    angle_choices_all_models.append(all_data[model]['angle_choice_subspace_avg'])
    angle_choices_all_models_shuffle.extend(all_data[model]['angle_choice_subspace_avg_shuffle'])
    angle_rule_choice_all_models.append(all_data[model]['angle_rule_choice'])
    angle_rule_choice_all_models_shuffle.extend(all_data[model]['angle_rule_choice_shuffle'])


# subspace angle compare to shuffled data
fig, ax = plt.subplots(1, 2, figsize=[10, 3])
bins=np.arange(0, 90, 1)
ax[0].set_title('principle angle between rule subspaces')
ax[1].set_title('principle angle between choice subspaces')

ax[0].hist(angle_rules_all_models, color='k', density=True, bins=bins)
ax[0].hist(angle_rules_all_models_shuffle, color='gray', alpha=0.5, density=True, bins=bins)

ax[1].hist(angle_choices_all_models, color='k', density=True, bins=bins)
ax[1].hist(angle_choices_all_models_shuffle, color='gray', alpha=0.5, density=True, bins=bins)

for i in range(2):
    make_pretty_axes(ax[i])
    ax[i].set_xlim([0, 90])
    ax[i].set_xticks([0, 90])
fig.tight_layout()
plt.show()

# Figure 7f: the principal angle when SST/PV neurons are inhibited

In [None]:
d = 'pr'

model_dir = ''
test_data_dir = ''

for model_name in sorted(os.listdir(model_dir)):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name)
        
        if model_name not in list(all_data.keys()):
            print('model perf is low, excluded in previous analysis, skip')
            continue
            
        # load model
        path_to_file = model_dir+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        
        if hp_test['dend_nonlinearity'] not in ['subtractive', 'divisive_2']:
            print('filtered')
            continue
            
        # subspace with SST silenced
        with open(test_data_dir+model_name+'_testdata_silenceSRSST_withnoise_no_current_matrix', 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
    
        if np.isnan(rnn_activity).any():
            print('NAN in rnn_activity')
            angle_rule_subspace_nosst = np.nan
        else:
            # group trials
            trial_labels = label_trials_wcst(test_data=test_data)
            rule1_trs = trial_labels['rule1_trs']
            rule2_trs = trial_labels['rule2_trs']
            c1_trs = trial_labels['c1_trs']
            c2_trs = trial_labels['c2_trs']
            c3_trs = trial_labels['c3_trs']
            error_trials = trial_labels['error_trials']

            # generate subspaces
            neuron_used = list(model.rnn.cg_idx['sr_esoma']) + list(model.rnn.cg_idx['sr_pv']) + list(model.rnn.cg_idx['sr_sst']) + list(model.rnn.cg_idx['sr_vip'])
            time_used_rule = np.arange(hp_task_test['trial_history_start']//hp_test['dt'], hp_task_test['center_card_on']//hp_test['dt'])
            time_used_choice = np.arange(hp_task_test['resp_start']//hp_test['dt'], hp_task_test['resp_end']//hp_test['dt'])

            ## rule 1 subspace
            rnn_activity_sm_rule1 = rnn_activity[rule1_trs, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_sm_rule1_flat = rnn_activity_sm_rule1.reshape(rnn_activity_sm_rule1.shape[0]*rnn_activity_sm_rule1.shape[1], rnn_activity_sm_rule1.shape[-1])
            subspace_rule1_nosst, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule1_flat, d=d)

            ## rule 2 subspace
            rnn_activity_sm_rule2 = rnn_activity[rule2_trs, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_sm_rule2_flat = rnn_activity_sm_rule2.reshape(rnn_activity_sm_rule2.shape[0]*rnn_activity_sm_rule2.shape[1], rnn_activity_sm_rule2.shape[-1])
            subspace_rule2_nosst, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule2_flat, d=d)
        
            # compute angle between subspaces
            angle_rule_subspace_nosst = np.rad2deg(subspace_angles(subspace_rule1_nosst.T, subspace_rule2_nosst.T)[0])

            


        
        # subspace with PV silenced
        with open(test_data_dir+model_name+'_testdata_silenceSRPV_noiseless_no_current_matrix', 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
    
        if np.isnan(rnn_activity).any():
            print('NAN in rnn_activity')
            angle_rule_subspace_nopv = np.nan
        else:
            # group trials
            trial_labels = label_trials_wcst(test_data=test_data)
            rule1_trs = trial_labels['rule1_trs']
            rule2_trs = trial_labels['rule2_trs']
            c1_trs = trial_labels['c1_trs']
            c2_trs = trial_labels['c2_trs']
            c3_trs = trial_labels['c3_trs']
            error_trials = trial_labels['error_trials']

            # generate subspaces
            neuron_used = list(model.rnn.cg_idx['sr_esoma']) + list(model.rnn.cg_idx['sr_pv']) + list(model.rnn.cg_idx['sr_sst']) + list(model.rnn.cg_idx['sr_vip'])
            time_used_rule = np.arange(hp_task_test['trial_history_start']//hp_test['dt'], hp_task_test['center_card_on']//hp_test['dt'])
            time_used_choice = np.arange(hp_task_test['resp_start']//hp_test['dt'], hp_task_test['resp_end']//hp_test['dt'])

            ## rule 1 subspace
            rnn_activity_sm_rule1 = rnn_activity[rule1_trs, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_sm_rule1_flat = rnn_activity_sm_rule1.reshape(rnn_activity_sm_rule1.shape[0]*rnn_activity_sm_rule1.shape[1], rnn_activity_sm_rule1.shape[-1])
            subspace_rule1_nopv, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule1_flat, d=d)
            
            ## rule 2 subspace
            rnn_activity_sm_rule2 = rnn_activity[rule2_trs, :, 0, :][:, time_used_rule, :][:, :, neuron_used]
            rnn_activity_sm_rule2_flat = rnn_activity_sm_rule2.reshape(rnn_activity_sm_rule2.shape[0]*rnn_activity_sm_rule2.shape[1], rnn_activity_sm_rule2.shape[-1])
            subspace_rule2_nopv, exp_var_ratio, n_dim = compute_subspace(rnn_activity_sm_rule2_flat, d=d)

            # compute angle between subspaces
            angle_rule_subspace_nopv = np.rad2deg(subspace_angles(subspace_rule1_nopv.T, subspace_rule2_nopv.T)[0])

    
    
    
        # add to the data dict
        all_data[model_name]['angle_rule_subspace_nopv'] = angle_rule_subspace_nopv
        all_data[model_name]['angle_rule_subspace_nosst'] = angle_rule_subspace_nosst
       

In [None]:
# plot

# subspace angle when silencing SST neurons
fig, ax = plt.subplots(1, 1, figsize=[3, 4])
ax.set_title('principle angle between\n rule subspaces')

model_list = list(all_data.keys())

for model in model_list:
    if all_data[model]['hp']['dend_nonlinearity']!='divisive_2':
        continue
    if ~np.isnan(all_data[model]['angle_rule_subspace_nosst']):
        ax[0].plot([0, 1], [all_data[model]['angle_rule_subspace'], all_data[model]['angle_rule_subspace_nosst']], color='k', alpha=0.5, marker='o')
make_pretty_axes(ax)
ax.set_ylim([0, 90])
ax.set_yticks([0, 90])
ax.set_xlim([-0.5, 1.5])
ax.set_xticks([0, 1], ['intact', 'silenced \nSST'])
fig.tight_layout()
plt.show()
