In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12); 
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
import warnings

from textwrap import wrap
from scipy.stats import wilcoxon
from sklearn.metrics.pairwise import cosine_similarity

# from model import *
from functions import *

                
%matplotlib inline

# Figure 7f: Network performance after silencing SST, VIP or PV cells in the sensorimotor module

In [None]:
start = time.time()

all_data_opto_perf = []

model_dir = ''
test_data_dir = ''

for model_name in sorted(os.listdir(model_dir)):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and ('success' in model_name):
        print(model_name)
        path_to_file = model_dir + model_name
        
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)
        
        with open(test_data_dir+'{}_testdata_silenceSRSST_noiseless'.format(model_name), 'rb') as f: 
            neural_data_silenceSST = pickle.load(f)  
        test_data_inhibit_sst = neural_data_silenceSST['test_data']
        mean_perf_nosst = np.mean([_[0] for _ in test_data_inhibit_sst['perfs']])
        mean_perf_rule_nosst = np.mean([_[0] for _ in test_data_inhibit_sst['perf_rules']])
        
        with open(test_data_dir+'{}_testdata_silenceSRVIP_noiseless'.format(model_name), 'rb') as f:
            neural_data_silenceVIP = pickle.load(f)
        test_data_inhibit_vip = neural_data_silenceVIP['test_data']
        mean_perf_novip = np.mean([_[0] for _ in test_data_inhibit_vip['perfs']])
        mean_perf_rule_novip = np.mean([_[0] for _ in test_data_inhibit_vip['perf_rules']])
        
        with open(test_data_dir+'{}_testdata_silenceSRPV_noiseless'.format(model_name), 'rb') as f:
            neural_data_silencePV = pickle.load(f)
        test_data_inhibit_pv = neural_data_silencePV['test_data']
        mean_perf_nopv = np.mean([_[0] for _ in test_data_inhibit_pv['perfs']])
        mean_perf_rule_nopv = np.mean([_[0] for _ in test_data_inhibit_pv['perf_rules']])
        
        with open(test_data_dir+'{}_testdata_noiseless'.format(model_name), 'rb') as f:
            neural_data = pickle.load(f)
        test_data_intact = neural_data['test_data']
        mean_perf_intact = np.mean([_[0] for _ in test_data_intact['perfs']])
        mean_perf_rule_intact = np.mean([_[0] for _ in test_data_intact['perf_rules']])
        if mean_perf_intact<=0.8 or mean_perf_rule_intact<=0.8:
            print('low perf, pass ({}/{})'.format(mean_perf_intact, mean_perf_rule_intact))
            continue
        
        
        all_data_opto_perf.append({'hp': hp_test,
                                 'mean_perf_intact': mean_perf_intact, 
                                 'mean_perf_rule_intact': mean_perf_rule_intact, 
                                 'mean_perf_nosst': mean_perf_nosst,
                                 'mean_perf_novip': mean_perf_novip,
                                 'mean_perf_nopv': mean_perf_nopv,
                                 'mean_perf_rule_nosst': mean_perf_rule_nosst,
                                 'mean_perf_rule_novip': mean_perf_rule_novip,
                                 'mean_perf_rule_nopv': mean_perf_rule_nopv})   
        
        
print(time.time()-start)

In [None]:
fig, ax = plt.subplots(2, 3, figsize=[10, 10])
plt.style.use('classic')
fig.patch.set_facecolor('white')
for data in all_data_opto_perf:
    if data['hp']['dend_nonlinearity']!='divisive_2':    # sub-select dendritic nonlinearity here. subtractive or divisive_2
        continue
    perf_intact = data['mean_perf_intact']
    perf_nosst = data['mean_perf_nosst']
    perf_novip = data['mean_perf_novip']
    perf_nopv = data['mean_perf_nopv']
    perf_rule_intact = data['mean_perf_rule_intact']
    perf_rule_nosst = data['mean_perf_rule_nosst']
    perf_rule_novip = data['mean_perf_rule_novip']
    perf_rule_nopv = data['mean_perf_rule_nopv']
    ax[0, 0].plot([perf_intact, perf_nopv], color='k', alpha=0.5, marker='o')
    ax[1, 0].plot([perf_rule_intact, perf_rule_nopv], color='k', alpha=0.5, marker='o')
    ax[0, 1].plot([perf_intact, perf_nosst], color='k', alpha=0.5, marker='o')
    ax[1, 1].plot([perf_rule_intact, perf_rule_nosst], color='k', alpha=0.5, marker='o')
    ax[0, 2].plot([perf_intact, perf_novip], color='k', alpha=0.5, marker='o')
    ax[1, 2].plot([perf_rule_intact, perf_rule_novip], color='k', alpha=0.5, marker='o')
    for i in range(2):
        for j in range(3):
            ax[i, j].set_xticks([0, 1])
            if j==0:
                ctype='PV'
            elif j==1:
                ctype='SST'
            elif j==2:
                ctype='VIP'
            ax[i, j].set_xticklabels(['Intact', 'Opto {}'.format(ctype)], rotation=20)
            ax[i, j].set_xlim([-0.5, 1.5])
            ax[i, j].set_ylim([0, 1])
            if i==0:
                ax[i, j].axhline(y=1/3, linestyle='dashed', color='k')
                ax[i, j].set_ylabel('Performance', fontsize=20)
            elif i==1:
                ax[i, j].axhline(y=1/2, linestyle='dashed', color='k')
                ax[i, j].set_ylabel('Performance (rule)', fontsize=20)
            make_pretty_axes(ax[i, j])
fig.tight_layout()
plt.show()
    