In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=2); np.random.seed(0)
import torch; torch.set_printoptions(precision=2)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl

import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import datetime
import pickle
import copy
import pandas as pd
import scipy
import os

from sklearn.cluster import KMeans
from sklearn.manifold import MDS
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity

import sys

# from model import *
from functions import *


print(torch.__version__)
print(sys.version)
                
%matplotlib inline

# Figure 3a, b: connectivity matrices between different groups of units

In [None]:
start = time.time()

all_data_conn_bias = []

plot = True

subcg_label_converter = {'rule1_pfc_esoma': 'rule 1 exc', 'rule2_pfc_esoma': 'rule 2 exc', 'rule1_pfc_pv': 'rule 1\n PV', 'rule2_pfc_pv': 'rule 2\n PV', 'mix_err_rule1_pfc_esoma': 'error x rule 1', 
                         'mix_err_rule2_pfc_esoma': 'error x rule 2'}

for model_name in sorted(os.listdir('/model/directory/')):
    if ('2023-05-10' in model_name) and 'success' in model_name:
        print(model_name+'\n')
        
        # load model
        path_to_file = '/model/directory/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)
        
        # generate some neural data
        with open('/where/test/run/is/stored/{}'.format(model_name+'_testdata_noiseless'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        if mean_perf<0.8 or mean_perf_rule<0.8:
            print('low performing model ({}/{})'.format(mean_perf, mean_perf_rule))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        rule1_trs_after_correct = trial_labels['rule1_trs_after_correct']
        rule2_trs_after_correct = trial_labels['rule2_trs_after_correct']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
                
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                     rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                     rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                     resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                     stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)
        
        
        # define subpopulations within PFC
        rule_sel_used = all_sels['rule_normalized_activity']
        error_sel_used = all_sels['error_normalized']
        subcg_pfc_idx = define_subpop_pfc(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=rule_sel_used, err_sel=error_sel_used, 
                                          rule1_trs_stable=rule1_trs_stable, rule2_trs_stable=rule2_trs_stable, 
                                          rule1_after_error_trs=rule1_trs_after_error, rule2_after_error_trs=rule2_trs_after_error,
                                          rule1_after_correct_trs=trial_labels['rule1_trs_after_correct'], rule2_after_correct_trs=trial_labels['rule2_trs_after_correct'],
                                          rule_threshold=0.5, err_threshold=0.5, dend_pop='same_as_soma')
        for subcg in subcg_pfc_idx.keys():
            model.rnn.cg_idx['subcg_pfc_'+subcg] = subcg_pfc_idx[subcg]

        
        
        
        # calculate connectivity bias
        w_rec_eff = model.rnn.effective_weight(w=model.rnn.w_rec, mask=model.rnn.mask, w_fix=model.rnn.w_fix).detach().cpu().numpy()
        
        ## structure from rule -> mixerr
        bias_ruleesoma_mixerr_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['mix_err_rule2_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['mix_err_rule1_pfc_esoma'])])
        bias_ruleesoma_mixerr_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['mix_err_rule1_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['mix_err_rule2_pfc_esoma'])])
        bias_ruleesoma_mixerr = bias_ruleesoma_mixerr_1 + bias_ruleesoma_mixerr_2

        bias_rulepv_mixerr_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_pv'], subcg_pfc_idx['mix_err_rule1_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_pv'], subcg_pfc_idx['mix_err_rule2_pfc_esoma'])])
        bias_rulepv_mixerr_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_pv'], subcg_pfc_idx['mix_err_rule2_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_pv'], subcg_pfc_idx['mix_err_rule1_pfc_esoma'])])
        bias_rulepv_mixerr_1 = - bias_rulepv_mixerr_1
        bias_rulepv_mixerr_2 = - bias_rulepv_mixerr_2
        bias_rulepv_mixerr = bias_rulepv_mixerr_1 + bias_rulepv_mixerr_2

        
        bias_rulesst_mixerr_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_sst'], subcg_pfc_idx['mix_err_rule1_pfc_edend'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_sst'], subcg_pfc_idx['mix_err_rule2_pfc_edend'])])
        bias_rulesst_mixerr_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_sst'], subcg_pfc_idx['mix_err_rule2_pfc_edend'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_sst'], subcg_pfc_idx['mix_err_rule1_pfc_edend'])])
        bias_rulesst_mixerr_1 = - bias_rulesst_mixerr_1
        bias_rulesst_mixerr_2 = - bias_rulesst_mixerr_2
        bias_rulesst_mixerr = bias_rulesst_mixerr_1 + bias_rulesst_mixerr_2
    
        bias_ruletomixerr = np.mean([bias_ruleesoma_mixerr, -bias_rulepv_mixerr])    # average over both biases
        
            
        
        
        ## structure from mixerr -> rule
        bias_mixerr_ruleesoma_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_esoma'])])
        bias_mixerr_ruleesoma_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_esoma'])])
        bias_mixerr_ruleesoma = bias_mixerr_ruleesoma_1 + bias_mixerr_ruleesoma_2

        bias_mixerr_rulepv_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_pv'])])
        bias_mixerr_rulepv_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_pv'])])
        bias_mixerr_rulepv = bias_mixerr_rulepv_1 + bias_mixerr_rulepv_2


        bias_mixerr_rulesst_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_sst'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_sst'])])
        bias_mixerr_rulesst_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_sst'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_err_rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_sst'])])
        bias_mixerr_rulesst = bias_mixerr_rulesst_1 + bias_mixerr_rulesst_2
        
        
        bias_mixerrtorule = np.mean([bias_mixerr_ruleesoma, bias_mixerr_rulepv])    # average over both biases
            
        
        
        ## structure from mixcorr -> rule
        bias_mixcorr_ruleesoma_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_esoma'])])
        bias_mixcorr_ruleesoma_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_esoma'])])
        bias_mixcorr_ruleesoma = bias_mixcorr_ruleesoma_1 + bias_mixcorr_ruleesoma_2

        bias_mixcorr_rulepv_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_pv'])])
        bias_mixcorr_rulepv_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_pv'])])
        bias_mixcorr_rulepv = bias_mixcorr_rulepv_1 + bias_mixcorr_rulepv_2
        
        bias_mixcorr_rulesst_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_sst'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_sst'])])
        bias_mixcorr_rulesst_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_sst'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['mix_corr_rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_sst'])])
        bias_mixcorr_rulesst = bias_mixcorr_rulesst_1 + bias_mixcorr_rulesst_2
        
        bias_mixcorrtorule = np.mean([bias_mixcorr_ruleesoma, bias_mixcorr_rulepv])    # average over both biases
        
        
        ## structure from rule -> mixcorr 
        bias_ruleesoma_mixcorr_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['mix_corr_rule1_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['mix_corr_rule2_pfc_esoma'])])
        bias_ruleesoma_mixcorr_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['mix_corr_rule2_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['mix_corr_rule1_pfc_esoma'])])
        bias_ruleesoma_mixcorr = bias_ruleesoma_mixcorr_1 + bias_ruleesoma_mixcorr_2
        
        bias_rulepv_mixcorr_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_pv'], subcg_pfc_idx['mix_corr_rule2_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_pv'], subcg_pfc_idx['mix_corr_rule1_pfc_esoma'])])
        bias_rulepv_mixcorr_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_pv'], subcg_pfc_idx['mix_corr_rule1_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_pv'], subcg_pfc_idx['mix_corr_rule2_pfc_esoma'])])
        bias_rulepv_mixcorr_1 = - bias_rulepv_mixcorr_1
        bias_rulepv_mixcorr_2 = - bias_rulepv_mixcorr_2
        bias_rulepv_mixcorr = bias_rulepv_mixcorr_1 + bias_rulepv_mixcorr_2
        
        bias_rulesst_mixcorr_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_sst'], subcg_pfc_idx['mix_corr_rule2_pfc_edend'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_sst'], subcg_pfc_idx['mix_corr_rule1_pfc_edend'])])
        bias_rulesst_mixcorr_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_sst'], subcg_pfc_idx['mix_corr_rule1_pfc_edend'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_sst'], subcg_pfc_idx['mix_corr_rule2_pfc_edend'])])
        bias_rulesst_mixcorr_1 = - bias_rulesst_mixcorr_1
        bias_rulesst_mixcorr_2 = - bias_rulesst_mixcorr_2
        bias_rulesst_mixcorr = bias_rulesst_mixcorr_1 + bias_rulesst_mixcorr_2
    
        bias_ruletomixcorr = np.mean([bias_ruleesoma_mixcorr, -bias_rulepv_mixcorr])    # average over both biases
        
        
        
        
        
        # bias within rule network
        rule_subregion = ['rule1_pfc_esoma', 'rule2_pfc_esoma', 'rule1_pfc_pv', 'rule2_pfc_pv']
        
        ## from rule_esoma to rule_esoma
        bias_ruleesoma_ruleesoma_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_esoma'])])
        bias_ruleesoma_ruleesoma_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_esoma'])])
        bias_ruleesoma_ruleesoma = bias_ruleesoma_ruleesoma_1 + bias_ruleesoma_ruleesoma_2

        
        ## from rule_esoma to rule_pv
        bias_ruleesoma_rulepv_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_pv'])])
        bias_ruleesoma_rulepv_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_pv'])])
        bias_ruleesoma_rulepv = bias_ruleesoma_rulepv_1 + bias_ruleesoma_rulepv_2

        ## from rule_pv to rule_esoma
        bias_rulepv_ruleesoma_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_pv'], subcg_pfc_idx['rule2_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_pv'], subcg_pfc_idx['rule1_pfc_esoma'])])
        bias_rulepv_ruleesoma_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_pv'], subcg_pfc_idx['rule1_pfc_esoma'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_pv'], subcg_pfc_idx['rule2_pfc_esoma'])])
        bias_rulepv_ruleesoma_1 = - bias_rulepv_ruleesoma_1
        bias_rulepv_ruleesoma_2 = - bias_rulepv_ruleesoma_2
        bias_rulepv_ruleesoma = bias_rulepv_ruleesoma_1 + bias_rulepv_ruleesoma_2

        ## from rule_pv to rule_pv
        bias_rulepv_rulepv_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_pv'], subcg_pfc_idx['rule2_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_pv'], subcg_pfc_idx['rule1_pfc_pv'])])
        bias_rulepv_rulepv_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_pv'], subcg_pfc_idx['rule1_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_pv'], subcg_pfc_idx['rule2_pfc_pv'])])
        bias_rulepv_rulepv_1 = - bias_rulepv_rulepv_1
        bias_rulepv_rulepv_2 = - bias_rulepv_rulepv_2
        bias_rulepv_rulepv = bias_rulepv_rulepv_1 + bias_rulepv_rulepv_2
        
        ## from rule SST to rule E
        bias_rulesst_ruleedend_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_sst'], subcg_pfc_idx['rule2_pfc_edend'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_sst'], subcg_pfc_idx['rule1_pfc_edend'])])
        bias_rulesst_ruleedend_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_sst'], subcg_pfc_idx['rule1_pfc_edend'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_sst'], subcg_pfc_idx['rule2_pfc_edend'])])
        bias_rulesst_ruleedend_1 = - bias_rulesst_ruleedend_1
        bias_rulesst_ruleedend_2 = - bias_rulesst_ruleedend_2
        bias_rulesst_ruleedend = bias_rulesst_ruleedend_1 + bias_rulesst_ruleedend_2
        
        ## from rule E to rule SST
        bias_ruleesoma_rulesst_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['rule1_pfc_sst'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_esoma'], subcg_pfc_idx['rule2_pfc_sst'])])
        bias_ruleesoma_rulesst_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['rule2_pfc_sst'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_esoma'], subcg_pfc_idx['rule1_pfc_sst'])])
        bias_ruleesoma_rulesst = bias_ruleesoma_rulesst_1 + bias_ruleesoma_rulesst_2
        
        ## from rule SST to rule PV
        bias_rulesst_rulepv_1 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_sst'], subcg_pfc_idx['rule2_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule1_pfc_sst'], subcg_pfc_idx['rule1_pfc_pv'])])
        bias_rulesst_rulepv_2 = np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_sst'], subcg_pfc_idx['rule1_pfc_pv'])]) - np.mean(w_rec_eff[np.ix_(subcg_pfc_idx['rule2_pfc_sst'], subcg_pfc_idx['rule2_pfc_pv'])])
        bias_rulesst_rulepv_1 = - bias_rulesst_rulepv_1
        bias_rulesst_rulepv_2 = - bias_rulesst_rulepv_2
        bias_rulesst_rulepv= bias_rulesst_rulepv_1 + bias_rulesst_rulepv_2
    
            
    
        
        if plot==True:
            plt.rc('font', size=18)
            fig, ax = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_pfc_idx, subcg_to_plot_sender=['rule1_pfc_esoma', 'rule2_pfc_esoma'], subcg_to_plot_receiver=['mix_err_rule1_pfc_esoma', 'mix_err_rule2_pfc_esoma'], subcg_label_converter=subcg_label_converter)
            fig, ax = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_pfc_idx, subcg_to_plot_sender=['rule1_pfc_pv', 'rule2_pfc_pv'], subcg_to_plot_receiver=['mix_err_rule1_pfc_esoma', 'mix_err_rule2_pfc_esoma'], subcg_label_converter=subcg_label_converter)
            fig, ax = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_pfc_idx, subcg_to_plot_sender=['mix_err_rule1_pfc_esoma', 'mix_err_rule2_pfc_esoma'], subcg_to_plot_receiver=['rule1_pfc_esoma', 'rule2_pfc_esoma'], subcg_label_converter=subcg_label_converter)
            fig, ax = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_pfc_idx, subcg_to_plot_sender=['mix_err_rule1_pfc_esoma', 'mix_err_rule2_pfc_esoma'], subcg_to_plot_receiver=['rule1_pfc_pv', 'rule2_pfc_pv'], subcg_label_converter=subcg_label_converter)
            fig, ax = plot_conn_subpop(weight=w_rec_eff, cg_idx=subcg_pfc_idx, subcg_to_plot_sender = rule_subregion, subcg_to_plot_receiver=rule_subregion, subcg_label_converter=subcg_label_converter)
        
        all_data_conn_bias.append({'model': model_name, 
                                   'hp': hp_test,
                                   'bias_ruleesoma_mixerr': bias_ruleesoma_mixerr,
                                   'bias_rulepv_mixerr': bias_rulepv_mixerr,
                                   'bias_rulesst_mixerr': bias_rulesst_mixerr,
                                   'bias_ruletomixerr': bias_ruletomixerr,
                                   'bias_ruleesoma_mixcorr': bias_ruleesoma_mixcorr,
                                   'bias_rulepv_mixcorr': bias_rulepv_mixcorr,
                                   'bias_rulesst_mixcorr': bias_rulesst_mixcorr,
                                   'bias_mixerr_ruleesoma': bias_mixerr_ruleesoma,
                                   'bias_mixerr_rulepv': bias_mixerr_rulepv,
                                   'bias_mixerr_rulesst': bias_mixerr_rulesst,
                                   'bias_mixerrtorule': bias_mixerrtorule,
                                   'bias_mixcorr_ruleesoma': bias_mixcorr_ruleesoma,
                                   'bias_mixcorr_rulepv': bias_mixcorr_rulepv,
                                   'bias_mixcorr_rulesst': bias_mixcorr_rulesst,
                                   'bias_ruleesoma_ruleesoma': bias_ruleesoma_ruleesoma,
                                   'bias_ruleesoma_rulepv': bias_ruleesoma_rulepv,
                                   'bias_rulepv_ruleesoma': bias_rulepv_ruleesoma,
                                   'bias_rulepv_rulepv': bias_rulepv_rulepv,
                                   'bias_rulesst_ruleedend': bias_rulesst_ruleedend,
                                   'bias_ruleesoma_rulesst': bias_ruleesoma_rulesst,
                                   'bias_rulesst_rulepv': bias_rulesst_rulepv,
                                   'subcg_pfc_idx': subcg_pfc_idx})
        
        
        
        
        

print(time.time()-start)

# Figure 3d: connectivity biases

In [None]:
# between rule and errorxrule neurons
fig, ax = plt.subplots(figsize=[15,5])
fig.patch.set_facecolor('white')
keys_to_plot = ['bias_ruleesoma_mixerr', 'bias_rulepv_mixerr', 'bias_mixerr_ruleesoma', 'bias_mixerr_rulepv']

for x in all_data_conn_bias:
    data = [x[key] for key in keys_to_plot]
    ax.plot(data, marker='o', color='k', linewidth=2, markersize=10, alpha=0.5)
ax.set_xticks(np.arange(len(data)))
ax.tick_params(axis='both', which='major', labelsize=15)
ax.axhline(y=0, ls='--', color='k', linewidth=5)
ax.set_xlim(-0.5, len(data)-0.5)
ax.set_ylabel('Connectivity bias', fontsize=20)
make_pretty_axes(ax)
fig.tight_layout()
plt.show()


# among the rule neurons
fig, ax = plt.subplots(figsize=[15,5])
fig.patch.set_facecolor('white')
keys_to_plot = ['bias_ruleesoma_ruleesoma', 'bias_ruleesoma_rulepv', 'bias_rulepv_ruleesoma', 'bias_rulepv_rulepv']
for x in all_data_conn_bias:
    data = [x[key] for key in keys_to_plot]
    ax.plot(data, marker='o', color='k', linewidth=2, markersize=10, alpha=0.5)
ax.set_xticks(np.arange(len(data)))
ax.tick_params(axis='both', which='major', labelsize=20)
ax.axhline(y=0, ls='--', color='k', linewidth=5)
ax.set_xlim(-0.5, len(data)-0.5)
ax.set_ylabel('Connectivity bias', fontsize=20)
make_pretty_axes(ax)
fig.tight_layout()
plt.show()    