In [3]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=4); np.random.seed(0)
import torch; torch.set_printoptions(precision=4)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import scipy
import os
from textwrap import wrap
from scipy.stats import wilcoxon

sys.path.append("../two_module_rnn/code")
os.chdir('/home/yl4317/Documents/two_module_rnn/code')
from functions import *
# os.chdir('/home/yl4317/Documents/two_module_rnn/')

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) 
torch.backends.cudnn.deterministic = True    

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
1.13.1+cu116
3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]


# Supplementary figure 1 - single cell traces

In [5]:
n_to_plot = 3    # number of neurons to plot for each group
colors = ['#377eb8', '#4daf4a', '#e41a1c']

for model_name in sorted(os.listdir('/scratch/yl4317/two_module_rnn/saved_models/')):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name)
        if model_name != 'success_2023-05-10-14-28-42_wcst_105_sparsity0':
            continue
        
        # load model
        path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        
        # load data
        with open('/scratch/yl4317/two_module_rnn/saved_testdata/{}'.format(model_name+'_testdata_noiseless_no_current_matrix'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        if mean_perf<0.8 or mean_perf_rule<0.8:
            print('low performing model ({}/{})'.format(mean_perf, mean_perf_rule))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                         rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                         rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                         resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                         stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)
        resp_sel_normalized = all_sels['resp_normalized']
        rule_sel_normalized = all_sels['rule_normalized_activity']

        # subregions
        subcg_pfc_idx = define_subpop_pfc(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], err_sel=all_sels['error_normalized'], 
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule1_after_error_trs=rule1_trs_after_error,
                                          rule2_after_error_trs=rule2_trs_after_error,
                                          rule1_after_correct_trs = trial_labels['rule1_trs_after_correct'],
                                          rule2_after_correct_trs = trial_labels['rule2_trs_after_correct'],
                                          rule_threshold=0.5, err_threshold=0.5)
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], ref_card_sel=all_sels['ref_card_normalized'],
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule_threshold=0, resp_threshold=0)
        for subcg in subcg_pfc_idx.keys():
            model.rnn.cg_idx['subcg_pfc_'+subcg] = subcg_pfc_idx[subcg]
        for subcg in subcg_sr_idx.keys():
            model.rnn.cg_idx['subcg_sr_'+subcg] = subcg_sr_idx[subcg]
            
            
        #======= PLOT ========#
        n_trials = rnn_activity.shape[0]
        rules = test_data['rules']
        switch_trs = [tr for tr in range(n_trials-1) if rules[tr]=='color' and rules[tr+1]=='shape']
        first_tr = switch_trs[0]
        trs_to_plot = [first_tr, first_tr+1, first_tr+2]
        for tr in trs_to_plot:
            print(rules[tr])
        rnn_activity_selected = rnn_activity[trs_to_plot, :, 0, :]
        n_ts = rnn_activity.shape[1]
        n_units = rnn_activity.shape[-1]
        rnn_activity_selected = rnn_activity_selected.reshape(3*n_ts, n_units)   
        for cg in model.rnn.cell_group_list:
            if 'pfc_esoma' in cg:
                units_to_plot = list(np.random.choice(subcg_pfc_idx['rule1_{}'.format(cg)], n_to_plot)) + list(np.random.choice(subcg_pfc_idx['rule2_{}'.format(cg)], n_to_plot)) + list(np.random.choice(subcg_pfc_idx['mix_err_rule2_{}'.format(cg)], n_to_plot))
            else:
                if 'sr' in cg:
                    if len(subcg_sr_idx['rule1_{}'.format(cg)])==0 or len(subcg_sr_idx['rule2_{}'.format(cg)])==0:
                        print('No rule 1/2 neurons')
                        continue
                    else:
                        units_to_plot = list(np.random.choice(subcg_sr_idx['rule1_{}'.format(cg)], n_to_plot)) + list(np.random.choice(subcg_sr_idx['rule2_{}'.format(cg)], n_to_plot))
                
                elif 'pfc' in cg:
                    if len(subcg_pfc_idx['rule1_{}'.format(cg)])==0 or len(subcg_pfc_idx['rule2_{}'.format(cg)])==0:
                        print('No rule 1/2 neurons')
                        continue
                    else:
                        units_to_plot = list(np.random.choice(subcg_pfc_idx['rule1_{}'.format(cg)], n_to_plot)) + list(np.random.choice(subcg_pfc_idx['rule2_{}'.format(cg)], n_to_plot))

            print(cg, 'units shown: {}'.format(units_to_plot))
            
            fig, ax = plt.subplots(figsize=[15, 3])
            ax.set_title(cg)
            for i in range(len(units_to_plot)):
                n = units_to_plot[i]
                if 'sr' in cg:
                    if n in subcg_sr_idx['rule1_{}'.format(cg)]:
                        color = colors[0]
                    elif n in subcg_sr_idx['rule2_{}'.format(cg)]:
                        color = colors[1]
                elif 'pfc' in cg:
                    if n in subcg_pfc_idx['rule1_{}'.format(cg)]:
                        color = colors[0]
                    elif n in subcg_pfc_idx['rule2_{}'.format(cg)]:
                        color = colors[1]
                    else:
                        color = colors[2]
                ax.plot(rnn_activity_selected[:, n], color=color)
            ax.axvspan(int(hp_task_test['trial_history_start']/hp_test['dt']), int(hp_task_test['trial_history_end']/hp_test['dt'])-1, color='green', alpha=0.1) 
            ax.axvspan(int(hp_task_test['trial_history_start']/hp_test['dt'])+n_ts, int(hp_task_test['trial_history_end']/hp_test['dt'])-1+n_ts, color='green', alpha=0.1) 
            ax.axvspan(int(hp_task_test['trial_history_start']/hp_test['dt'])+2*n_ts, int(hp_task_test['trial_history_end']/hp_test['dt'])-1+2*n_ts, color='red', alpha=0.1) 
            for i in range(3):
                ax.axvline(x=hp_task_test['trial_history_start']//hp_test['dt'] + i*n_ts, linestyle='dashed', color='gray')
                ax.axvline(x=hp_task_test['trial_history_end']//hp_test['dt'] + i*n_ts, linestyle='dashed', color='gray')
                ax.axvline(x=hp_task_test['center_card_on']//hp_test['dt'] + i*n_ts, linestyle='dashed', color='gray')
                ax.axvline(x=hp_task_test['test_cards_on']//hp_test['dt'] + i*n_ts, linestyle='dashed', color='gray')
                ax.axvline(x=hp_task_test['resp_end']//hp_test['dt'] + i*n_ts, linestyle='dashed', color='gray')
            make_pretty_axes(ax)
            fig.tight_layout()
            plt.show()
            
            
            
            
        

success_2023-05-10-14-28-42_wcst_105_sparsity0


TypeError: define_subpop_pfc() missing 2 required positional arguments: 'rule1_after_correct_trs' and 'rule2_after_correct_trs'

In [6]:
trial_labels.keys()

dict_keys(['error_trials', 'correct_trials', 'error_trials_rule', 'correct_trials_rule', 'rule1_trs_after_error', 'rule1_after_error_now_correct_trs', 'rule1_trs_after_correct', 'rule1_trs_stable', 'rule2_trs_after_error', 'rule2_after_error_now_correct_trs', 'rule2_trs_after_correct', 'rule2_trs_stable', 'switch_trs', 'c1_trs_stable', 'c2_trs_stable', 'c3_trs_stable', 'trs_by_center_card_stable', 'rule1_trs', 'rule2_trs', 'c1_trs', 'c2_trs', 'c3_trs'])

# Archived

In [None]:
# Look at the E and I input to VIP neurons & SST neurons 
all_data = {}

for model_name in sorted(os.listdir('/scratch/yl4317/two_module_rnn/saved_models/')):
    if ('2023-05-01' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name)
        
        # load model
        path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        
        # load data
        with open('/scratch/yl4317/two_module_rnn/saved_testdata/{}'.format(model_name+'_testdata_noiseless_no_current_matrix'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        if mean_perf<0.8 or mean_perf_rule<0.8:
            print('low performing model ({}/{}), skip'.format(mean_perf, mean_perf_rule))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        w_rec = model.rnn.effective_weight(w=model.rnn.w_rec, mask=model.rnn.mask, w_fix=model.rnn.w_fix).detach().cpu().numpy()
        
        # compute current
        rnn_activity = rnn_activity.squeeze(2)
#         n_timesteps = rnn_activity.shape[1]
#         n_neurons = rnn_activity.shape[-1]
#         current_matrix = np.zeros([rnn_activity.shape[0], n_timesteps, n_neurons, n_neurons])    # trial x time x neuron x neuron
#         for n_sender in range(n_neurons):
# #             print(n_sender)
#             for n_receiver in range(n_neurons):
#                 current_matrix[:, :, n_sender, n_receiver] = rnn_activity[:,:,n_sender] * w_rec[n_sender, n_receiver]
#         print(current_matrix.shape)
        i_pfc_to_vip = np.tensordot(rnn_activity[:, :, model.rnn.cg_idx['pfc_esoma']], w_rec[np.ix_(model.rnn.cg_idx['pfc_esoma'], model.rnn.cg_idx['sr_vip'])], axes=(-1,0))
        i_pfc_to_vip = np.mean(i_pfc_to_vip, axis=(0, 1))
        i_sst_to_vip = np.tensordot(rnn_activity[:, :, model.rnn.cg_idx['sr_sst']], w_rec[np.ix_(model.rnn.cg_idx['sr_sst'], model.rnn.cg_idx['sr_vip'])], axes=(-1,0))
        i_sst_to_vip = np.mean(i_sst_to_vip, axis=(0, 1))
        i_pfc_to_sst = np.tensordot(rnn_activity[:, :, model.rnn.cg_idx['pfc_esoma']], w_rec[np.ix_(model.rnn.cg_idx['pfc_esoma'], model.rnn.cg_idx['sr_sst'])], axes=(-1,0))
        i_pfc_to_sst = np.mean(i_pfc_to_sst, axis=(0, 1))
        i_vip_to_sst = np.tensordot(rnn_activity[:, :, model.rnn.cg_idx['sr_vip']], w_rec[np.ix_(model.rnn.cg_idx['sr_vip'], model.rnn.cg_idx['sr_sst'])], axes=(-1,0))
        i_vip_to_sst = np.mean(i_vip_to_sst, axis=(0, 1))
        i_exc_to_sst = np.tensordot(rnn_activity[:, :, model.rnn.cg_idx['sr_esoma']], w_rec[np.ix_(model.rnn.cg_idx['sr_esoma'], model.rnn.cg_idx['sr_sst'])], axes=(-1,0))
        i_exc_to_sst = np.mean(i_exc_to_sst, axis=(0, 1))
        
        all_data[model_name] = {'i_pfc_to_vip': i_pfc_to_vip,
                                'i_sst_to_vip': i_sst_to_vip,
                                'i_pfc_to_sst': i_pfc_to_sst,
                                'i_vip_to_sst': i_vip_to_sst,
                                'i_exc_to_sst': i_exc_to_sst}

In [None]:
model_list = list(all_data.keys())

fig, ax = plt.subplots(2, 2)
for model_name in model_list:
    ax[0, 0].scatter(x=all_data[model_name]['i_pfc_to_vip'], y=all_data[model_name]['i_sst_to_vip'], color='k', s=5)
    ax[0, 0].set_xlabel('PFC to VIP'); ax[0, 0].set_ylabel('SST to VIP')
    ax[0, 1].scatter(x=all_data[model_name]['i_pfc_to_sst'], y=all_data[model_name]['i_vip_to_sst'], color='k', s=5)
    ax[0, 1].set_xlabel('PFC to SST'); ax[0, 1].set_ylabel('VIP to SST')
    ax[1, 0].scatter(x=all_data[model_name]['i_exc_to_sst'], y=all_data[model_name]['i_vip_to_sst'], color='k', s=5)
    ax[1, 0].set_xlabel('Local exc to SST'); ax[1, 0].set_ylabel('VIP to SST')
    ax[1, 1].scatter(x=all_data[model_name]['i_pfc_to_sst'] + all_data[model_name]['i_exc_to_sst'], y=all_data[model_name]['i_vip_to_sst'], color='k', s=5)
    ax[1, 1].set_xlabel('Local exc + PFC to SST'); ax[1, 1].set_ylabel('VIP to SST')
fig.tight_layout()
plt.show()

In [3]:
model_list = list(all_data.keys())

i_pfc_to_vip_all = []    # across all models
i_sst_to_vip_all = []    # across all models
i_pfc_to_sst_all = []    # across all models
i_vip_to_sst_all = []    # across all models
i_exc_to_sst_all = []    # across all models

fig, ax = plt.subplots(2, 2, figsize=[10, 10])
for model_name in model_list:
    y1, y2 = all_data[model_name]['i_pfc_to_vip'], np.abs(all_data[model_name]['i_sst_to_vip'])
    ax[0, 0].plot([1, 2], [y1, y2], marker='o', color='k', alpha=0.05, markersize=2)
    
    y1, y2 = all_data[model_name]['i_pfc_to_sst'], np.abs(all_data[model_name]['i_vip_to_sst'])
    ax[0, 1].plot([1, 2], [y1, y2], marker='o', color='k', alpha=0.05, markersize=2)
    
    y1, y2 = all_data[model_name]['i_exc_to_sst'], np.abs(all_data[model_name]['i_vip_to_sst'])
    ax[1, 0].plot([1, 2], [y1, y2], marker='o', color='k', alpha=0.05, markersize=2)
    
    y1, y2 = all_data[model_name]['i_exc_to_sst'] + all_data[model_name]['i_pfc_to_sst'], np.abs(all_data[model_name]['i_vip_to_sst'])
    ax[1, 1].plot([1, 2], [y1, y2], marker='o', color='k', alpha=0.05, markersize=2)
    
#     ax[0, 0].set_xlabel('PFC to VIP'); ax[0, 0].set_ylabel('SST to VIP')
#     ax[0, 1].scatter(x=all_data[model_name]['i_pfc_to_sst'], y=all_data[model_name]['i_vip_to_sst'], color='k', s=5)
#     ax[0, 1].set_xlabel('PFC to SST'); ax[0, 1].set_ylabel('VIP to SST')
#     ax[1, 0].scatter(x=all_data[model_name]['i_exc_to_sst'], y=all_data[model_name]['i_vip_to_sst'], color='k', s=5)
#     ax[1, 0].set_xlabel('Local exc to SST'); ax[1, 0].set_ylabel('VIP to SST')
#     ax[1, 1].scatter(x=all_data[model_name]['i_pfc_to_sst'] + all_data[model_name]['i_exc_to_sst'], y=all_data[model_name]['i_vip_to_sst'], color='k', s=5)

    i_pfc_to_vip_all.extend(all_data[model_name]['i_pfc_to_vip'])
    i_pfc_to_sst_all.extend(all_data[model_name]['i_pfc_to_sst'])
    i_vip_to_sst_all.extend(all_data[model_name]['i_vip_to_sst'])
    i_sst_to_vip_all.extend(all_data[model_name]['i_sst_to_vip'])
    i_exc_to_sst_all.extend(all_data[model_name]['i_exc_to_sst'])
    
y1, y2 = i_pfc_to_vip_all, np.abs(i_sst_to_vip_all)
y1_err, y2_err = np.std(y1), np.std(y2)
ax[0 ,0].bar(x=[1, 2], height=[np.mean(y1), np.mean(y2)], color=['blue', 'red'], yerr=[y1_err, y2_err], alpha=0.5)
ax[0, 0].set_xticks([1, 2]); ax[0, 0].set_xticklabels(['PFC to VIP', 'SST to VIP'])

y1, y2 = i_pfc_to_sst_all, np.abs(i_vip_to_sst_all)
y1_err, y2_err = np.std(y1), np.std(y2)
ax[0 ,1].bar(x=[1, 2], height=[np.mean(y1), np.mean(y2)], color=['blue', 'red'], yerr=[y1_err, y2_err], alpha=0.5)
ax[0, 1].set_xticks([1, 2]); ax[0, 1].set_xticklabels(['PFC to SST', 'VIP to SST'])

y1, y2 = i_exc_to_sst_all, np.abs(i_vip_to_sst_all)
y1_err, y2_err = np.std(y1), np.std(y2)
ax[1 ,0].bar(x=[1, 2], height=[np.mean(y1), np.mean(y2)], color=['blue', 'red'], yerr=[y1_err, y2_err], alpha=0.5)
ax[1, 0].set_xticks([1, 2]); ax[1, 0].set_xticklabels(['Local E to SST', 'VIP to SST'])

y1, y2 = i_pfc_to_sst_all + i_exc_to_sst_all, np.abs(i_vip_to_sst_all)
y1_err, y2_err = np.std(y1), np.std(y2)
ax[1 ,1].bar(x=[1, 2], height=[np.mean(y1), np.mean(y2)], color=['blue', 'red'], yerr=[y1_err, y2_err], alpha=0.5)
ax[1, 1].set_xticks([1, 2]); ax[1, 1].set_xticklabels(['Local E + PFC to SST', 'VIP to SST'], rotation=10)

for i in range(ax.shape[0]):
    for j in range(ax.shape[1]):
        make_pretty_axes(ax[i, j])
        
fig.tight_layout()
plt.show()

NameError: name 'all_data' is not defined