In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=2); np.random.seed(0)
import torch; torch.set_printoptions(precision=2)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt
# import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
from mpltern.ternary.datasets import get_scatter_points


import seaborn as sns
import time
import sys
import itertools 
import random; random.seed(0)
import datetime
import pickle
import copy
import pandas as pd
import scipy; from scipy import stats; from scipy.stats import wilcoxon
import os

from sklearn.cluster import KMeans
from sklearn.manifold import MDS
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity

import sys
sys.path.append("../two_module_rnn/code")
os.chdir('/home/yl4317/Documents/two_module_rnn/code')
# from model_working import *
from functions import *

os.chdir('/home/yl4317/Documents/two_module_rnn/code/code_for_figs')

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

# Generate data: input and output weight of the sensorimotor module

In [None]:
start = time.time()
plt.rc('font', size=12)

all_data = []

for model_name in sorted(os.listdir('/scratch/yl4317/two_module_rnn/saved_models/')):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name+'\n')
        
        # load model
        path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        # for key in ['dend_nonlinearity', 'sparse_srsst_to_sredend', 'initialization_weights', 'activation']:
        #     print(key, hp_test[key])

        # load test data
        with open('/scratch/yl4317/two_module_rnn/saved_testdata/'+model_name+'_testdata_noiseless_no_current_matrix', 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_test_perf = np.mean([_[0] for _ in test_data['perfs']])
        if mean_test_perf<=0.8:
            print('perf too low ({}), pass\n'.format(mean_test_perf))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                         rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                         rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                         resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                         stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)

        # define neuron pools
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], ref_card_sel=all_sels['ref_card_normalized'],
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule_threshold=0, resp_threshold=0)
        for subcg in subcg_sr_idx.keys():
            model.rnn.cg_idx['subcg_sr_'+subcg] = subcg_sr_idx[subcg]
            
        w_in_eff = model.rnn.effective_weight(w=model.rnn.w_in, mask=model.rnn.mask_in).detach().cpu().numpy()
        all_data.append({'model': model, 'model_name': model_name, 'hp': hp_test, 
                         'n_sr_esoma': model.rnn.n['sr_esoma'], 'n_branches': model.rnn.n_branches, 
                         'w_in_eff': w_in_eff, 'subcg_sr_idx': subcg_sr_idx, 'all_sels': all_sels})

with open('/home/yl4317/Documents/two_module_rnn/processed_data/conn_bias_sm_w_in.pickle', 'wb') as handle:
        pickle.dump(all_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('Elapsed time: {}s'.format(time.time()-start))     

# Figure 5c: bias in the input weight, one example model

In [None]:
with open('/home/yl4317/Documents/two_module_rnn/processed_data/conn_bias_sm_w_in.pickle', 'rb') as handle:
    all_data = pickle.load(handle)


data_fig5c = {'big': [], 'small': []}

for data in all_data:
    if data['model_name'] != 'success_2023-05-10-14-28-42_wcst_106_sparsity0':    # this is the example model
        continue
    w_in_larges = []
    w_in_large_refs = []
    w_in_large_tests = []
    w_in_smalls = []
    for r in [1, 2]:
        for choice in [1, 2, 3]:
            for card in ['(0, 0)', '(1, 0)', '(0, 1)', '(1, 1)']:
                neuron_id_rule = data['subcg_sr_idx']['rule{}_sr_esoma'.format(r)]
                neuron_id_choice = data['subcg_sr_idx']['respc{}_sr_esoma'.format(choice)]
                neuron_id_refcard = data['subcg_sr_idx']['ref_card{}_sr_esoma'.format(card)]
                neuron_id = [n for n in neuron_id_rule if (n in neuron_id_choice and n in neuron_id_refcard)]    # all neurons that prefer a given combination of rule, choice and reference card (and therefore feature)        
                # print('rule {}, choice {}, card {}\nneuron_id {}'.format(r, choice, card, neuron_id))
                if (r==1 and (card=='(0, 0)' or card=='(0, 1)')) or (r==2 and (card=='(0, 0)' or card=='(1, 0)')):
                    feature_id = 0    # this neuron prefers when the matching feature is blue or circle
                else:
                    feature_id = 1    # this neuron prefers when the matching feature is red or square 
#                 input_neuron_id_large = [2*(r-1) + feature_id, 4+4*(choice-1) + 2*(r-1) + feature_id]    # the indices of the input neurons that should have a strong projection to this neuron
                input_neuron_id_large_ref = 2*(r-1) + feature_id
                input_neuron_id_large_test = 4+4*(choice-1) + 2*(r-1) + feature_id
                input_neuron_id_large = [input_neuron_id_large_ref, input_neuron_id_large_test]    # the indices of the input neurons that should have a strong projection to this neuron
                input_neuron_id_small = [i for i in range(data['w_in_eff'].shape[0]) if i not in input_neuron_id_large]
                for n in neuron_id:
                    dend_id = [n + (b+1)*data['n_sr_esoma'] for b in range(data['n_branches'])]    # the dendritic indices
                    w_in_large = data['w_in_eff'][np.ix_(input_neuron_id_large, dend_id)]
                    w_in_small = data['w_in_eff'][np.ix_(input_neuron_id_small, dend_id)]
                    w_in_large_ref = data['w_in_eff'][input_neuron_id_large_ref, dend_id]
                    w_in_large_test = data['w_in_eff'][input_neuron_id_large_test, dend_id]
                    
                    w_in_larges.append(np.mean(w_in_large))
                    w_in_smalls.append(np.mean(w_in_small))
                    w_in_large_refs.append(np.mean(w_in_large_ref))
                    w_in_large_tests.append(np.mean(w_in_large_test))
                    
    #=== plotting ===#
    fig, ax = plt.subplots(figsize=[10,7])
    fig.patch.set_facecolor('white')
    for i in range(len(w_in_larges)):
        ax.plot([w_in_larges[i], w_in_smalls[i]], marker='o', markersize=10, alpha=0.5, color='k')
    y = [w_in_larges, w_in_smalls]
    
    ax.set_xlim([-0.2, 1.2])
    ax.set_ylabel(r'Input weight', fontsize=20)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Preferred\nfeature', 'Non-preferred\nfeature'], rotation=0, fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()

    # save figure
    # fig.savefig('/home/yl4317/Documents/two_module_rnn/figs/w_in_sm_example.pdf')
    
    # statistical test
    t, p = scipy.stats.ttest_ind(y[0], y[1], alternative='greater')
    print('t={}, p-value={}, n={}'.format(t, p, len(y[0])))

    # save source data
    data_fig5c['big'] = w_in_larges
    data_fig5c['small'] = w_in_smalls
    # pd.DataFrame.from_dict(data=data_fig5c, orient='index').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/fig5c_w_in_example.csv', header=False)





# Supplementary figure 9a, b: across all networks

In [None]:
data_suppfig9ab = {'subtractive': {}, 'divisive_2': {}}
for key in data_suppfig9.keys():
    data_suppfig9ab[key] = {'big': [], 'small': []}
    
for dend_nonlinear in ['subtractive', 'divisive_2']:
    w_in_large_all_models = []
    w_in_large_ref_all_models = []
    w_in_large_test_all_models = []
    w_in_small_all_models = []
    for data in all_data:
        if data['hp']['dend_nonlinearity'] != dend_nonlinear: 
            continue
        w_in_larges = []
        w_in_large_refs = []
        w_in_large_tests = []
        w_in_smalls = []
        for r in [1, 2]:
            for choice in [1, 2, 3]:
                for card in ['(0, 0)', '(1, 0)', '(0, 1)', '(1, 1)']:
                    neuron_id_rule = data['subcg_sr_idx']['rule{}_sr_esoma'.format(r)]
                    neuron_id_choice = data['subcg_sr_idx']['respc{}_sr_esoma'.format(choice)]
                    neuron_id_refcard = data['subcg_sr_idx']['ref_card{}_sr_esoma'.format(card)]
                    neuron_id = [n for n in neuron_id_rule if (n in neuron_id_choice and n in neuron_id_refcard)]    # all neurons that prefer a given combination of rule, choice and reference card (and therefore feature)        
                    if (r==1 and (card=='(0, 0)' or card=='(0, 1)')) or (r==2 and (card=='(0, 0)' or card=='(1, 0)')):
                        feature_id = 0    # this neuron prefers when the matching feature is blue or circle
                    else:
                        feature_id = 1    # this neuron prefers when the matching feature is red or square 
                    input_neuron_id_large_ref = 2*(r-1) + feature_id
                    input_neuron_id_large_test = 4+4*(choice-1) + 2*(r-1) + feature_id
                    input_neuron_id_large = [input_neuron_id_large_ref, input_neuron_id_large_test]    # the indices of the input neurons that should have a strong projection to this neuron
                    input_neuron_id_small = [i for i in range(data['w_in_eff'].shape[0]) if i not in input_neuron_id_large]
                    for n in neuron_id:
                        dend_id = [n + (b+1)*data['n_sr_esoma'] for b in range(data['n_branches'])]    # the dendritic indices
                        w_in_large = data['w_in_eff'][np.ix_(input_neuron_id_large, dend_id)]
                        w_in_small = data['w_in_eff'][np.ix_(input_neuron_id_small, dend_id)]
                        w_in_large_ref = data['w_in_eff'][input_neuron_id_large_ref, dend_id]
                        w_in_large_test = data['w_in_eff'][input_neuron_id_large_test, dend_id]
                        
                        w_in_larges.append(np.mean(w_in_large))
                        w_in_smalls.append(np.mean(w_in_small))
                        w_in_large_refs.append(np.mean(w_in_large_ref))
                        w_in_large_tests.append(np.mean(w_in_large_test))
    
        w_in_large_all_models.extend(w_in_larges)
        w_in_large_ref_all_models.extend(w_in_large_refs)
        w_in_large_test_all_models.extend(w_in_large_tests)
        w_in_small_all_models.extend(w_in_smalls)
    
    #=== plotting ===#
    fig, ax = plt.subplots(figsize=[10,7])
    fig.suptitle(dend_nonlinear)
    fig.patch.set_facecolor('white')
    for i in range(len(w_in_large_all_models)):
        ax.plot([w_in_large_all_models[i], w_in_small_all_models[i]], marker='o', markersize=10, alpha=0.1, color='k')
    y = [w_in_large_all_models, w_in_small_all_models]
    
    ax.set_xlim([-0.2, 1.2])
    ax.set_ylabel(r'Input weight', fontsize=20)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Preferred\nfeature', 'Non-preferred\nfeature'], rotation=0, fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()

    # save figure
    fig.savefig('/home/yl4317/Documents/two_module_rnn/figs/w_in_sm_allNetworks_{}.pdf'.format(dend_nonlinear))
    
    # statistical test
    t, p = scipy.stats.ttest_ind(y[0], y[1], alternative='greater')
    print('t={}, p-value={}, n={}'.format(t, p, len(y[0])))

    # save source data
    data_suppfig9ab[dend_nonlinear]['big'] = w_in_large_all_models
    data_suppfig9ab[dend_nonlinear]['small'] = w_in_small_all_models
    pd.DataFrame.from_dict(data=data_suppfig9ab[dend_nonlinear], orient='index').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/suppfig9ab_w_in_allNets_{}.csv'.format(dend_nonlinear), header=False)
    

# Generate data for output weight

In [None]:
start = time.time()
plt.rc('font', size=12)

all_data_wout = []

for model_name in sorted(os.listdir('/scratch/yl4317/two_module_rnn/saved_models/')):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name+'\n')
        
        # load model
        path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)

        # load test data
        with open('/scratch/yl4317/two_module_rnn/saved_testdata/'+model_name+'_testdata_noiseless_no_current_matrix', 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_test_perf = np.mean([_[0] for _ in test_data['perfs']])
        if mean_test_perf<=0.8:
            print('perf too low ({}), pass\n'.format(mean_test_perf))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
#         current_matrix = neural_data['current_matrix']
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                         rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                         rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                         resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                         stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)

        # define neuron pools
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], ref_card_sel=all_sels['ref_card_normalized'],
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule_threshold=0, resp_threshold=0)
        for subcg in subcg_sr_idx.keys():
            model.rnn.cg_idx['subcg_sr_'+subcg] = subcg_sr_idx[subcg]
            
        w_out_eff = model.rnn.effective_weight(w=model.rnn.w_out, mask=model.mask_out).detach().cpu().numpy()

        all_data_wout.append({
                              'model': model, 
                              'model_name': model_name, 
                              'hp': hp_test, 
                              'w_out_eff': w_out_eff,
                              'subcg_sr_idx': subcg_sr_idx
                             })

with open('/home/yl4317/Documents/two_module_rnn/processed_data/conn_bias_sm_w_out.pickle', 'wb') as handle:
        pickle.dump(all_data_wout, handle, protocol=pickle.HIGHEST_PROTOCOL)
                
print('Elapsed time: {}s'.format(time.time()-start))     

# Figure 5d: w_out for an example model

In [None]:
with open('/home/yl4317/Documents/two_module_rnn/processed_data/conn_bias_sm_w_out.pickle', 'rb') as handle:
    all_data_wout = pickle.load(handle)

data_fig5d = {'big': [], 'small': []}

for data in all_data_wout:
    if data['model_name'] != 'success_2023-05-10-14-28-42_wcst_136_sparsity0':    # this is the example model shown in the paper
        continue
    print(data['model_name'])
    wout_pref = []
    wout_nonpref = []
    wout = data['w_out_eff']
    for resp in ['c1', 'c2', 'c3']:
        neuron_idx = data['subcg_sr_idx']['resp{}_sr_esoma'.format(resp)]
        # print('resp {}, neuron_idx {}'.format(resp, neuron_idx))
        if resp=='c1':
            wout_idx_pref, wout_idx_nonpref = 0, [1, 2]
        elif resp=='c2':
            wout_idx_pref, wout_idx_nonpref = 1, [0, 2]
        elif resp=='c3':
            wout_idx_pref, wout_idx_nonpref = 2, [0, 1]
        wout_pref.extend(wout[neuron_idx, wout_idx_pref])
        wout_nonpref.extend(np.mean(wout[np.ix_(neuron_idx, wout_idx_nonpref)], axis=1))

    #===== plotting =====#
    fig, ax = plt.subplots(figsize=[10,7])
    fig.patch.set_facecolor('white')
    ax.plot([0, 1], [wout_pref, wout_nonpref], color='k', marker='o', alpha=0.5)
    ax.set_xlim([-0.2, 1.2])
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Preferred\nchoice', 'Non-preferred\nchoice'], rotation=0)
    ax.set_ylabel('Readout weight', fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()

# statistical test
t, p = scipy.stats.ttest_ind(wout_pref, wout_nonpref, alternative='greater')
print('t={}, p-value={}, n={}'.format(t, p, len(wout_pref)))

# save figure
# fig.savefig('/home/yl4317/Documents/two_module_rnn/figs/w_out_example.pdf')

# collect source data
data_fig5d['big'].extend(wout_pref)
data_fig5d['small'].extend(wout_nonpref)
# pd.DataFrame.from_dict(data=data_fig5d, orient='index').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/fig5d.csv', header=False)

# Supplementary Figure 9c, d: w_out, across all models

In [None]:
data_suppfig19cd = {'subtractive': {}, 'divisive_2': {}}
for key in data_suppfig19cd.keys():
    data_suppfig19cd[key] = {'big': [], 'small': []}


for dend_nonlinear in ['subtractive', 'divisive_2']:
    wout_pref_all = []
    wout_nonpref_all = []
    for data in all_data_wout:
        if data['hp']['dend_nonlinearity'] != dend_nonlinear:   
            continue
        # wout_pref = []
        # wout_nonpref = []
        wout = data['w_out_eff']
        for resp in ['c1', 'c2', 'c3']:
            neuron_idx = data['subcg_sr_idx']['resp{}_sr_esoma'.format(resp)]
            if resp=='c1':
                wout_idx_pref, wout_idx_nonpref = 0, [1, 2]
            elif resp=='c2':
                wout_idx_pref, wout_idx_nonpref = 1, [0, 2]
            elif resp=='c3':
                wout_idx_pref, wout_idx_nonpref = 2, [0, 1]
            wout_pref_all.extend(wout[neuron_idx, wout_idx_pref])
            wout_nonpref_all.extend(np.mean(wout[np.ix_(neuron_idx, wout_idx_nonpref)], axis=1))
            
    #===== plotting =====#
    fig, ax = plt.subplots(figsize=[10,7])
    fig.patch.set_facecolor('white')
    ax.plot([0, 1], [wout_pref_all, wout_nonpref_all], color='k', marker='o', alpha=0.1)
    ax.set_xlim([-0.2, 1.2])
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Preferred\nchoice', 'Non-preferred\nchoice'], rotation=0)
    ax.set_ylabel('Readout weight', fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()

    # statistical test
    t, p = scipy.stats.ttest_ind(wout_pref_all, wout_nonpref_all, alternative='greater')
    print('t={}, p-value={}, n={}'.format(t, p, len(wout_pref_all)))

    # save figure
    fig.savefig('/home/yl4317/Documents/two_module_rnn/figs/w_out_allNets_{}.pdf'.format(dend_nonlinear))
    
    # collect source data
    data_suppfig19cd[dend_nonlinear]['big'].extend(wout_pref_all)
    data_suppfig19cd[dend_nonlinear]['small'].extend(wout_nonpref_all)
    pd.DataFrame.from_dict(data=data_suppfig19cd[dend_nonlinear], orient='index').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/suppfig19cd_wout_{}.csv'.format(dend_nonlinear), header=False)

# Fig 6e-g Compute the connectivity bias between the sensorimotor populations

In [None]:
# whether SR E soma are selective for both rule and response
start = time.time()
plot = False

conn_bias_sr_all_models = []

for model_name in sorted(os.listdir('/scratch/yl4317/two_module_rnn/saved_models/')):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
#     if 'success_2022-09-19-13-52-05_wcst_15_longer_iti_5dend_nonlineaerities' in model_name:
#     if ('2022-10-24' in model_name) and 'success' in model_name:
        print(model_name)
        
        # load model
        path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)
#         print(hp_test['dend_nonlinearity'])
        # add filters here
        if hp_test['dt']!=10:
            continue
        if hp_test['dend_nonlinearity'] not in ['subtractive', 'divisive_2']:
            continue
#         if hp_test['sparse_srsst_to_sredend']!=0:
#             continue
#         if hp_test['no_pfcesoma_to_srsst']==True:
#             continue
        
        for key in ['dend_nonlinearity', 'sparse_srsst_to_sredend', 'initialization_weights', 'activation']:
            print(key, hp_test[key])
            
        # make noiseless
#         model.rnn.network_noise = 0
#         hp_test['input_noise_perceptual'] = 0
#         hp_test['input_noise_rule'] = 0
        
        # generate some neural data
#         neural_data = generate_neural_data_test(model=model, n_trials_test=100, switch_every_test=10, n_switches=5, to_plot=False, hp_test=hp_test, hp_task_test=hp_task_test)
        with open('/scratch/yl4317/two_module_rnn/saved_testdata/{}'.format(model_name+'_testdata_noiseless_no_current_matrix'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        if mean_perf<0.8 or mean_perf_rule<0.8:
            print('low performing model ({}/{})'.format(mean_perf, mean_perf_rule))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()

        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        resp_trs_stable = {'c1': trial_labels['c1_trs_stable'], 'c2': trial_labels['c2_trs_stable'], 'c3': trial_labels['c3_trs_stable']}
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                     rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                     rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error = trial_labels['rule2_trs_after_error'],
                                     resp_trs_stable=resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                     stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)
        rule_sel_used= all_sels['rule_normalized_activity']    
        
        
        resp_sel_normalized = all_sels['resp_normalized']
        rule_sel_normalized = all_sels['rule_normalized_activity']

        # subregions
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], 
                                          ref_card_sel=all_sels['ref_card_normalized'],
                                          rule1_trs_stable=trial_labels['rule1_trs_stable'], 
                                          rule2_trs_stable=trial_labels['rule2_trs_stable'], 
                                          rule_threshold=0.0, resp_threshold=0.0)
    
        # plot connectivity between subpopulations
        w_rec_eff = model.rnn.effective_weight(w=model.rnn.w_rec, mask=model.rnn.mask, w_fix=model.rnn.w_fix).detach().cpu().numpy()
        if plot==True:
            _, _ = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_sr_idx, subcg_to_plot_sender=['rule1_sr_esoma', 'rule2_sr_esoma', 'rule1_sr_pv', 'rule2_sr_pv'], subcg_to_plot_receiver=['rule1_sr_esoma', 'rule2_sr_esoma', 'rule1_sr_pv', 'rule2_sr_pv'])
#             _, _ = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_sr_idx, subcg_to_plot_sender=['rule1_sr_pv', 'rule2_sr_pv'], subcg_to_plot_receiver=['rule1_sr_pv', 'rule2_sr_pv'])
#             resp_pops = ['{}_{}'.format(resp, cell) for (resp, cell) in itertools.product(*[['respc1', 'respc2', 'respc3'], ['sr_esoma', 'sr_pv']])]
            resp_pops = ['respc1_sr_esoma', 'respc2_sr_esoma', 'respc3_sr_esoma', 'respc1_sr_pv', 'respc2_sr_pv', 'respc3_sr_pv']
            _, _ = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_sr_idx, subcg_to_plot_sender=resp_pops, subcg_to_plot_receiver=resp_pops)
            _, _ = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_sr_idx, subcg_to_plot_sender=['respc1_sr_pv', 'respc2_sr_pv', 'respc3_sr_pv'], subcg_to_plot_receiver=['respc1_sr_pv', 'respc2_sr_pv', 'respc3_sr_pv'])
#             ref_card_pops = ['ref_card(0, 0)_sr_esoma', 'ref_card(0, 1)_sr_esoma', 'ref_card(1, 0)_sr_esoma', 'ref_card(1, 1)_sr_esoma', 
#                              'ref_card(0, 0)_sr_pv', 'ref_card(0, 1)_sr_pv', 'ref_card(1, 0)_sr_pv', 'ref_card(1, 1)_sr_pv', ]
#             _, _ = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_sr_idx, subcg_to_plot_sender=ref_card_pops, subcg_to_plot_receiver=ref_card_pops)
            shared_feature_pops_color = ['blue_sr_esoma', 'red_sr_esoma', 'blue_sr_pv', 'red_sr_pv']
            _, _ = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_sr_idx, subcg_to_plot_sender=shared_feature_pops_color, subcg_to_plot_receiver=shared_feature_pops_color)
            shared_feature_pops_shape = ['circle_sr_esoma', 'triangle_sr_esoma', 'circle_sr_pv', 'triangle_sr_pv']
            _, _ = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_sr_idx, subcg_to_plot_sender=shared_feature_pops_shape, subcg_to_plot_receiver=shared_feature_pops_shape)
            
        
        
        
        
        # compute connectivity bias for rule
        conn_bias_rulee_rulee = np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule1_sr_esoma'], subcg_sr_idx['rule1_sr_esoma'])])\
                                  +np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule2_sr_esoma'], subcg_sr_idx['rule2_sr_esoma'])])\
                                  -np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule1_sr_esoma'], subcg_sr_idx['rule2_sr_esoma'])])\
                                  -np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule2_sr_esoma'], subcg_sr_idx['rule1_sr_esoma'])])
        conn_bias_rulee_rulepv = np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule1_sr_esoma'], subcg_sr_idx['rule1_sr_pv'])])\
                                  +np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule2_sr_esoma'], subcg_sr_idx['rule2_sr_pv'])])\
                                  -np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule1_sr_esoma'], subcg_sr_idx['rule2_sr_pv'])])\
                                  -np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule2_sr_esoma'], subcg_sr_idx['rule1_sr_pv'])])
        conn_bias_rulepv_rulee = np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule1_sr_pv'], subcg_sr_idx['rule2_sr_esoma'])])\
                                  +np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule2_sr_pv'], subcg_sr_idx['rule1_sr_esoma'])])\
                                  -np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule1_sr_pv'], subcg_sr_idx['rule1_sr_esoma'])])\
                                  -np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule2_sr_pv'], subcg_sr_idx['rule2_sr_esoma'])])
        conn_bias_rulepv_rulee = -conn_bias_rulepv_rulee
        conn_bias_rulepv_rulepv = np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule1_sr_pv'], subcg_sr_idx['rule2_sr_pv'])])\
                                  +np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule2_sr_pv'], subcg_sr_idx['rule1_sr_pv'])])\
                                  -np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule1_sr_pv'], subcg_sr_idx['rule1_sr_pv'])])\
                                  -np.mean(w_rec_eff[np.ix_(subcg_sr_idx['rule2_sr_pv'], subcg_sr_idx['rule2_sr_pv'])])
        conn_bias_rulepv_rulepv = -conn_bias_rulepv_rulepv
        
                     
            
            
        # connectivity bias based on neurons preferring different responses              
        conn = {}
        conn[('resp_sr_esoma', 'resp_sr_esoma')] = {}
        conn[('resp_sr_esoma', 'resp_sr_pv')] = {}
        conn[('resp_sr_pv', 'resp_sr_esoma')] = {}
        conn[('resp_sr_pv', 'resp_sr_pv')] = {}
        for key in conn.keys():
            for x in ['same', 'cross']:
                conn[key][x] = []
        
        for resp1 in ['c1', 'c2', 'c3']:
            for resp2 in ['c1', 'c2', 'c3']:
                if resp1==resp2:
                    for sender, receiver in itertools.product(['sr_esoma', 'sr_pv'], repeat=2):
                        conn[('resp_{}'.format(sender), 'resp_{}'.format(receiver))]['same'].append(np.mean(w_rec_eff[np.ix_(subcg_sr_idx['resp{}_{}'.format(resp1, sender)], subcg_sr_idx['resp{}_{}'.format(resp2, receiver)])]))
                elif resp1!=resp2:
                    for sender, receiver in itertools.product(['sr_esoma', 'sr_pv'], repeat=2):
                        conn[('resp_{}'.format(sender), 'resp_{}'.format(receiver))]['cross'].append(np.mean(w_rec_eff[np.ix_(subcg_sr_idx['resp{}_{}'.format(resp1, sender)], subcg_sr_idx['resp{}_{}'.format(resp2, receiver)])]))     
                        
        conn_bias_respe_respe = np.mean(conn[('resp_sr_esoma', 'resp_sr_esoma')]['same']) - np.mean(conn[('resp_sr_esoma', 'resp_sr_esoma')]['cross'])
        conn_bias_respe_resppv = np.mean(conn[('resp_sr_esoma', 'resp_sr_pv')]['same']) - np.mean(conn[('resp_sr_esoma', 'resp_sr_pv')]['cross'])
        conn_bias_resppv_respe = - (np.mean(conn[('resp_sr_pv', 'resp_sr_esoma')]['cross']) - np.mean(conn[('resp_sr_pv', 'resp_sr_esoma')]['same']))
        conn_bias_resppv_resppv = - (np.mean(conn[('resp_sr_pv', 'resp_sr_pv')]['cross']) - np.mean(conn[('resp_sr_pv', 'resp_sr_pv')]['same']))
        
        
        
        
        # connectivity bias based on neurons preferring different reference cards
        conn = {}
        conn[('ref_card_sr_esoma', 'ref_card_sr_esoma')] = {}
        conn[('ref_card_sr_esoma', 'ref_card_sr_pv')] = {}
        conn[('ref_card_sr_pv', 'ref_card_sr_esoma')] = {}
        conn[('ref_card_sr_pv', 'ref_card_sr_pv')] = {}
        for key in conn.keys():
            for x in ['same', 'cross']:
                conn[key][x] = []
            
        for ref_card1 in [(0, 0), (0, 1), (1, 0), (1, 1)]:
            for ref_card2 in [(0, 0), (0, 1), (1, 0), (1, 1)]:
                if ref_card1==ref_card2:
                    for sender, receiver in itertools.product(['sr_esoma', 'sr_pv'], repeat=2):
                        conn[('ref_card_{}'.format(sender), 'ref_card_{}'.format(receiver))]['same'].append(np.mean(w_rec_eff[np.ix_(subcg_sr_idx['ref_card{}_{}'.format(ref_card1, sender)], subcg_sr_idx['ref_card{}_{}'.format(ref_card2, receiver)])]))
                elif ref_card1!=ref_card2:
                    for sender, receiver in itertools.product(['sr_esoma', 'sr_pv'], repeat=2):
                        conn[('ref_card_{}'.format(sender), 'ref_card_{}'.format(receiver))]['cross'].append(np.mean(w_rec_eff[np.ix_(subcg_sr_idx['ref_card{}_{}'.format(ref_card1, sender)], subcg_sr_idx['ref_card{}_{}'.format(ref_card2, receiver)])]))     
                        
        conn_bias_ref_card_e_ref_card_e = np.mean(conn[('ref_card_sr_esoma', 'ref_card_sr_esoma')]['same']) - np.mean(conn[('ref_card_sr_esoma', 'ref_card_sr_esoma')]['cross'])
        conn_bias_ref_card_e_ref_card_pv = np.mean(conn[('ref_card_sr_esoma', 'ref_card_sr_pv')]['same']) - np.mean(conn[('ref_card_sr_esoma', 'ref_card_sr_pv')]['cross'])
        conn_bias_ref_card_pv_ref_card_e = - (np.mean(conn[('ref_card_sr_pv', 'ref_card_sr_esoma')]['cross']) - np.mean(conn[('ref_card_sr_pv', 'ref_card_sr_esoma')]['same']))
        conn_bias_ref_card_pv_ref_card_pv = - (np.mean(conn[('ref_card_sr_pv', 'ref_card_sr_pv')]['cross']) - np.mean(conn[('ref_card_sr_pv', 'ref_card_sr_pv')]['same']))
        
        
        # connectivity bias based on neurons preferring different reference cards
        conn = {}
        conn[('shared_feature_sr_esoma', 'shared_feature_sr_esoma')] = {}
        conn[('shared_feature_sr_esoma', 'shared_feature_sr_pv')] = {}
        conn[('shared_feature_sr_pv', 'shared_feature_sr_esoma')] = {}
        conn[('shared_feature_sr_pv', 'shared_feature_sr_pv')] = {}
        for key in conn.keys():
            for x in ['same', 'cross']:
                conn[key][x] = []
        
        shared_features = ['blue', 'red', 'circle', 'triangle']
        for f1 in shared_features:
            for f2 in shared_features:
                for sender, receiver in itertools.product(['sr_esoma', 'sr_pv'], repeat=2):
                    mean_weight = np.mean(w_rec_eff[np.ix_(subcg_sr_idx['{}_{}'.format(f1, sender)], subcg_sr_idx['{}_{}'.format(f2, receiver)])])
                    if f1==f2:
                        conn[('shared_feature_{}'.format(sender), 'shared_feature_{}'.format(receiver))]['same'].append(mean_weight)
                    elif f1!=f2 and shared_features.index(f1)//2 == shared_features.index(f2)//2:    # do not include for example f1=blue and f2=square
                        conn[('shared_feature_{}'.format(sender), 'shared_feature_{}'.format(receiver))]['cross'].append(mean_weight)     
                        
        conn_bias_shared_feature_e_e = np.mean(conn[('shared_feature_sr_esoma', 'shared_feature_sr_esoma')]['same']) - np.mean(conn[('shared_feature_sr_esoma', 'shared_feature_sr_esoma')]['cross'])
        conn_bias_shared_feature_e_pv = np.mean(conn[('shared_feature_sr_esoma', 'shared_feature_sr_pv')]['same']) - np.mean(conn[('shared_feature_sr_esoma', 'shared_feature_sr_pv')]['cross'])
        conn_bias_shared_feature_pv_e = - (np.mean(conn[('shared_feature_sr_pv', 'shared_feature_sr_esoma')]['cross']) - np.mean(conn[('shared_feature_sr_pv', 'shared_feature_sr_esoma')]['same']))
        conn_bias_shared_feature_pv_pv = - (np.mean(conn[('shared_feature_sr_pv', 'shared_feature_sr_pv')]['cross']) - np.mean(conn[('shared_feature_sr_pv', 'shared_feature_sr_pv')]['same']))
        
        
        
        
        
        conn_bias_sr_all_models.append({'model': model_name, 
                                        'hp': hp_test,
                                        'bias_ruleesoma_ruleesoma': conn_bias_rulee_rulee, 
                                        'bias_ruleesoma_rulepv': conn_bias_rulee_rulepv, 
                                        'bias_rulepv_ruleesoma': conn_bias_rulepv_rulee, 
                                        'bias_rulepv_rulepv': conn_bias_rulepv_rulepv,
                                        'bias_respesoma_respesoma': conn_bias_respe_respe, 
                                        'bias_respesoma_resppv': conn_bias_respe_resppv, 
                                        'bias_resppv_respesoma': conn_bias_resppv_respe, 
                                        'bias_resppv_resppv': conn_bias_resppv_resppv,
                                        'bias_ref_card_esoma_ref_card_esoma': conn_bias_ref_card_e_ref_card_e, 
                                        'bias_ref_card_esoma_ref_card_pv': conn_bias_ref_card_e_ref_card_pv, 
                                        'bias_ref_card_pv_ref_card_esoma': conn_bias_ref_card_pv_ref_card_e, 
                                        'bias_ref_card_pv_ref_card_pv': conn_bias_ref_card_pv_ref_card_pv,
                                        'bias_shared_feature_esoma_esoma': conn_bias_shared_feature_e_e, 
                                        'bias_shared_feature_esoma_pv': conn_bias_shared_feature_e_pv, 
                                        'bias_shared_feature_pv_esoma': conn_bias_shared_feature_pv_e, 
                                        'bias_shared_feature_pv_pv': conn_bias_shared_feature_pv_pv,
                                       })
        # print('biases: {}'.format(conn_bias_sr_all_models[-1][2:]))

print(time.time()-start)

In [None]:
data_fig7e = {'exc_exc': [], 'exc_pv': [], 'pv_exc': [], 'pv_pv': []}    # connectivity biases between rule-selective populations
data_fig7f = {'exc_exc': [], 'exc_pv': [], 'pv_exc': [], 'pv_pv': []}    # connectivity biases between response location-selective populations
data_fig7g = {'exc_exc': [], 'exc_pv': [], 'pv_exc': [], 'pv_pv': []}    # connectivity biases between shared feature-selective populations

fig, ax = plt.subplots(figsize=[7,3])
fig.suptitle('Connectivity biases across all models', fontsize=20)
fig.patch.set_facecolor('white')
for x in conn_bias_sr_all_models:
    if x['hp']['dend_nonlinearity'] not in ['subtractive']:
        continue
    data = list(x.values())[2:6]
    if np.isnan(data).any():
        continue
#     if 'success_2022-09-19-13-52-05_wcst_15_longer_iti_5dend_nonlineaerities' not in x['hp']['save_name']:
#         continue

    data_fig7e['exc_exc'].append(data[0])
    data_fig7e['exc_pv'].append(data[1])
    data_fig7e['pv_exc'].append(data[2])
    data_fig7e['pv_pv'].append(data[3])

    
    ax.plot(data, marker='o', color='k', linewidth=2, markersize=10, alpha=0.5)
ax.set_xticks(np.arange(len(data)))
ax.set_xticklabels([r'rule E $\rightarrow$ rule E', r'rule E $\rightarrow$ rule PV', r'rule PV $\rightarrow$ rule E', r'rule PV $\rightarrow$ rule PV'], rotation=10)
ax.axhline(y=0, ls='--', color='k')
ax.set_xlim(-0.5, len(data)-0.5)
# ax.set_ylim(-1,2)
ax.set_ylabel('Connectivity bias', fontsize=20)
make_pretty_axes(ax)
plt.show()



fig, ax = plt.subplots(figsize=[7,3])
fig.suptitle('Connectivity biases across all models', fontsize=20)
fig.patch.set_facecolor('white')
for x in conn_bias_sr_all_models:
    if x['hp']['dend_nonlinearity'] not in ['subtractive']:
        continue
    data = list(x.values())[6:10]
    if np.isnan(data).any():
        continue
#     if 'success_2022-09-19-13-52-05_wcst_15_longer_iti_5dend_nonlineaerities' not in x['hp']['save_name']:
#         continue


    data_fig7f['exc_exc'].append(data[0])
    data_fig7f['exc_pv'].append(data[1])
    data_fig7f['pv_exc'].append(data[2])
    data_fig7f['pv_pv'].append(data[3])

    
    ax.plot(data, marker='o', color='k', linewidth=2, markersize=10, alpha=0.5)
ax.set_xticks(np.arange(len(data)))
ax.set_xticklabels([r'resp E $\rightarrow$ resp E', r'resp E $\rightarrow$ resp PV', r'resp PV $\rightarrow$ resp E', r'resp PV $\rightarrow$ resp PV'], rotation=10)
ax.axhline(y=0, ls='--', color='k')
ax.set_xlim(-0.5, len(data)-0.5)
# ax.set_ylim(-1,2)
ax.set_ylabel('Connectivity bias', fontsize=20)
make_pretty_axes(ax)
plt.show()


fig, ax = plt.subplots(figsize=[7,3])
fig.suptitle('Connectivity biases across all models', fontsize=20)
fig.patch.set_facecolor('white')
for x in conn_bias_sr_all_models:
    if x['hp']['dend_nonlinearity'] not in ['subtractive']:
        continue
    data = list(x.values())[10:14]
    if np.isnan(data).any():
        continue
#     if 'success_2022-09-19-13-52-05_wcst_15_longer_iti_5dend_nonlineaerities' not in x['hp']['save_name']:
#         continue
    
    ax.plot(data, marker='o', color='k', linewidth=2, markersize=10, alpha=0.5)
ax.set_xticks(np.arange(len(data)))
ax.set_xticklabels([r'ref card E $\rightarrow$ ref card E', r'ref card E $\rightarrow$ ref card PV', r'ref card PV $\rightarrow$ ref card E', r'ref card PV $\rightarrow$ ref card PV'], rotation=10)
ax.axhline(y=0, ls='--', color='k')
ax.set_xlim(-0.5, len(data)-0.5)
# ax.set_ylim(-1,2)
ax.set_ylabel('Connectivity bias', fontsize=20)
make_pretty_axes(ax)
plt.show()


fig, ax = plt.subplots(figsize=[7,3])
fig.suptitle('Connectivity biases across all models', fontsize=20)
fig.patch.set_facecolor('white')
for x in conn_bias_sr_all_models:
    if x['hp']['dend_nonlinearity'] not in ['subtractive']:
        continue
    data = list(x.values())[14:18]
    if np.isnan(data).any():
        continue
#     if 'success_2022-09-19-13-52-05_wcst_15_longer_iti_5dend_nonlineaerities' not in x['hp']['save_name']:
#         continue


    data_fig7g['exc_exc'].append(data[0])
    data_fig7g['exc_pv'].append(data[1])
    data_fig7g['pv_exc'].append(data[2])
    data_fig7g['pv_pv'].append(data[3])



    
    ax.plot(data, marker='o', color='k', linewidth=2, markersize=10, alpha=0.5)
ax.set_xticks(np.arange(len(data)))
ax.set_xticklabels([r'shared feature E $\rightarrow$ E', r'shared feature E $\rightarrow$ PV', r'shared feature PV $\rightarrow$ E', r'shared feature PV $\rightarrow$ PV'], rotation=10)
ax.axhline(y=0, ls='--', color='k')
ax.set_xlim(-0.5, len(data)-0.5)
# ax.set_ylim(-1,2)
ax.set_ylabel('Connectivity bias', fontsize=20)
make_pretty_axes(ax)
plt.show()
    

In [None]:
# save to csv
pd.DataFrame.from_dict(data=data_fig7e, orient='index').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/fig7e.csv', header=False)
pd.DataFrame.from_dict(data=data_fig7f, orient='index').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/fig7f.csv', header=False)
pd.DataFrame.from_dict(data=data_fig7g, orient='index').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/fig7g.csv', header=False)

In [None]:
# start = time.time()
# plt.rc('font', size=12)

# all_data_wout = []

# for model_name in sorted(os.listdir('/scratch/yl4317/two_module_rnn/saved_models/')):
#     if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
#         print(model_name+'\n')
        
# #         # load model
#         path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
#         with HiddenPrints():
#             model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
# #         if hp_test['dt']!=10:
# #             print('pass\n')
# #             continue
# #         if len(hp_test['cell_group_list'])==2:
# #             print('pass\n')
# #             continue
# #         if hp_test['dend_nonlinearity'] not in ['v2', 'subtract_v2', 'old_v2', 'subtract_v3', 'v3', 'v2_std']:
# #             continue
# #         if hp_test['no_pfcesoma_to_srsst']==True:
# #             continue
# #         if hp_test['sparse_srsst_to_sredend']!=0:
# #             continue
        
#         for key in ['dend_nonlinearity', 'sparse_srsst_to_sredend', 'initialization_weights', 'activation']:
#             print(key, hp_test[key])
            
#         # make noiseless
# #         model.rnn.network_noise = 0
# #         hp_test['input_noise_perceptual'] = 0
# #         hp_test['input_noise_rule'] = 0
        
#         # generate some neural data
# #         neural_data = generate_neural_data_test(model=model, n_trials_test=100, switch_every_test=10, to_plot=False, hp_test=hp_test, hp_task_test=hp_task_test, compute_current=False)
#         with open('/scratch/yl4317/two_module_rnn/saved_testdata/'+model_name+'_testdata_noiseless_no_current_matrix', 'rb') as f:
#             neural_data = pickle.load(f)
#         test_data = neural_data['test_data']
#         mean_test_perf = np.mean([_[0] for _ in test_data['perfs']])
#         if mean_test_perf<=0.8:
#             print('perf too low ({}), pass\n'.format(mean_test_perf))
#             continue
#         rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
# #         current_matrix = neural_data['current_matrix']
        
#         # generate trial labels
#         trial_labels = label_trials_wcst(test_data=test_data)
#         rule1_trs_stable = trial_labels['rule1_trs_stable']
#         rule2_trs_stable = trial_labels['rule2_trs_stable']
#         rule1_trs_after_error = trial_labels['rule1_trs_after_error']
#         rule2_trs_after_error = trial_labels['rule2_trs_after_error']
#         c1_trs_stable = trial_labels['c1_trs_stable']
#         c2_trs_stable = trial_labels['c2_trs_stable']
#         c3_trs_stable = trial_labels['c3_trs_stable']
#         resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
#         error_trials = trial_labels['error_trials']
        
#         # compute cell selectivity
#         all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
#                                          rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
#                                          rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
#                                          resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
#                                          stims=test_data['stims'], error_trials=trial_labels['error_trials'])

#         # define neuron pools
#         subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
#                                           rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], ref_card_sel=all_sels['ref_card_normalized'],
#                                           rule1_trs_stable=rule1_trs_stable, 
#                                           rule2_trs_stable=rule2_trs_stable, 
#                                           rule_threshold=0, resp_threshold=0)
#         for subcg in subcg_sr_idx.keys():
#             model.rnn.cg_idx['subcg_sr_'+subcg] = subcg_sr_idx[subcg]
            
#         w_out_eff = model.rnn.effective_weight(w=model.rnn.w_out, mask=model.mask_out).detach().cpu().numpy()

#         all_data_wout.append({
#                               'model': model, 
#                               'model_name': model_name, 
#                               'hp': hp_test, 
#                               'w_out_eff': w_out_eff,
#                               'subcg_sr_idx': subcg_sr_idx
#                              })
                
# print('Elapsed time: {}s'.format(time.time()-start))     

In [None]:
# wout_pref_all = []
# wout_nonpref_all = []

# for data in all_data_wout:
#     if data['hp']['dend_nonlinearity']!='divisive_2':
#         continue
#     print(data['model_name'])
#     wout_pref = []
#     wout_nonpref = []
#     wout = data['w_out_eff']
#     for resp in ['c1', 'c2', 'c3']:
#         neuron_idx = data['subcg_sr_idx']['resp{}_sr_esoma'.format(resp)]
#         if resp=='c1':
#             wout_idx_pref, wout_idx_nonpref = 0, [1, 2]
#         elif resp=='c2':
#             wout_idx_pref, wout_idx_nonpref = 1, [0, 2]
#         elif resp=='c3':
#             wout_idx_pref, wout_idx_nonpref = 2, [0, 1]
#         wout_pref.extend(wout[neuron_idx, wout_idx_pref])
#         wout_nonpref.extend(np.mean(wout[np.ix_(neuron_idx, wout_idx_nonpref)], axis=1))
#         wout_pref_all.extend(wout[neuron_idx, wout_idx_pref])
#         wout_nonpref_all.extend(np.mean(wout[np.ix_(neuron_idx, wout_idx_nonpref)], axis=1))
        
#     fig, ax = plt.subplots(figsize=[10,7])
#     fig.patch.set_facecolor('white')
#     ax.plot([0, 1], [wout_pref, wout_nonpref], color='k', marker='o', alpha=0.5)
#     ax.set_xlim([-0.2, 1.2])
#     ax.set_xticks([0, 1])
#     ax.set_xticklabels(['Preferred\nchoice', 'Non-preferred\nchoice'], rotation=0)
#     ax.set_ylabel('Readout weight', fontsize=20)
#     make_pretty_axes(ax)
#     fig.tight_layout()
#     plt.show()
    
# print('all models')
# fig, ax = plt.subplots(figsize=[7,7])
# fig.patch.set_facecolor('white')
# ax.plot([0, 1], [wout_pref_all, wout_nonpref_all], color='k', marker='o', alpha=0.02)
# ax.set_xlim([-0.3, 1.3])
# ax.set_xticks([0, 1])
# ax.set_xticklabels(['Preferred\nchoice', 'Non-preferred\nchoice'], rotation=0)
# ax.set_ylabel('Readout weight', fontsize=20)
# make_pretty_axes(ax)
# fig.tight_layout()
# plt.show()