In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
from textwrap import wrap
from scipy.stats import wilcoxon

from functions import *

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    


In [None]:
colors = ['#b3e2cd', '#fdcdac']

# Generate data for Figure 4 a, d, f: structure in the connection that target excitatory cells in the sensorimotor module

In [None]:
start = time.time()
plt.rc('font', size=12)

all_data_to_exc = []

for model_name in sorted(os.listdir('/model/directory/')):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name+'\n')
        
#         # load model
        path_to_file = '/model/directory/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, 
                                                                            simple=False, plot=False, toprint=False)

        # generate some neural data
        with open('/where/test/run/is/saved/{}'.format(model_name+'_testdata_noiseless'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_test_perf = np.mean([_ for _ in test_data['perfs']])
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        rule1_trs_after_correct = trial_labels['rule1_trs_after_correct']
        rule2_trs_after_correct = trial_labels['rule2_trs_after_correct']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                         rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                         rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                         resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                         stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)
        resp_sel_normalized = all_sels['resp_normalized']
        rule_sel_normalized = all_sels['rule_normalized_activity']

        # subregions
        subcg_pfc_idx = define_subpop_pfc(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], err_sel=all_sels['error_normalized'], 
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule1_after_error_trs=rule1_trs_after_error,
                                          rule2_after_error_trs=rule2_trs_after_error,
                                          rule1_after_correct_trs=rule1_trs_after_correct,
                                          rule2_after_correct_trs = rule1_trs_after_correct,
                                          rule_threshold=0.5, err_threshold=0.5)
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], 
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable,
                                          ref_card_sel=all_sels['ref_card_normalized'],
                                          rule_threshold=0, resp_threshold=0)

            
        
        
            
        #=== analysis ===#
        w_rec_eff = model.rnn.effective_weight(w=model.rnn.w_rec, mask=model.rnn.mask).detach().numpy()

        
        for n in model.rnn.cg_idx['sr_esoma']:
            # look at PV->exc (soma) connection
            rule_sel_soma = all_sels['rule_normalized_activity'][n]
            if rule_sel_soma>0:
                same_rule_soma = 'rule1'
                diff_rule_soma = 'rule2'
            elif rule_sel_soma<0:
                same_rule_soma = 'rule2'
                diff_rule_soma = 'rule1'
            else:
                continue
                
            if len(subcg_sr_idx['rule1_sr_pv'])==0 or len(subcg_sr_idx['rule2_sr_pv'])==0:
                print('# of rule1/2 PV neuron = {}/{}, pass'.format(len(subcg_sr_idx['rule1_sr_pv']), len(subcg_sr_idx['rule2_sr_pv'])))
                continue
                
            w_smpv_same_rule = np.mean(w_rec_eff[subcg_sr_idx['{}_sr_pv'.format(same_rule_soma)], n])
            w_smpv_diff_rule = np.mean(w_rec_eff[subcg_sr_idx['{}_sr_pv'.format(diff_rule_soma)], n])
            
            # look at PFC->exc (dendrite) and SST->exc (dendrite) connections
            dend_idx = [n+(b+1)*len(model.rnn.cg_idx['sr_esoma']) for b in range(model.rnn.n_branches)]
            for n_dend in dend_idx:
                rule_sel_dend = all_sels['rule_normalized_activity'][n_dend]
                if rule_sel_soma>0:
                    same_rule_soma = 'rule1'
                    diff_rule_soma = 'rule2'
                elif rule_sel_soma<0:
                    same_rule_soma = 'rule2'
                    diff_rule_soma = 'rule1'
                else:
                    continue
                if rule_sel_dend>0:
                    same_rule_dend = 'rule1'
                    diff_rule_dend = 'rule2'
                elif rule_sel_dend<0:
                    same_rule_dend = 'rule2'
                    diff_rule_dend = 'rule1'
                else:
                    continue
                if len(subcg_pfc_idx['rule1_pfc_esoma'])==0 or len(subcg_pfc_idx['rule2_pfc_esoma'])==0 or len(subcg_sr_idx['rule1_sr_sst'])==0 or len(subcg_sr_idx['rule2_sr_sst'])==0:
                    print('# of rule1/2 PFC exc/SR SST neuron = {}/{}, {}/{}, pass'.format(len(subcg_pfc_idx['rule1_pfc_esoma']), len(subcg_pfc_idx['rule2_pfc_esoma']), len(subcg_sr_idx['rule1_sr_sst']), len(subcg_sr_idx['rule2_sr_sst'])))
                    continue
                w_pfc_same_rule_soma = np.mean(w_rec_eff[subcg_pfc_idx['{}_pfc_esoma'.format(same_rule_soma)], n_dend])
                w_pfc_diff_rule_soma = np.mean(w_rec_eff[subcg_pfc_idx['{}_pfc_esoma'.format(diff_rule_soma)], n_dend])
                w_smsst_same_rule_soma = np.mean(w_rec_eff[subcg_sr_idx['{}_sr_sst'.format(same_rule_soma)], n_dend])
                w_smsst_diff_rule_soma = np.mean(w_rec_eff[subcg_sr_idx['{}_sr_sst'.format(diff_rule_soma)], n_dend])
                w_pfc_same_rule_dend = np.mean(w_rec_eff[subcg_pfc_idx['{}_pfc_esoma'.format(same_rule_dend)], n_dend])
                w_pfc_diff_rule_dend = np.mean(w_rec_eff[subcg_pfc_idx['{}_pfc_esoma'.format(diff_rule_dend)], n_dend])
                w_smsst_same_rule_dend = np.mean(w_rec_eff[subcg_sr_idx['{}_sr_sst'.format(same_rule_dend)], n_dend])
                w_smsst_diff_rule_dend = np.mean(w_rec_eff[subcg_sr_idx['{}_sr_sst'.format(diff_rule_dend)], n_dend])
                all_data_to_exc.append({'model': model_name,
                                     'hp': hp_test,
                                     'n_dend': n_dend,
                                     'rule_sel_soma': rule_sel_soma,
                                     'rule_sel_dend': rule_sel_dend,
                                     'w_pfc_same_rule_soma': w_pfc_same_rule_soma,
                                     'w_pfc_diff_rule_soma': w_pfc_diff_rule_soma,
                                     'w_smsst_same_rule_soma': w_smsst_same_rule_soma,
                                     'w_smsst_diff_rule_soma': w_smsst_diff_rule_soma,
                                     'w_smpv_same_rule_soma': w_smpv_same_rule,
                                     'w_smpv_diff_rule_soma': w_smpv_diff_rule,
                                     'w_pfc_same_rule_dend': w_pfc_same_rule_dend,
                                     'w_pfc_diff_rule_dend': w_pfc_diff_rule_dend,
                                     'w_smsst_diff_rule_dend': w_smsst_diff_rule_dend})

print(time.time()-start)

# Figure 4a, d, f:  connection onto E cells in the sensorimotor module, example model

In [None]:
# plot each model separately

all_model_names = list(set([x['model'] for x in all_data_to_exc]))

for model_name in all_model_names:
    hp = [x['hp'] for x in all_data_to_exc if x['model']==model_name][0]    

    if hp['no_pfcesoma_to_srsst']==True:
        continue

    
    fig, ax = plt.subplots(1, 3, figsize=[15, 5])
    fig.suptitle(model_name, fontsize=20)
    fig.patch.set_facecolor('white')
    for row in range(3):
        ax[row].set_xlim([-0.5, 1.5])
        ax[row].set_xlim([-0.5, 1.5])
    
    w_pfc_same_rule_this_model = [x['w_pfc_same_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_pfc_diff_rule_this_model = [x['w_pfc_diff_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_smsst_same_rule_this_model = [x['w_smsst_same_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_smsst_diff_rule_this_model = [x['w_smsst_diff_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_smpv_same_rule_this_model = [x['w_smpv_same_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_smpv_diff_rule_this_model = [x['w_smpv_diff_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    
    # line plot
    ax[0].plot([0, 1], [w_pfc_same_rule_this_model, w_pfc_diff_rule_this_model], marker='o', color='k', alpha=0.25)
    ax[1].plot([0, 1], [w_smsst_same_rule_this_model, w_smsst_diff_rule_this_model], marker='o', color='k', alpha=0.25)
    ax[2].plot([0, 1], [w_smpv_same_rule_this_model, w_smpv_diff_rule_this_model], marker='o', color='k', alpha=0.25)
    ax[0].set_xticks([0, 1])
    ax[0].set_xticklabels(['PFC, same rule', 'PFC, different rule'], rotation=0)
    ax[1].set_xticks([0, 1])
    ax[1].set_xticklabels(['SST, same rule', 'SST, different rule'], rotation=0)
    ax[2].set_xticks([0, 1])
    ax[2].set_xticklabels(['PV, same rule', 'PV, different rule'], rotation=0)
    ax[0].set_ylabel('Weight from...', fontsize=20)
    ax[1].set_ylabel('Weight from...', fontsize=20)
    ax[2].set_ylabel('Weight from...', fontsize=20)
    
    
    # bar plot
    y = [w_pfc_same_rule_this_model, w_pfc_diff_rule_this_model]
    ax[0].bar([0, 1],
                   height=[np.mean(yi) for yi in y],
#                    yerr=[stats.sem(yi) for yi in y],    # error bars
                   capsize=12, # error bar cap width in points
                   width=0.2,    # bar width
                   color=colors,
                   edgecolor=colors
                  )

    if np.sum(y)!=0:     # no PFC->Edend connection
        print('pfc->exc, wilconxon test, p={}'.format(wilcoxon(x=w_pfc_same_rule_this_model, y=w_pfc_diff_rule_this_model, alternative='greater')))
        

    y = [w_smsst_same_rule_this_model, w_smsst_diff_rule_this_model]
    ax[1].bar([0, 1],
               height=[np.mean(yi) for yi in y],
#                yerr=[stats.sem(yi) for yi in y],    # error bars
               capsize=12, # error bar cap width in points
               width=0.2,    # bar width
               color=colors,
               edgecolor=colors
              )
    print('pfc->sst, wilconxon test, p={}'.format(wilcoxon(x=w_smsst_same_rule_this_model, y=w_smsst_diff_rule_this_model, alternative='greater')))


    
    y = [w_smpv_same_rule_this_model, w_smpv_diff_rule_this_model]
    ax[2].bar([0, 1],
               height=[np.mean(yi) for yi in y],
#                yerr=[stats.sem(yi) for yi in y],    # error bars
               capsize=12, # error bar cap width in points
               width=0.2,    # bar width
               color=colors,
               edgecolor=colors
              )
    print('pfc->pv, wilconxon test, p={}'.format(wilcoxon(x=w_smpv_same_rule_this_model, y=w_smpv_diff_rule_this_model, alternative='greater')))
    
    
    for i in range(3):
        make_pretty_axes(ax[i])
            
    fig.tight_layout()
    plt.show()

# Supplementary Figure 7a, d, f, g, j, l: connection onto E cells in the sensorimotor module, connection onto E cells in the sensorimotor module, across all models

In [None]:
fig, ax = plt.subplots(1, 3, figsize=[15, 5])
fig.patch.set_facecolor('white')
for col in range(3):
    ax[col].set_xlim([-0.5, 1.5])
    ax[col].set_xlim([-0.5, 1.5])

for x in all_data_to_exc:
    if x['hp']['dend_nonlinearity']!='subtractive':
        continue
    ax[0].plot([0, 1], [x['w_pfc_same_rule_soma'], x['w_pfc_diff_rule_soma']], marker='o', color='k', alpha=0.05)
    ax[1].plot([0, 1], [x['w_smsst_same_rule_soma'], x['w_smsst_diff_rule_soma']], marker='o', color='k', alpha=0.05)
    ax[2].plot([0, 1], [x['w_smpv_same_rule_soma'], x['w_smpv_diff_rule_soma']], marker='o', color='k', alpha=0.05)
ax[0].set_xticks([0, 1])
ax[0].set_xticklabels(['PFC, same rule', 'PFC, different rule'], rotation=0)
ax[1].set_xticks([0, 1])
ax[1].set_xticklabels(['SST, same rule', 'SST, different rule'], rotation=0)
ax[2].set_xticks([0, 1])
ax[2].set_xticklabels(['PV, same rule', 'PV, different rule'], rotation=0)
for i in range(3):
    ax[i].set_ylabel('Weight from...', fontsize=20)
    ax[i].set_yticklabels(np.round(ax[i].get_yticks(), 2))

# plot the means
w_pfc_same_rule_all_soma = [x['w_pfc_same_rule_soma'] for x in all_data_to_exc if x['hp']['dend_nonlinearity']=='subtractive']
w_pfc_diff_rule_all_soma = [x['w_pfc_diff_rule_soma'] for x in all_data_to_exc if x['hp']['dend_nonlinearity']=='subtractive']
y = [w_pfc_same_rule_all_soma, w_pfc_diff_rule_all_soma]
ax[0].bar([0, 1],
           height=[np.mean(yi) for yi in y],
#            yerr=[stats.sem(yi) for yi in y],    # error bars
           capsize=12, # error bar cap width in points
           width=0.2,    # bar width
           color=colors,
           edgecolor=colors
          )
ax[0].set_ylabel('')

if np.sum(y)!=0:
    print('wilconxon test, p={}'.format(wilcoxon(x=y[0], y=y[1], alternative='greater')))

w_smsst_same_rule_all_soma = [x['w_smsst_same_rule_soma'] for x in all_data_to_exc if x['hp']['dend_nonlinearity']=='subtractive']
w_smsst_diff_rule_all_soma = [x['w_smsst_diff_rule_soma'] for x in all_data_to_exc if x['hp']['dend_nonlinearity']=='subtractive']
y = [w_smsst_same_rule_all_soma, w_smsst_diff_rule_all_soma]
ax[1].bar([0, 1],
           height=[np.mean(yi) for yi in y],
#            yerr=[stats.sem(yi) for yi in y],    # error bars
           capsize=12, # error bar cap width in points
           width=0.2,    # bar width
           color=colors,
           edgecolor=colors
          )
ax[1].set_ylabel('')
ax[1].set_xticks([0, 1])
ax[1].set_xticklabels(['SST, same rule', 'SST, different rule'], rotation=0)
print('wilconxon test, p={}'.format(wilcoxon(x=y[0], y=y[1], alternative='greater')))


w_smpv_same_rule_all_soma = [x['w_smpv_same_rule_soma'] for x in all_data_to_exc if x['hp']['dend_nonlinearity']=='subtractive']
w_smpv_diff_rule_all_soma = [x['w_smpv_diff_rule_soma'] for x in all_data_to_exc if x['hp']['dend_nonlinearity']=='subtractive']
y = [w_smpv_same_rule_all_soma, w_smpv_diff_rule_all_soma]
ax[2].bar([0, 1],
           height=[np.mean(yi) for yi in y],
#            yerr=[stats.sem(yi) for yi in y],    # error bars
           capsize=12, # error bar cap width in points
           width=0.2,    # bar width
           color=colors,
           edgecolor=colors
          )
ax[2].set_ylabel('')
ax[2].set_xticks([0, 1])
ax[2].set_xticklabels(['PV, same rule', 'PV, different rule'], rotation=0)
print('wilconxon test, p={}'.format(wilcoxon(x=y[0], y=y[1], alternative='greater')))


for i in range(3):
    make_pretty_axes(ax[i])
            
            
fig.tight_layout()
plt.show()










# Generate data for Figure 4b, c, e: connectivity structure from PFC to different cell types in the sensorimotor module

In [None]:
# Look at the top-down projection to different neuron pools in SM
from scipy.stats.stats import pearsonr

start = time.time()
plt.rc('font', size=12)

# all_rs = []
# all_ps = []
all_data_frompfc = []

for model_name in sorted(os.listdir('/model/directory/')):
#     if 'full_model' in model_name and 'wcst' in model_name and 'success' in model_name:
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name+'\n')
        
#         # load model
        path_to_file = '/where/test/run/data/is/stored/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, 
                                                                            simple=False, plot=False, toprint=False)
        with open('/where/test/run/data/is/stored/{}'.format(model_name+'_testdata_noiseless'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        if mean_perf<0.8 or mean_perf_rule<0.8:
            print('low performing model ({}/{})'.format(mean_perf, mean_perf_rule))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        rule1_trs_after_correct = trial_labels['rule1_trs_after_correct']
        rule2_trs_after_correct = trial_labels['rule2_trs_after_correct']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                     rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                     rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                     resp_trs_stable = resp_trs_stable, 
                                     stims=test_data['stims'], error_trials=trial_labels['error_trials'], trs_by_center_card=trial_labels['trs_by_center_card_stable'], trial_labels=trial_labels)
        resp_sel_normalized = all_sels['resp_normalized']
        rule_sel_normalized = all_sels['rule_normalized_activity']

        # subregions
        subcg_pfc_idx = define_subpop_pfc(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], err_sel=all_sels['error_normalized'], 
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule1_after_error_trs=rule1_trs_after_error,
                                          rule2_after_error_trs=rule2_trs_after_error,
                                          rule1_after_correct_trs=rule1_trs_after_correct,
                                          rule2_after_correct_trs = rule1_trs_after_correct,
                                          rule_threshold=0.5, err_threshold=0.5)
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], 
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule_threshold=0, resp_threshold=0,
                                          ref_card_sel=all_sels['ref_card_normalized'])
        for subcg in subcg_pfc_idx.keys():
            model.rnn.cg_idx['subcg_pfc_'+subcg] = subcg_pfc_idx[subcg]
        for subcg in subcg_sr_idx.keys():
            model.rnn.cg_idx['subcg_sr_'+subcg] = subcg_sr_idx[subcg]
            
        
        
        w_rec_eff = model.rnn.effective_weight(w=model.rnn.w_rec, mask=model.rnn.mask)
        w_rec_eff = w_rec_eff.detach().numpy()
        
        rule1_vip_idx, rule2_vip_idx, rule1_sst_idx, rule2_sst_idx, rule1_pv_idx, rule2_pv_idx = subcg_sr_idx['rule1_sr_vip'], subcg_sr_idx['rule2_sr_vip'], subcg_sr_idx['rule1_sr_sst'], subcg_sr_idx['rule2_sr_sst'], subcg_sr_idx['rule1_sr_pv'], subcg_sr_idx['rule2_sr_pv']
    
        w_pfc_rule1_vip_rule1 = np.mean(w_rec_eff[np.ix_(rule1_pfcesoma_idx, rule1_vip_idx)], axis=0)
        w_pfc_rule1_vip_rule2 = np.mean(w_rec_eff[np.ix_(rule1_pfcesoma_idx, rule2_vip_idx)], axis=0)
        w_pfc_rule2_vip_rule1 = np.mean(w_rec_eff[np.ix_(rule2_pfcesoma_idx, rule1_vip_idx)], axis=0)
        w_pfc_rule2_vip_rule2 = np.mean(w_rec_eff[np.ix_(rule2_pfcesoma_idx, rule2_vip_idx)], axis=0)
        
        w_pfc_rule1_sst_rule1 = np.mean(w_rec_eff[np.ix_(rule1_pfcesoma_idx, rule1_sst_idx)], axis=0)
        w_pfc_rule1_sst_rule2 = np.mean(w_rec_eff[np.ix_(rule1_pfcesoma_idx, rule2_sst_idx)], axis=0)
        w_pfc_rule2_sst_rule1 = np.mean(w_rec_eff[np.ix_(rule2_pfcesoma_idx, rule1_sst_idx)], axis=0)
        w_pfc_rule2_sst_rule2 = np.mean(w_rec_eff[np.ix_(rule2_pfcesoma_idx, rule2_sst_idx)], axis=0)
        
        w_pfc_rule1_pv_rule1 = np.mean(w_rec_eff[np.ix_(rule1_pfcesoma_idx, rule1_pv_idx)], axis=0)
        w_pfc_rule1_pv_rule2 = np.mean(w_rec_eff[np.ix_(rule1_pfcesoma_idx, rule2_pv_idx)], axis=0)
        w_pfc_rule2_pv_rule1 = np.mean(w_rec_eff[np.ix_(rule2_pfcesoma_idx, rule1_pv_idx)], axis=0)
        w_pfc_rule2_pv_rule2 = np.mean(w_rec_eff[np.ix_(rule2_pfcesoma_idx, rule2_pv_idx)], axis=0)
        
        w_same_rule_vip = np.concatenate((w_pfc_rule1_vip_rule1, w_pfc_rule2_vip_rule2))    # weight to VIP neurons from PFC exc neurons selective for the same rule
        w_diff_rule_vip = np.concatenate((w_pfc_rule2_vip_rule1, w_pfc_rule1_vip_rule2))
        w_same_rule_sst = np.concatenate((w_pfc_rule1_sst_rule1, w_pfc_rule2_sst_rule2))
        w_diff_rule_sst = np.concatenate((w_pfc_rule2_sst_rule1, w_pfc_rule1_sst_rule2))
        w_same_rule_pv = np.concatenate((w_pfc_rule1_pv_rule1, w_pfc_rule2_pv_rule2))
        w_diff_rule_pv = np.concatenate((w_pfc_rule2_pv_rule1, w_pfc_rule1_pv_rule2))
        
        
        
        all_data_frompfc.append({'name': model_name, 
                         'hp': hp_test,
                         'w_same_rule_vip': w_same_rule_vip,
                         'w_diff_rule_vip': w_diff_rule_vip,
                         'w_same_rule_sst': w_same_rule_sst,
                         'w_diff_rule_sst': w_diff_rule_sst,
                         'w_same_rule_pv': w_same_rule_pv,
                         'w_diff_rule_pv': w_diff_rule_pv
                        })

print(time.time()-start)
        

# Figure 4 a, d, f and Supplementary Figure 7 b, c, e, h, i, k

In [None]:
# compare the weight from PFC neurons that have the same/different rule preference

# plot each model
for x in all_data_frompfc:
    print(x['name'])
    
    for key in ['dend_nonlinearity', 'sparse_srsst_to_sredend', 'initialization_weights', 'activation']:
        print(key, x['hp'][key])


    for ctype in ['sst', 'vip', 'pv']:
        y = [x['w_same_rule_{}'.format(ctype)], x['w_diff_rule_{}'.format(ctype)]]
        if np.array(y).size==0:
            print('size of y is 0, pass')
            continue
        
        # plot
        fig, ax = plt.subplots(1, 1, figsize=[5, 6])
        fig.patch.set_facecolor('white')
        fig.suptitle(ctype)
           
        # do statistical test
        ttest = stats.ttest_ind(y[0], y[1])
        ax.set_title('{:.2f}, p={:.4f}'.format(ttest[0], ttest[1]))
        ax.plot([0, 1], y, marker='o', color='k', alpha=0.5)
        ax.bar([0, 1],
                    height=[np.mean(yi) for yi in y],
        #            yerr=[stats.sem(yi) for yi in yy],    # error bars
                   capsize=12, # error bar cap width in points
                   width=0.2,    # bar width
                   color=colors,
                   edgecolor=colors
                 )
        
        make_pretty_axes(ax)
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['PFC, same rule', 'PFC, different rule'], rotation=15)
        ax.set_ylabel('Weight from...', fontsize=20)
        ax.set_xlim([-0.2, 1.2])
        fig.tight_layout()
        plt.show()


# aggregate across models
for ctype in ['sst', 'vip', 'pv']:
    w_same_rule_all = []
    w_diff_rule_all = []  
    for x in all_data_frompfc:
        if x['hp']['dend_nonlinearity']!='divisive_2':
            continue
        w_same_rule_all.extend(x['w_same_rule_{}'.format(ctype)])
        w_diff_rule_all.extend(x['w_diff_rule_{}'.format(ctype)])
    yy = [w_same_rule_all, w_diff_rule_all]
    
    
    fig, ax = plt.subplots(1, 1, figsize=[5, 6])
    fig.suptitle(ctype, fontsize=30)
    fig.patch.set_facecolor('white')
    
    # do statistical test
    ttest = stats.ttest_ind(yy[0], yy[1])
    ax.set_title('{:.2f}, p={:.4f}'.format(ttest[0], ttest[1]))
        
    ax.plot([0, 1], yy, marker='o', color='k', alpha=0.1)
    ax.bar([0, 1],
                    height=[np.mean(yi) for yi in yy],
        #            yerr=[stats.sem(yi) for yi in yy],    # error bars
                   capsize=12, # error bar cap width in points
                   width=0.2,    # bar width
                   color=colors,
                   edgecolor=colors
                  )
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['PFC, same rule', 'PFC, different rule'], rotation=15)
    ax.set_ylabel('Weight from...', fontsize=20)
    ax.set_yticklabels(np.round(ax.get_yticks(), 2))
    ax.set_xlim([-0.2, 1.2])
    make_pretty_axes(ax)
    fig.tight_layout()
plt.show()
