# Figure 2: emergence of two populations of exc neurons in the PFC module

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=2); np.random.seed(0)
import torch; torch.set_printoptions(precision=2)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl

import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import datetime
import pickle
import copy
import pandas as pd
import scipy
import os

from sklearn.cluster import KMeans
from sklearn.manifold import MDS
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity

import sys
from functions import *


print(torch.__version__)
print(sys.version)
                
%matplotlib inline

# Figure 2a example neurons

In [None]:
# plot single cell activity and input     
    
n_to_plot = 100

for model_name in sorted(os.listdir('/model/directory/')):
    if ('2023-05-10' in model_name) and 'success' in model_name: 
        print(model_name+'\n')
        
#         # load model
        path_to_file = '/model/directory/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        with open('/where/run/data/is/stored/{}_testdata_noiseless'.format(model_name), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        current_matrix = neural_data['current_matrix']
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        trs_by_center_card = trial_labels['trs_by_center_card_stable']
        
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                    rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                    rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                    resp_trs_stable = resp_trs_stable, trs_by_center_card=trs_by_center_card,
                                    stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)
#         resp_sel_normalized = all_sels['resp_normalized']
        rule_sel_normalized = all_sels['rule_normalized_activity']
        error_sel_normalized = all_sels['error_normalized']
#         ref_card_sel_normalized = all_sels['ref_card_normalized']
        
        # define subpopulations
        subcg_pfc_idx = define_subpop_pfc(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=rule_sel_normalized, err_sel=error_sel_normalized, 
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule1_after_error_trs=trial_labels['rule1_trs_after_error'],
                                          rule2_after_error_trs=trial_labels['rule2_trs_after_error'],
                                          rule1_after_correct_trs=trial_labels['rule1_trs_after_correct'],
                                          rule2_after_correct_trs=trial_labels['rule2_trs_after_correct'],
                                          rule_threshold=0.5, err_threshold=0.5)
        subcg_sr_idx = define_subpop_sr_wcst(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=rule_sel_normalized, resp_sel=resp_sel_normalized, ref_card_sel=ref_card_sel_normalized,
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule_threshold=0, resp_threshold=0)
        for subcg in subcg_pfc_idx.keys():
            model.rnn.cg_idx['subcg_pfc_'+subcg] = subcg_pfc_idx[subcg]
        for subcg in subcg_sr_idx.keys():
            model.rnn.cg_idx['subcg_sr_'+subcg] = subcg_sr_idx[subcg]
            
        
        
        #=== analysis ===#
        plot_info_rulexerror = [{'name': 'rule 1, after correct', 'trials': rule1_trs_stable, 'color': 'blue', 'ls': '-'},
                               {'name': 'rule 1, after error', 'trials': trial_labels['rule1_trs_after_error'], 'color': 'blue', 'ls': '--'},
                               {'name': 'rule 2, after correct', 'trials': rule2_trs_stable, 'color': 'green', 'ls': '-'},
                               {'name': 'rule 2, after error', 'trials': trial_labels['rule2_trs_after_error'], 'color': 'green', 'ls': '--'}]
        
            
        n_plotted = 0

        all_pfc_units_sorted = []
        for cg in subcg_pfc_idx.keys():
            if 'esoma' in cg or 'pv' in cg:
                all_pfc_units_sorted.extend(subcg_pfc_idx[cg])
                
        for n in all_pfc_units_sorted:
            
            if n_plotted>=n_to_plot:
                break
                
            print('rule sel normalized: {:.4f}, resp sel normalized: {}, error sel: {:.4f}'.format(rule_sel_normalized[n], list(resp_sel_normalized[n].values()), error_sel_normalized[n]))
            
            
            cg = [cg for cg in list(subcg_sr_idx.keys())if n in subcg_sr_idx[cg]] + [cg for cg in list(subcg_pfc_idx.keys())if n in subcg_pfc_idx[cg]]
                
                
            # plot the single cell activity
            fig, ax = plt.subplots(figsize=[6, 4])
            fig.patch.set_facecolor('white')
            plt.style.use('classic')
            fig.suptitle('{}\n Unit {}. Cg={}'.format(model_name, n, cg))
            plot_single_cell(ax=ax, n=n, rnn_activity=rnn_activity, plot_info=plot_info_rulexerror, hp_task=hp_task_test, hp=hp_test, legend_fontsize=20)
            fig.tight_layout()
            plt.show()
            
            n_plotted += 1

# Generate data for Figure 2b, c

In [None]:
start = time.time()

all_data = []
plt.rc('font', size=12)

for model_name in sorted(os.listdir('/model/directory/')):
    if ('2023-05-10' in model_name) and 'success' in model_name:    
        print(model_name)
        
#         # load model
        path_to_file = '/model/directory/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)

        with open('/where/test/run/data/is/saved/{}'.format(model_name+'_testdata_noiseless'), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        
        # generate trial labels
         # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        rule1_trs_after_error = trial_labels['rule1_trs_after_error']
        rule2_trs_after_error = trial_labels['rule2_trs_after_error']
        c1_trs_stable = trial_labels['c1_trs_stable']
        c2_trs_stable = trial_labels['c2_trs_stable']
        c3_trs_stable = trial_labels['c3_trs_stable']
        resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
        error_trials = trial_labels['error_trials']
        rule1_after_correct_trs = trial_labels['rule1_trs_after_correct']
        rule2_after_correct_trs = trial_labels['rule2_trs_after_correct']
                
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'], 
                                     rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                     rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                     resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                     stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)
            
            
        
        # define subpopulations within PFC
        subcg_pfc_idx = define_subpop_pfc(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=rule_sel_used, err_sel=error_sel_used, rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, rule1_after_error_trs=trial_labels['rule1_trs_after_error'],
                                          rule2_after_error_trs=trial_labels['rule2_trs_after_error'],
                                          rule1_after_correct_trs=rule1_after_correct_trs, rule2_after_correct_trs=rule2_after_correct_trs,
                                          rule_threshold=0.5, err_threshold=0.5, toprint=False)
        for subcg in subcg_pfc_idx.keys():
            model.rnn.cg_idx['subcg_pfc_'+subcg] = subcg_pfc_idx[subcg]
            
            
            
        #=== analysis ===#
        w_rew_eff = model.rnn.effective_weight(w=model.rnn.w_rew, mask=model.rnn.mask_rew).detach().cpu().numpy()
        rule_sel_used_unnormalized = all_sels['rule_activity']
        err_sel_unnormalized = all_sels['error']
        err_sel_normalized = all_sels['error_normalized']

        all_data.append({
                         'model_name': model_name, 
                         'model': model,
                         'hp': hp_test,
                         'w_rew_eff': w_rew_eff, 
                         'subcg_pfc_idx': subcg_pfc_idx, 
                         'rule_sel_unnormalized': rule_sel_used_unnormalized, 
                         'err_sel_unnormalized': err_sel_unnormalized,
                         'err_sel_normalized': err_sel_normalized,
                         'rule_sel_normalized': rule_sel_used,
                         'rule_sel_aftererr': all_sels['rule_aftererr'], 
                         'rule_sel_aftererr_normalized': all_sels['rule_aftererr_roc'],
                         'all_sels': all_sels,
                         'mean_perf': mean_perf,
                         'mean_perf_rule': mean_perf_rule})

print(time.time()-start)

# Figure 2b, error input weight x rule modulation, example model

In [None]:
for data in all_data:
    print(data['model_name'])
    model = data['model']
    fig, ax = plt.subplots(1, 1, figsize=[6,5])
    fig.suptitle('perf: {:.4f}, perf_rule: {:.4f}'.format(data['mean_perf'], data['mean_perf_rule']), fontsize=30)
    fig.patch.set_facecolor('white')
    plt.style.use('classic')
    for n in model.rnn.cg_idx['pfc_edend']:
        branch_id = (n-model.rnn.cg_idx['pfc_edend'][0])//len(model.rnn.cg_idx['pfc_esoma'])+1    # this is dendrite number X
        soma_id = n-len(model.rnn.cg_idx['pfc_esoma'])*branch_id
        w_neg_fdbk = data['w_rew_eff'][1, n]
        if soma_id in data['subcg_pfc_idx']['rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['rule2_pfc_esoma']:  
            color = 'blue'
        elif soma_id in data['subcg_pfc_idx']['mix_err_rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['mix_err_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend:
            color = 'red'
        elif soma_id in data['subcg_pfc_idx']['mix_corr_rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['mix_corr_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend_correct:
            continue
        else:
            continue
        ax.scatter(x=w_neg_fdbk, y=data['rule_sel_unnormalized'][soma_id], color=color)
    ax.set_xlabel('Input weight for negative feedback', fontsize=20)
    ax.set_ylabel('Rule modulation', fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()

# Figure 2c, error input weight x rule modulation, across models

In [None]:
start = time.time()

plt.rc('font', size=12)
fig, ax = plt.subplots(1, 1, figsize=[6, 5])
fig.patch.set_facecolor('white')
plt.style.use('classic')

# load a sample model (the indices of different cell groups are the same for all models)
model_name = all_data[0]['model_name']
path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
with HiddenPrints():
    model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)

for x in all_data:
    print(x['model_name'])
    subcg_pfc_idx = x['subcg_pfc_idx']
    w_rew_eff = x['w_rew_eff']
    rule_sel = x['rule_sel_unnormalized']
    rule_sel_aftererr = x['rule_sel_aftererr']
    for n in model.rnn.cg_idx['pfc_edend']:
        branch_id = (n-model.rnn.cg_idx['pfc_edend'][0])//len(model.rnn.cg_idx['pfc_esoma'])+1    # this is dendrite number X
        soma_id = n-len(model.rnn.cg_idx['pfc_esoma'])*branch_id
        if soma_id in subcg_pfc_idx['rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['rule2_pfc_esoma']:  
            color = 'blue'
        elif soma_id in subcg_pfc_idx['mix_err_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_err_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend:
            color = 'red' 
        elif soma_id in subcg_pfc_idx['mix_corr_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_corr_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend_correct:
            continue
        else:
            continue
#             color = 'black'
        ax.scatter(x=w_rew_eff[1, n], y=rule_sel[soma_id], color=color, alpha=0.2)
ax.set_xlabel('Input weight for the negative feedback signal', fontsize=20)
ax.set_ylabel('Rule modulation', fontsize=20)
ax.tick_params(axis='x', 
                    direction='out', 
                    which='both',      # both major and minor ticks are affected
                    bottom=True,      # ticks along the bottom edge are off
                    top=False,         # ticks along the top edge are off
                    labelbottom=True,
                    labelsize=20)
ax.tick_params(axis='y', 
                    direction='out', 
                    which='both',      # both major and minor ticks are affected
                    left=True,      # ticks along the bottom edge are off
                    right=False,         # ticks along the top edge are off
                    labelleft=True,
                    labelsize=20)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
fig.tight_layout()
plt.show()

print(time.time()-start)