In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=2); np.random.seed(0)
import torch; torch.set_printoptions(precision=2)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt
# import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl
from mpltern.ternary.datasets import get_scatter_points


import seaborn as sns
import time
import sys
import itertools 
import random; random.seed(0)
import datetime
import pickle
import copy
import pandas as pd
import scipy; from scipy import stats; from scipy.stats import wilcoxon
import os

from sklearn.cluster import KMeans
from sklearn.manifold import MDS
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity

import sys

from functions import *


print(torch.__version__)
print(sys.version)
                
%matplotlib inline

# Generate data for Figure 5c (structure in the input weights)

In [None]:
start = time.time()
plt.rc('font', size=12)

all_data = []

for model_name in sorted(os.listdir('/model/directory/')):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name+'\n')
        
        # load model
        path_to_file = '/model/directory/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        
        # load the neural data
        with open('/where/test/data/is/stored/'+model_name+'_testdata_noiseless', 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_test_perf = np.mean([_[0] for _ in test_data['perfs']])
        if mean_test_perf<=0.8:
            print('perf too low ({}), pass\n'.format(mean_test_perf))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                         rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                         rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                         resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                         stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)

        # define neuron pools
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], ref_card_sel=all_sels['ref_card_normalized'],
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule_threshold=0, resp_threshold=0)
        for subcg in subcg_sr_idx.keys():
            model.rnn.cg_idx['subcg_sr_'+subcg] = subcg_sr_idx[subcg]
            
        w_in_eff = model.rnn.effective_weight(w=model.rnn.w_in, mask=model.rnn.mask_in).detach().cpu().numpy()

        all_data.append({'model': model, 'model_name': model_name, 'hp': hp_test, 
                         'n_sr_esoma': model.rnn.n['sr_esoma'], 'n_branches': model.rnn.n_branches, 
                         'w_in_eff': w_in_eff, 'subcg_sr_idx': subcg_sr_idx, 'all_sels': all_sels})
                
print('Elapsed time: {}s'.format(time.time()-start))     

# Figure 5c

In [None]:

w_in_large_all_models = []
w_in_large_ref_all_models = []
w_in_large_test_all_models = []
w_in_small_all_models = []

for data in all_data:
    if data['hp']['dend_nonlinearity']!='divisive_2':
        continue
    w_in_larges = []
    w_in_large_refs = []
    w_in_large_tests = []
    w_in_smalls = []
    for r in [1, 2]:
        for choice in [1, 2, 3]:
            for card in ['(0, 0)', '(1, 0)', '(0, 1)', '(1, 1)']:
                neuron_id_rule = data['subcg_sr_idx']['rule{}_sr_esoma'.format(r)]
                neuron_id_choice = data['subcg_sr_idx']['respc{}_sr_esoma'.format(choice)]
                neuron_id_refcard = data['subcg_sr_idx']['ref_card{}_sr_esoma'.format(card)]
                neuron_id = [n for n in neuron_id_rule if (n in neuron_id_choice and n in neuron_id_refcard)]    # all neurons that prefer a given combination of rule, choice and reference card (and therefore feature)        
                if (r==1 and (card=='(0, 0)' or card=='(0, 1)')) or (r==2 and (card=='(0, 0)' or card=='(1, 0)')):
                    feature_id = 0    # this neuron prefers when the matching feature is blue or circle
                else:
                    feature_id = 1    # this neuron prefers when the matching feature is red or square 
#                 input_neuron_id_large = [2*(r-1) + feature_id, 4+4*(choice-1) + 2*(r-1) + feature_id]    # the indices of the input neurons that should have a strong projection to this neuron
                input_neuron_id_large_ref = 2*(r-1) + feature_id
                input_neuron_id_large_test = 4+4*(choice-1) + 2*(r-1) + feature_id
                input_neuron_id_large = [input_neuron_id_large_ref, input_neuron_id_large_test]    # the indices of the input neurons that should have a strong projection to this neuron
                input_neuron_id_small = [i for i in range(data['w_in_eff'].shape[0]) if i not in input_neuron_id_large]
                for n in neuron_id:
                    dend_id = [n + (b+1)*data['n_sr_esoma'] for b in range(data['n_branches'])]    # the dendritic indices
                    w_in_large = data['w_in_eff'][np.ix_(input_neuron_id_large, dend_id)]
                    w_in_small = data['w_in_eff'][np.ix_(input_neuron_id_small, dend_id)]
                    w_in_large_ref = data['w_in_eff'][input_neuron_id_large_ref, dend_id]
                    w_in_large_test = data['w_in_eff'][input_neuron_id_large_test, dend_id]
                    
                    w_in_larges.append(np.mean(w_in_large))
                    w_in_smalls.append(np.mean(w_in_small))
                    w_in_large_refs.append(np.mean(w_in_large_ref))
                    w_in_large_tests.append(np.mean(w_in_large_test))

    w_in_large_all_models.extend(w_in_larges)
    w_in_large_ref_all_models.extend(w_in_large_refs)
    w_in_large_test_all_models.extend(w_in_large_tests)
    w_in_small_all_models.extend(w_in_smalls)




    fig, ax = plt.subplots(figsize=[10,7])
    fig.patch.set_facecolor('white')
    for i in range(len(w_in_larges)):
        ax.plot([w_in_larges[i], w_in_smalls[i]], marker='o', markersize=10, alpha=0.5, color='k')
    y = [w_in_larges, w_in_smalls]
    t, p = scipy.stats.ttest_ind(y[0], y[1])
    print('t={}, p-value={}'.format(t, p))
    ax.set_xlim([-0.2, 1.2])
    ax.set_ylabel(r'Input weight', fontsize=20)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Preferred\nfeature', 'Non-preferred\nfeature'], rotation=0, fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()
    
fig, ax = plt.subplots(figsize=[7,7])
fig.patch.set_facecolor('white')
for i in range(len(w_in_large_all_models)):
    ax.plot([w_in_large_all_models[i], w_in_small_all_models[i]], marker='o', markersize=10, alpha=0.02, color='k')
y = [w_in_large_all_models, w_in_small_all_models]
t, p = scipy.stats.ttest_ind(y[0], y[1])
print('t={}, p-value={}'.format(t, p))
ax.set_xlim([-0.3, 1.3])
ax.set_ylabel(r'Input weight', fontsize=20)
ax.set_xticks([0, 1])
make_pretty_axes(ax, labelsize=12)

# Generate data for Figure 5d (structure in the output weights)

In [None]:
start = time.time()
plt.rc('font', size=12)

all_data_wout = []

for model_name in sorted(os.listdir('/model/directory/')):
    if ('2023-05-10' in model_name) and 'wcst' in model_name and 'success' in model_name:
        print(model_name+'\n')
        
        # load model
        path_to_file = '/model/directory/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        
        # load neural data
        with open('/where/test/data/is/stored/'+model_name+'_testdata_noiseless', 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_test_perf = np.mean([_[0] for _ in test_data['perfs']])
        if mean_test_perf<=0.8:
            print('perf too low ({}), pass\n'.format(mean_test_perf))
            continue
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                         rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                         rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                         resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                         stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)

        # define neuron pools
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=all_sels['rule_normalized_activity'], resp_sel=all_sels['resp_normalized'], ref_card_sel=all_sels['ref_card_normalized'],
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule_threshold=0, resp_threshold=0)
        for subcg in subcg_sr_idx.keys():
            model.rnn.cg_idx['subcg_sr_'+subcg] = subcg_sr_idx[subcg]
            
        w_out_eff = model.rnn.effective_weight(w=model.rnn.w_out, mask=model.mask_out).detach().cpu().numpy()

        all_data_wout.append({
                              'model': model, 
                              'model_name': model_name, 
                              'hp': hp_test, 
                              'w_out_eff': w_out_eff,
                              'subcg_sr_idx': subcg_sr_idx
                             })
                
print('Elapsed time: {}s'.format(time.time()-start))     

In [None]:
wout_pref_all = []
wout_nonpref_all = []

for data in all_data_wout:
    if data['hp']['dend_nonlinearity']!='divisive_2':
        continue
    print(data['model_name'])
    wout_pref = []
    wout_nonpref = []
    wout = data['w_out_eff']
    for resp in ['c1', 'c2', 'c3']:
        neuron_idx = data['subcg_sr_idx']['resp{}_sr_esoma'.format(resp)]
        if resp=='c1':
            wout_idx_pref, wout_idx_nonpref = 0, [1, 2]
        elif resp=='c2':
            wout_idx_pref, wout_idx_nonpref = 1, [0, 2]
        elif resp=='c3':
            wout_idx_pref, wout_idx_nonpref = 2, [0, 1]
        wout_pref.extend(wout[neuron_idx, wout_idx_pref])
        wout_nonpref.extend(np.mean(wout[np.ix_(neuron_idx, wout_idx_nonpref)], axis=1))
        wout_pref_all.extend(wout[neuron_idx, wout_idx_pref])
        wout_nonpref_all.extend(np.mean(wout[np.ix_(neuron_idx, wout_idx_nonpref)], axis=1))
        
    fig, ax = plt.subplots(figsize=[10,7])
    fig.patch.set_facecolor('white')
    ax.plot([0, 1], [wout_pref, wout_nonpref], color='k', marker='o', alpha=0.5)
    ax.set_xlim([-0.2, 1.2])
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Preferred\nchoice', 'Non-preferred\nchoice'], rotation=0)
    ax.set_ylabel('Readout weight', fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()
    
print('all models')
fig, ax = plt.subplots(figsize=[7,7])
fig.patch.set_facecolor('white')
ax.plot([0, 1], [wout_pref_all, wout_nonpref_all], color='k', marker='o', alpha=0.02)
ax.set_xlim([-0.3, 1.3])
ax.set_xticks([0, 1])
ax.set_xticklabels(['Preferred\nchoice', 'Non-preferred\nchoice'], rotation=0)
ax.set_ylabel('Readout weight', fontsize=20)
make_pretty_axes(ax)
fig.tight_layout()
plt.show()