In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
from textwrap import wrap
from scipy.stats import wilcoxon


from functions import *

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    

# Figure 5b: joint distribution of selectivity indices for different task variables

## Generate data

In [None]:
start = time.time()

plt.rc('font', size=12)

model_dir = ''
test_data_dir = ''

all_data_sm_rule_resp_sel = []

for model_name in os.listdir(model_dir):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name)
        
        # load model
        path_to_file = model_dir+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
            
        # load test data
        with open(test_data_dir+model_name+'_testdata_noiseless', 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean(test_data['perfs'])
        if mean_perf <= 0.8:
            print('low performance ({}), pass\n'.format(mean_perf))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=np.asarray(test_data['rules']),
                                         rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                         rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                         resp_trs_stable = resp_trs_stable, trs_by_center_card = trial_labels['trs_by_center_card_stable'],
                                         stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)


        # define subpopulations
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], 
                                          ref_card_sel=all_sels['ref_card_normalized'],
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule_threshold=0, resp_threshold=0)
        
               
        all_data_sm_rule_resp_sel.append({'model': model, 'model_name': model_name, 'hp': hp_test, 'cell_group_list': model.rnn.cell_group_list,
                                          'rule_sel': all_sels['rule_activity'], 
                                          'rule_sel_roc': all_sels['rule_sel_roc'],
                                          'rule_sel_norm': all_sels['rule_normalized_activity'],
                                          'resp_sel': all_sels['resp'],
                                          'resp_sel_norm': all_sels['resp_normalized'],
                                          'mean_act': np.mean(rnn_activity, axis=(0, 1, 2)),
                                          'subcg_sr_idx': subcg_sr_idx,
                                          'all_sels': all_sels                                        
                                         })
print(time.time()-start)

In [None]:
# rule selectivity x response selectivity x feature selectivity across all models

rule_sel_all = []
resp_sel_all = []
feature_sel_all = []

cg_idx = model.rnn.cg_idx['sr_esoma'].tolist() + model.rnn.cg_idx['sr_pv'].tolist() + model.rnn.cg_idx['sr_sst'].tolist() + model.rnn.cg_idx['sr_vip'].tolist()

for x in all_data_sm_rule_resp_sel:
    rule_sel_all.extend(x['all_sels']['rule_normalized_activity'][n] for n in cg_idx)
    resp_sel_all.extend(x['all_sels']['resp_normalized'][n]['max'] for n in cg_idx)
    feature_sel_all.extend(x['all_sels']['common_feature'][n] for n in cg_idx)


fig = plt.figure()
ax = fig.add_subplot(projection='3d')
fig.patch.set_facecolor('white')
ax.scatter(xs=np.array(rule_sel_all), ys=np.array(resp_sel_all), zs=np.array(feature_sel_all), color='k')  
ax.set_xlabel('Rule selectivity')
ax.set_ylabel('Response selectivity')
ax.set_zlabel('Shared feature selectivity')
fig.tight_layout()
plt.show()

