# Figure 2: emergence of two populations of exc neurons in the PFC module

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=2); np.random.seed(0)
import torch; torch.set_printoptions(precision=2)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl

import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import datetime
import pickle
import copy
import pandas as pd
import scipy
import os

from sklearn.cluster import KMeans
from sklearn.manifold import MDS
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity

import sys
from functions import *


print(torch.__version__)
print(sys.version)
                
%matplotlib inline

## Figure 3b, error input weight x rule modulation, example model

In [None]:
with open('/.../input_weight_rule_sel_across_models.pickle', 'rb') as handle:
    all_data = pickle.load(handle)

In [None]:
data_fig2b = {'type': [], 'w_neg_fdbk': [], 'rule_sel': []}

for data in all_data:
    print(data['model_name'])
    if data['model_name'] != 'success_2023-05-10-14-28-42_wcst_50_sparsity0':
        continue
    fig, ax = plt.subplots(1, 1, figsize=[6,5])
    fig.patch.set_facecolor('white')
    plt.style.use('classic')
    for n in data['cg_idx']['pfc_edend']:
        branch_id = (n-data['cg_idx']['pfc_edend'][0])//len(data['cg_idx']['pfc_esoma'])+1    # this is dendrite number X
        soma_id = n-len(data['cg_idx']['pfc_esoma'])*branch_id
        w_neg_fdbk = data['w_rew_eff'][1, n]
        if soma_id in data['subcg_pfc_idx']['rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['rule2_pfc_esoma']:  
            color = 'blue'
            type = 'rule_neuron'
        elif soma_id in data['subcg_pfc_idx']['mix_err_rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['mix_err_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend:
            color = 'red'
            type = 'error_x_rule_neuron'
        elif soma_id in data['subcg_pfc_idx']['mix_corr_rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['mix_corr_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend_correct:
            continue
        else:
            continue
            
        data_fig2b['type'].append(type)
        data_fig2b['w_neg_fdbk'].append(w_neg_fdbk)
        data_fig2b['rule_sel'].append(data['rule_sel_unnormalized'][soma_id])
        
        ax.scatter(x=w_neg_fdbk, y=data['rule_sel_unnormalized'][soma_id], color=color)
    ax.set_xlabel('Input weight for negative feedback', fontsize=20)
    ax.set_ylabel('Rule modulation', fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()



## Figure 3c & Supplementary Figure 3: error input weight x rule modulation, across models

In [None]:
start = time.time()

data_fig2c = {'type': [], 'w_rew': [], 'rule_sel': []}
data_suppfig3 = {'type': [], 'w_rew': [], 'rule_sel': []}

for dend_nonlinear in ['subtractive', 'divisive_2']:
    # plt.rc('font', size=12)
    fig, ax = plt.subplots(1, 1, figsize=[7.5, 6])
    fig.suptitle(dend_nonlinear)
    fig.patch.set_facecolor('white')
    plt.style.use('classic')

    # load a sample model (the indices for all models are the same)
    model_name = all_data[0]['model_name']
    path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
    with HiddenPrints():
        model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)
    
    for x in all_data:
        print(x['model_name'])
        if x['hp']['dend_nonlinearity'] != dend_nonlinear:
            continue
        subcg_pfc_idx = x['subcg_pfc_idx']
        w_rew_eff = x['w_rew_eff']
        rule_sel = x['rule_sel_unnormalized']
        
        for n in x['cg_idx']['pfc_edend']:
            branch_id = (n-x['cg_idx']['pfc_edend'][0])//len(x['cg_idx']['pfc_esoma'])+1    # this is dendrite number X
            soma_id = n-len(x['cg_idx']['pfc_esoma'])*branch_id
            if soma_id in subcg_pfc_idx['rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['rule2_pfc_esoma']:  
                color = 'blue'
                type = 'rule_neuron'
            elif soma_id in subcg_pfc_idx['mix_err_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_err_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend:
                color = 'red' 
                type = 'error_x_rule_neuron'
            elif soma_id in subcg_pfc_idx['mix_corr_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_corr_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend_correct:
                continue
            else:
                continue
            
            if x['hp']['dend_nonlinearity']=='subtractive':
                data_fig2c['type'].append(type)
                data_fig2c['w_rew'].append(w_rew_eff[1, n])
                data_fig2c['rule_sel'].append(rule_sel[soma_id])
            if x['hp']['dend_nonlinearity']=='divisive_2':
                data_suppfig3['type'].append(type)
                data_suppfig3['w_rew'].append(w_rew_eff[1, n])
                data_suppfig3['rule_sel'].append(rule_sel[soma_id])
    
            
            ax.scatter(x=w_rew_eff[1, n], y=rule_sel[soma_id], color=color, alpha=0.2)
    ax.set_xlabel('Input weight for the negative feedback signal', fontsize=20)
    ax.set_ylabel('Rule modulation', fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()
    



print(time.time()-start)

# Supplementary Figure 6: input weight for negative feedback x rule modulation, for slow-switching models

In [None]:
with open('/.../input_weight_rule_sel_across_models_slow.pickle', 'rb') as handle:
    all_data_slow = pickle.load(handle)

# TEST - don't accidently include fast-switching models!
for x in all_data_slow:
    print(x['model_name'])

In [None]:
start = time.time()

data_suppfig6a = {'type': [], 'w_rew': [], 'rule_sel': []}

fig, ax = plt.subplots(1, 1, figsize=[7.5, 6])
fig.suptitle('slow switching models')
fig.patch.set_facecolor('white')
plt.style.use('classic')

for x in all_data_slow:
    print(x['model_name'])
    subcg_pfc_idx = x['subcg_pfc_idx']
    w_rew_eff = x['w_rew_eff']
    rule_sel = x['rule_sel_unnormalized']
    
    for n in x['cg_idx']['pfc_edend']:
        branch_id = (n-x['cg_idx']['pfc_edend'][0])//len(x['cg_idx']['pfc_esoma'])+1    # this is dendrite number X
        soma_id = n-len(x['cg_idx']['pfc_esoma'])*branch_id
        if soma_id in subcg_pfc_idx['rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['rule2_pfc_esoma']:  
            color = 'blue'
            type = 'rule_neuron'
        elif soma_id in subcg_pfc_idx['mix_err_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_err_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend:
            color = 'red' 
            type = 'error_x_rule_neuron'
        elif soma_id in subcg_pfc_idx['mix_corr_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_corr_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend_correct:
            continue
        else:
            continue
        data_suppfig6a['type'].append(type)
        data_suppfig6a['w_rew'].append(w_rew_eff[1, n])
        data_suppfig6a['rule_sel'].append(rule_sel[soma_id])
        if np.isnan(rule_sel[soma_id]):
            print(rule_sel[soma_id])
        ax.scatter(x=w_rew_eff[1, n], y=rule_sel[soma_id], color=color, alpha=0.2)
ax.set_xlabel('Input weight for the negative feedback signal', fontsize=20)
ax.set_ylabel('Rule modulation', fontsize=20)
make_pretty_axes(ax)
fig.tight_layout()
plt.show()



print(time.time()-start)