In [4]:
%load_ext autoreload
%autoreload 2
import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4); seed = 1; torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
import warnings
from sklearn.decomposition import PCA
from textwrap import wrap
from scipy.stats import wilcoxon
from scipy.linalg import subspace_angles

from functions import *

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
1.13.1+cu116
3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]


# Figure 6cd + supplementary figure 10ab: angle between rule and response subspaces

In [None]:
# change this to your directory
with open('/.../angle_subspace_intact.pickle', 'rb') as handle:
    all_data_intact = pickle.load(handle)


model_list = list(all_data.keys())
for dend_nonlinearity in ['subtractive', 'divisive_2']:
    angle_rules_all_models = []
    angle_rules_all_models_shuffle = []
    angle_choices_all_models = []
    angle_choices_all_models_shuffle = []
    angle_rule_choice_all_models = []
    angle_rule_choice_all_models_shuffle = []
    for model in model_list:
        if all_data[model]['hp']['dend_nonlinearity'] != dend_nonlinearity:
            continue
        angle_rules_all_models.append(all_data_intact[model]['angle_rule_subspace'])
        angle_rules_all_models_shuffle.extend(all_data_intact[model]['angle_rule_subspace_shuffle'])
        angle_choices_all_models.append(all_data_intact[model]['angle_choice_subspace_avg'])
        angle_choices_all_models_shuffle.extend(all_data_intact[model]['angle_choice_subspace_avg_shuffle'])
        angle_rule_choice_all_models.append(all_data_intact[model]['angle_rule_choice'])
        angle_rule_choice_all_models_shuffle.extend(all_data_intact[model]['angle_rule_choice_shuffle'])
    
    # subspace angle compare to shuffled data
    fig, ax = plt.subplots(1, 1, figsize=[5, 3])
    bins=np.arange(0, 90, 1)
    ax.set_title('principle angle between rule subspaces\n{}'.format(dend_nonlinearity))
    ax.hist(angle_rules_all_models, color='k', density=True, bins=bins)
    ax.hist(angle_rules_all_models_shuffle, color='gray', alpha=0.5, density=True, bins=bins)
    make_pretty_axes(ax)
    ax.set_xlim([0, 90])
    ax.set_xticks([0, 90])
    fig.tight_layout()
    plt.show()
    
    fig, ax = plt.subplots(1, 1, figsize=[5, 3])
    bins=np.arange(0, 90, 1)
    ax.set_title('principle angle between choice subspaces\n{}'.format(dend_nonlinearity))
    ax.hist(angle_choices_all_models, color='k', density=True, bins=bins)
    ax.hist(angle_choices_all_models_shuffle, color='gray', alpha=0.5, density=True, bins=bins)
    make_pretty_axes(ax)
    ax.set_xlim([0, 90])
    ax.set_xticks([0, 90])
    fig.tight_layout()
    plt.show()
            
    
    if dend_nonlinearity == 'subtractive':
        data_fig6cd = {'rule': angle_rules_all_models, 'rule_shuffle': angle_rules_all_models_shuffle, 'response': angle_choices_all_models, 'response_shuffle': angle_choices_all_models_shuffle}
    elif dend_nonlinearity == 'divisive_2':
        data_suppfig10ab = {'rule': angle_rules_all_models, 'rule_shuffle': angle_rules_all_models_shuffle, 'response': angle_choices_all_models, 'response_shuffle': angle_choices_all_models_shuffle}



# Figure 7f and Supplementary Figure 11c: principal angle between rule subspaces when SST neurons are inhibited

In [None]:
with open('/.../angle_subspace_inactivation.pickle', 'rb') as handle:
    all_data_inact = pickle.load(handle)

data_fig7f = {}
data_suppfig11c = {}

for cell_type in ['SST', 'PV']:
    for dend_nonlinear in ['subtractive', 'divisive_2']:
        print(dend_nonlinear)
    
        angles = []
        angles_inactivation = []
    
        # subspace angle when silencing SST neurons
        fig, ax = plt.subplots(1, 1, figsize=[2.75, 4])
        ax.set_title('principle angle between\n rule subspaces')
        
        model_list = list(all_data_inact.keys())
        
        for model in model_list:
            if all_data_inact[model]['hp']['dend_nonlinearity']!=dend_nonlinear:
                continue
            if ~np.isnan(all_data_inact[model]['angle_rule_subspace_no{}'.format(cell_type)]):
                ax.plot([0, 1], [all_data_intact[model]['angle_rule_subspace'], all_data_inact[model]['angle_rule_subspace_no{}'.format(cell_type)]], color='k', alpha=0.5, marker='o', clip_on=False)
                angles.append(all_data_intact[model]['angle_rule_subspace'])
                angles_inactivation.append(all_data_inact[model]['angle_rule_subspace_no{}'.format(cell_type)])
        make_pretty_axes(ax)
        ax.set_ylim([0, 90])
        ax.set_yticks([0, 90])
        ax.set_xlim([-0.2, 1.2])
        ax.set_xticks([0, 1], ['intact', 'silenced \n{}'.format(cell_type)])
        fig.tight_layout()
        plt.show()
    
        print(scipy.stats.ttest_ind(a=angles, b=angles_inactivation, alternative='greater'))
        print('n={}'.format(len(angles)))

        if cell_type == 'SST':
            if dend_nonlinear == 'subtractive':
                data_fig7f['intact'] = angles
                data_fig7f['silence_sst'] = angles_inactivation
            elif dend_nonlinear == 'divisive_2':
                data_suppfig11c['intact'] = angles
                data_suppfig11c['silence_sst'] = angles_inactivation
