In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12); 
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
import warnings

from textwrap import wrap
from scipy.stats import wilcoxon
from sklearn.metrics.pairwise import cosine_similarity


from functions import *

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    

# Figure 7e and Supplementary figure 11b: performance after silencing SST

In [None]:
with open('/.../silencing_perf.pickle', 'rb') as handle:
    all_data_opto_perf = pickle.load(handle)


data_fig7e = {'intact': [], 'silence_sst': []}
data_suppfig11b = {'intact': [], 'silence_sst': []}

for dend_nonlinear in ['subtractive', 'divisive_2']:
    print(dend_nonlinear)
    fig, ax = plt.subplots(1, 1, figsize=[3.5, 5])
    plt.style.use('classic')
    fig.patch.set_facecolor('white')
    for data in all_data_opto_perf:
        if data['hp']['dend_nonlinearity']!=dend_nonlinear:
            continue
        perf_intact = data['mean_perf_intact']
        perf_nosst = data['mean_perf_nosst']
        ax.plot([perf_intact, perf_nosst], color='k', alpha=0.5, marker='o')
        
        if data['hp']['dend_nonlinearity'] == 'subtractive':
            data_fig7e['intact'].append(perf_intact)
            data_fig7e['silence_sst'].append(perf_nosst)
        elif data['hp']['dend_nonlinearity'] == 'divisive_2':
            data_suppfig11b['intact'].append(perf_intact)
            data_suppfig11b['silence_sst'].append(perf_nosst)

        ax.set_xticks([0, 1])
        ax.set_xticklabels(['Intact', 'Silence SST'], rotation=20)
        ax.set_xlim([-0.5, 1.5])
        ax.set_ylim([0, 1])
        ax.axhline(y=1/3, linestyle='dashed', color='k')
        ax.set_ylabel('Performance', fontsize=20)
        make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()



# Supplementary figure 11: performance after silencing PV and VIP in the sensorimotor module

In [None]:
data_silencepv = {'intact': [], 'silence_pv': []}
data_silencevip = {'intact': [], 'silence_vip': []}


for cell_type in ['pv', 'vip']:
    fig, ax = plt.subplots(figsize=[3.5, 5])
    plt.style.use('classic')
    fig.patch.set_facecolor('white')
    for data in all_data_opto_perf:
        perf_intact = data['mean_perf_intact']
        perf_inactivate = data['mean_perf_no{}'.format(cell_type)]
        ax.plot([perf_intact, perf_inactivate], color='k', alpha=0.5, marker='o')
        if cell_type == 'pv':
            data_silencepv['intact'].append(perf_intact)
            data_silencepv['silence_pv'].append(perf_inactivate)
        elif cell_type == 'vip':
            data_silencevip['intact'].append(perf_intact)
            data_silencevip['silence_vip'].append(perf_inactivate)
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['Intact', 'Silence {}'.format(cell_type)], rotation=20)
        ax.set_xlim([-0.5, 1.5])
        ax.set_ylim([0, 1])
        ax.axhline(y=1/3, linestyle='dashed', color='k')
        ax.set_ylabel('Performance', fontsize=20)
        make_pretty_axes(ax)
        fig.tight_layout()
    plt.show()