In [3]:
%load_ext autoreload
%autoreload 2
import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12); 
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
import warnings

from textwrap import wrap
from scipy.stats import wilcoxon
from sklearn.metrics.pairwise import cosine_similarity

from functions import *

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    

1.13.1+cu116
3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]


# Figure 7b, c: rule selectivity on different branches across different sparsity levels

In [38]:
with open('/.../branch_coding.pickle', 'rb') as handle:
    all_data_branch_coding = pickle.load(handle)

In [None]:
all_dend_nonlinears = ['subtractive', 'divisive_2']
all_sparsitys = [0, 0.2, 0.4, 0.6, 0.8]

data_fig7b = {'x': [], 'y': []}    # rule selectivity of two branches, for fully-connected SST -> exc 
data_fig7c = {'x': [], 'y': []}    # rule selectivity of two branches, for 20% connected SST -> exc 

for dend_nonlinear in all_dend_nonlinears:
    for sparsity in all_sparsitys:
        if not ((dend_nonlinear == 'subtractive' and sparsity == 0) or (dend_nonlinear == 'subtractive' and sparsity == 0.8)):
            continue
        fig, ax = plt.subplots(1, 1, figsize=[7, 7])
        fig.patch.set_facecolor('white')
        fig.suptitle('Rule selectivity for the two branches, \nacross all models, {} {}'.format(sparsity, dend_nonlinear), fontsize=20)
        ax.set_xlabel('Rule selectivity of \none dendritic branch', fontsize=20)
        ax.set_ylabel('Rule selectivity of \nthe other dendritic branch', fontsize=20)
        ax.set_xlim([-1, 1])
        ax.set_ylim([-1, 1])
        for x in all_data_branch_coding:
            if x['hp']['sparse_srsst_to_sredend']!=sparsity or (x['hp']['dend_nonlinearity']!=dend_nonlinear):
                continue
            
            for n in range(70):
                dend1_idx = n+70    # id of one dendritic branch
                dend2_idx = n+140    # the other branch
                ax.scatter(x=x['rule_sel_dend1_norm'][n], y=x['rule_sel_dend2_norm'][n], s=30, color='k', alpha=1)
                ax.axvline(x=0, linestyle='dotted', color='k')
                ax.axhline(y=0, linestyle='dotted', color='k')
                ax.set_xlim([-1.1, 1.1])
                ax.set_ylim([-1.1, 1.1])
            
                # source data
                if dend_nonlinear == 'subtractive' and sparsity == 0:
                    data_fig7b['x'].append(x['rule_sel_dend1_norm'][n])
                    data_fig7b['y'].append(x['rule_sel_dend2_norm'][n])
                if dend_nonlinear == 'subtractive' and sparsity == 0.8:
                    data_fig7c['x'].append(x['rule_sel_dend1_norm'][n])
                    data_fig7c['y'].append(x['rule_sel_dend2_norm'][n])

        make_pretty_axes(ax)
        fig.tight_layout()
        plt.show()


# Figure 7d: branch coding as a function of sparsity of SST->Edend

In [None]:
# branch coding as a function of sparsity of SST->Edend
import scipy.stats as stats
all_dend_nonlinears = list(set(x['hp']['dend_nonlinearity'] for x in all_data_branch_coding))


data_fig7d = {'x': [], 'y': [], 'y_err': []}
data_suppfig11a = {'x': [], 'y': [], 'y_err': []}


for dend_nonlinear in all_dend_nonlinears:
    print(dend_nonlinear)
    

    sparsities = sorted(list(set([x['hp']['sparse_srsst_to_sredend'] for x in all_data_branch_coding])), reverse=True)
    sparsity_vs_diffrulesel = dict.fromkeys(sparsities)
    for s in sparsities:
        sparsity_vs_diffrulesel[s] = {}              
        all_diff_rulesel = []
        for x in all_data_branch_coding:
#             if x['hp']['initialization_weights']!='kaiming_uniform':
#                 continue
            if x['hp']['dend_nonlinearity']!=dend_nonlinear:
                continue
            if x['hp']['sparse_srsst_to_sredend']==s:
#                 if np.isnan(x['rule_sel_dend1']).any()==True or np.isnan(x['rule_sel_dend2']).any()==True:
#                     print('nan!')
#                     print(x['rule_sel_dend1'], x['rule_sel_dend2'])
#                     continue
#                 print(np.array(x['rule_sel_dend1_norm']), np.array(x['rule_sel_dend2_norm']))
                all_diff_rulesel.extend(np.abs(np.array(x['rule_sel_dend1_norm']) - np.array(x['rule_sel_dend2_norm'])))
        mean_diff_rulesel = np.mean(all_diff_rulesel)
        print('{}, {}, n={}'.format(dend_nonlinear, s, len(all_diff_rulesel)))
        std_diff_rulesel = np.std(all_diff_rulesel)
        sem_diff_rulesel = stats.sem(all_diff_rulesel)
        sparsity_vs_diffrulesel[s]['mean_diff_rulesel'] = mean_diff_rulesel
        sparsity_vs_diffrulesel[s]['std_diff_rulesel'] = std_diff_rulesel
        sparsity_vs_diffrulesel[s]['sem_diff_rulesel'] = sem_diff_rulesel
    
#     if dend_nonlinear=='v2_std':
#         print(sparsity_vs_diffrulesel)

    #===== Plotting =====#
    fig, ax = plt.subplots(figsize=[6, 5])  
    fig.patch.set_facecolor('white')
    x = sparsities
    y = [sparsity_vs_diffrulesel[s]['mean_diff_rulesel'] for s in sparsities]
    y_err = [sparsity_vs_diffrulesel[s]['sem_diff_rulesel'] for s in sparsities]
    ax.errorbar(x=x, y=y, yerr=y_err, marker='o', color='k')
    ax.set_xlim([min(sparsities)-0.1, max(sparsities)+0.1])
    ax.set_xlabel('Sparsity', fontsize=20)
#     ax[1].set_ylabel('Mean difference in rule \nselectivity between branches', fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()

    # collect source data
    if dend_nonlinear == 'subtractive':
        data_fig7d['x'] = x
        data_fig7d['y'] = y
        data_fig7d['y_err'] = y_err
    elif dend_nonlinear == 'divisive_2':
        data_suppfig11a['x'] = x
        data_suppfig11a['y'] = y
        data_suppfig11a['y_err'] = y_err

