In [14]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
from textwrap import wrap
from scipy.stats import wilcoxon

from functions import *

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    


colors = ['#b3e2cd', '#fdcdac']


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
1.13.1+cu116
3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]


# Figure 4a, d, f plot example model

In [18]:
with open('/.../conn_bias_to_smExc.pickle', 'rb') as handle:
    all_data_to_exc = pickle.load(handle)

In [None]:
# plot each model separately

# source code for figure 4 of the paper
data_fig4 = {'top_down_weight_to_exc': {'from_same_rule': [], 'from_different_rule': []},
             'weight_from_sst_to_exc': {'from_same_rule': [], 'from_different_rule': []},
             'weight_from_pv_to_exc': {'from_same_rule': [], 'from_different_rule': []},
             'top_down_weight_to_vip': {'from_same_rule': [], 'from_different_rule': []},
             'top_down_weight_to_sst': {'from_same_rule': [], 'from_different_rule': []},
             'top_down_weight_to_pv': {'from_same_rule': [], 'from_different_rule': []}}


all_model_names = list(set([x['model'] for x in all_data_to_exc]))


for model_name in all_model_names:
    if model_name != 'success_2023-05-10-14-28-42_wcst_58_sparsity0':    # this is the model used in figure 4 of the paper
        continue
    hp = [x['hp'] for x in all_data_to_exc if x['model']==model_name][0]    
    
    
    
    w_pfc_same_rule_this_model = [x['w_pfc_same_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_pfc_diff_rule_this_model = [x['w_pfc_diff_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_smsst_same_rule_this_model = [x['w_smsst_same_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_smsst_diff_rule_this_model = [x['w_smsst_diff_rule_soma'] for x in all_data_to_exc if x['model']==model_name]
    w_smpv_same_rule_this_model = list(set([x['w_smpv_same_rule_soma'] for x in all_data_to_exc if x['model']==model_name]))    # the data for PV was written twice 
    w_smpv_diff_rule_this_model = list(set([x['w_smpv_diff_rule_soma'] for x in all_data_to_exc if x['model']==model_name]))

    data_fig4['top_down_weight_to_exc']['from_same_rule'] = w_pfc_same_rule_this_model
    data_fig4['top_down_weight_to_exc']['from_different_rule'] = w_pfc_diff_rule_this_model
    data_fig4['weight_from_sst_to_exc']['from_same_rule'] = w_smsst_same_rule_this_model
    data_fig4['weight_from_sst_to_exc']['from_different_rule'] = w_smsst_diff_rule_this_model
    data_fig4['weight_from_pv_to_exc']['from_same_rule'] = w_smpv_same_rule_this_model
    data_fig4['weight_from_pv_to_exc']['from_different_rule'] = w_smpv_same_rule_this_model
    
    # line plot
    for source in ['PFC', 'SST', 'PV']:
        fig, ax = plt.subplots(1, 1, figsize=[5, 5])
        fig.patch.set_facecolor('white')
        ax.set_xlim([-0.5, 1.5])
        ax.set_xlim([-0.5, 1.5])
        if source == 'PFC':
            data_same = w_pfc_same_rule_this_model
            data_different = w_pfc_diff_rule_this_model
        elif source == 'SST':
            data_same = w_smsst_same_rule_this_model
            data_different = w_smsst_diff_rule_this_model
        elif source == 'PV':
            data_same = w_smpv_same_rule_this_model
            data_different = w_smpv_diff_rule_this_model
        ax.plot([0, 1], [data_same, data_different], marker='o', color='k', alpha=0.25, clip_on=False)
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['{}, same rule'.format(source), '{}, different rule'.format(source)], rotation=0)
        ax.set_ylabel('Weight from...', fontsize=20)
        
        
        # bar plot
        y = [data_same, data_different]
        ax.bar([0, 1], height=[np.mean(yi) for yi in y],
                       # capsize=12, # error bar cap width in points
                       width=0.2,    # bar width
                       color=colors,
                       edgecolor=colors
                      )
        if np.sum(y)!=0:     # no PFC->Edend connection
            ttest = stats.ttest_ind(y[0], y[1], alternative='greater')
            print('student t test, {}, p={}, n={}'.format(ttest[0], ttest[1], len(y[0])))
            
        make_pretty_axes(ax)
                
        fig.tight_layout()
        plt.show()



# Supplementary Figure 8a, d, f, g, j, l - across all models

In [None]:
data_suppfig8 = {'subtractive': {}, 'divisive_2': {}}
for key in data_suppfig8.keys():
    data_suppfig8[key] = {'weight_from_pfc_to_exc': {'from_same_rule': [], 'from_different_rule': []},
                         'weight_from_smsst_to_exc': {'from_same_rule': [], 'from_different_rule': []},
                         'weight_from_smpv_to_exc': {'from_same_rule': [], 'from_different_rule': []},
                         'top_down_weight_to_vip': {'from_same_rule': [], 'from_different_rule': []},
                         'top_down_weight_to_sst': {'from_same_rule': [], 'from_different_rule': []},
                         'top_down_weight_to_pv': {'from_same_rule': [], 'from_different_rule': []}}
    
for dend_nonlinear in ['subtractive', 'divisive_2']:
    for source in ['pfc', 'smsst', 'smpv']:
        print(dend_nonlinear, source)
        fig, ax = plt.subplots(1, 1, figsize=[5, 5])
        fig.patch.set_facecolor('white')
        ax.set_xlim([-0.5, 1.5])
        ax.set_xlim([-0.5, 1.5])
        
        w_same = [x['w_{}_same_rule_soma'.format(source)] for x in all_data_to_exc if x['hp']['dend_nonlinearity'] == dend_nonlinear]
        w_diff = [x['w_{}_diff_rule_soma'.format(source)] for x in all_data_to_exc if x['hp']['dend_nonlinearity'] == dend_nonlinear]
        if source == 'smpv':    # data was included twice for PV neurons
            w_same = list(set(w_same))
            w_diff = list(set(w_diff))
        ax.plot([0, 1], [w_same, w_diff], marker='o', color='k', alpha=0.05)
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['{}, same rule'.format(source), '{}, different rule'.format(source)], rotation=0)
        ax.set_ylabel('Weight from...', fontsize=20)
        
        # plot the means
        y = [w_same, w_diff]
        ax.bar([0, 1], height=[np.mean(yi) for yi in y],
                   width=0.2,    # bar width
                   color=colors,
                   edgecolor=colors
                  )
        ax.set_ylabel('')
        
        if np.sum(y)!=0:
            ttest = stats.ttest_ind(y[0], y[1], alternative='greater')
            print('student t test {}, p={}, n={}'.format(ttest[0], ttest[1], len(y[0])))

        make_pretty_axes(ax)    
        fig.tight_layout()
        plt.show()

       

# Figure 4b, c, e - example model

In [17]:
with open('/.../conn_bias_pfc_to_inh.pickle', 'rb') as handle:
    all_data_frompfc = pickle.load(handle)

In [None]:
for x in all_data_frompfc:
    for target in ['pv', 'sst', 'vip']:
        if (target == 'pv' and x['name'] != 'success_2023-05-10-14-28-42_wcst_10_sparsity0') or ((target == 'sst' or target == 'vip') and x['name'] != 'success_2023-05-10-14-28-42_wcst_106_sparsity0'):    # the data in figure 4b is taken from one example model, the data in figure 4c, e is taken from another one
            continue
        y = [x['w_same_rule_{}'.format(target)], x['w_diff_rule_{}'.format(target)]]
        if np.array(y).size==0:
            print('size of y is 0, pass')
            continue
        if len(y[0])!=10 or len(y[1])!=10:
            continue
    
        # plot
        fig, ax = plt.subplots(1, 1, figsize=[5, 6])
        fig.patch.set_facecolor('white')
        fig.suptitle(target)
        ax.plot([0, 1], y, marker='o', color='k', alpha=0.5)
        ax.bar([0, 1],
                height=[np.mean(yi) for yi in y],
                width=0.2,    # bar width
                color=colors,
                edgecolor=colors)
        make_pretty_axes(ax)
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['Same rule', 'Different rule'], rotation=15)
        ax.set_ylabel('Weight from PFC to {} in SM'.format(target), fontsize=20)
        ax.set_xlim([-0.2, 1.2])
        fig.tight_layout()
        plt.show()

        # do statistical test
        ttest = stats.ttest_ind(y[0], y[1], alternative='greater')
        print('student t test, {}, p={}, n={}'.format(ttest[0], ttest[1], len(y[0])))



# Supplementary Figure 8b, c, e, h, i, k: aggregate across models

In [None]:
for dend_nonlinear in ['subtractive', 'divisive_2']:
    print(dend_nonlinear)
    for target in ['sst', 'vip', 'pv']:
        w_same_rule_all = []
        w_diff_rule_all = []  
        for x in all_data_frompfc:
            if x['hp']['dend_nonlinearity'] != dend_nonlinear:
                continue
            w_same_rule_all.extend(x['w_same_rule_{}'.format(target)])
            w_diff_rule_all.extend(x['w_diff_rule_{}'.format(target)])
        yy = [w_same_rule_all, w_diff_rule_all]
        
        fig, ax = plt.subplots(1, 1, figsize=[5, 6])
        fig.patch.set_facecolor('white')
        # do statistical test
        ttest = stats.ttest_ind(yy[0], yy[1])
        print('student t test: {}, p={}, n={}'.format(ttest[0], ttest[1], len(yy[0])))
            
        ax.plot([0, 1], yy, marker='o', color='k', alpha=0.1)
        ax.bar([0, 1],
                height=[np.mean(yi) for yi in yy],
                width=0.2,    # bar width
                color=colors,
                edgecolor=colors)
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['Same rule', 'Different rule'], rotation=15)
        ax.set_ylabel('Weight from PFC to {} in SM'.format(target), fontsize=20)
        ax.set_yticklabels(np.round(ax.get_yticks(), 2))
        ax.set_xlim([-0.2, 1.2])
        make_pretty_axes(ax)
        fig.tight_layout()
        plt.show()


        data_suppfig8[dend_nonlinear]['top_down_weight_to_{}'.format(target)]['from_same_rule'] = yy[0]
        data_suppfig8[dend_nonlinear]['top_down_weight_to_{}'.format(target)]['from_different_rule'] = yy[1]