# Figure 2: emergence of two populations of exc neurons in the PFC module

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np; np.set_printoptions(precision=2); np.random.seed(0)
import torch; torch.set_printoptions(precision=2)
seed = 1

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib 
from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as pl

import seaborn as sns
import time
import sys
import itertools
import random; random.seed(0)
import datetime
import pickle
import copy
import pandas as pd
import scipy
import os

from sklearn.cluster import KMeans
from sklearn.manifold import MDS
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity

import sys
sys.path.append("../two_module_rnn/code")
os.chdir('/home/yl4317/Documents/two_module_rnn/code')
# from model_working import *
# from task import *
from functions import *

# os.chdir('/home/yl4317/Documents/two_module_rnn/')

print(torch.__version__)
print(sys.version)
                
%matplotlib inline

## Figure 2a example neurons

In [None]:
for model_name in sorted(os.listdir('/scratch/yl4317/two_module_rnn/saved_models/')):
    if ('2023-05-10' in model_name) and 'success' in model_name: 
        
        if model_name != 'success_2023-05-10-14-28-42_wcst_136_sparsity0':
            continue    # this is the example used in the paper
        
        # load model
        path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file,model_name=model_name, simple=False, plot=False, toprint=False)
        
        # load test data
        with open('/scratch/yl4317/two_module_rnn/saved_testdata/{}_testdata_noiseless_no_current_matrix'.format(model_name), 'rb') as f:
            neural_data = pickle.load(f)
        test_data = neural_data['test_data']
        mean_perf = np.mean([_[0] for _ in test_data['perfs']])
        mean_perf_rule = np.mean([_[0] for _ in test_data['perf_rules']])
        rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
        
        # generate trial labels
        trial_labels = label_trials_wcst(test_data=test_data)
        rule1_trs_stable = trial_labels['rule1_trs_stable']
        rule2_trs_stable = trial_labels['rule2_trs_stable']
        resp_trs_stable = {'c1': trial_labels['c1_trs_stable'], 'c2': trial_labels['c2_trs_stable'], 'c3': trial_labels['c3_trs_stable']}
        
        # compute cell selectivity
        all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'],
                                    rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                    rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                    resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                    stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)
        resp_sel_normalized = all_sels['resp_normalized']
        rule_sel_normalized = all_sels['rule_normalized_activity']
        error_sel_normalized = all_sels['error_normalized']
        ref_card_sel_normalized = all_sels['ref_card_normalized']
        
        # define subpopulations
        subcg_pfc_idx = define_subpop_pfc(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                          rule_sel=rule_sel_normalized, err_sel=error_sel_normalized, 
                                          rule1_trs_stable=rule1_trs_stable, 
                                          rule2_trs_stable=rule2_trs_stable, 
                                          rule1_after_error_trs=trial_labels['rule1_trs_after_error'],
                                          rule2_after_error_trs=trial_labels['rule2_trs_after_error'],
                                          rule1_after_correct_trs=trial_labels['rule1_trs_after_correct'],
                                          rule2_after_correct_trs=trial_labels['rule2_trs_after_correct'],
                                          rule_threshold=0.5, err_threshold=0.5)
        for subcg in subcg_pfc_idx.keys():
            model.rnn.cg_idx['subcg_pfc_'+subcg] = subcg_pfc_idx[subcg]
            
        
        
        #=== analysis ===#
        plot_info_rulexerror = [{'name': 'rule 1, after correct', 'trials': rule1_trs_stable, 'color': 'blue', 'ls': '-'},
                               {'name': 'rule 1, after error', 'trials': trial_labels['rule1_trs_after_error'], 'color': 'blue', 'ls': '--'},
                               {'name': 'rule 2, after correct', 'trials': rule2_trs_stable, 'color': 'green', 'ls': '-'},
                               {'name': 'rule 2, after error', 'trials': trial_labels['rule2_trs_after_error'], 'color': 'green', 'ls': '--'}]

        
        for n in model.rnn.cg_idx['pfc_esoma']:
            if n not in [252, 279, 290, 250]:    # these are the example units shown in the paper. Feel free to check out other units as well!
                continue
            cg = [cg for cg in list(subcg_pfc_idx.keys())if n in subcg_pfc_idx[cg]]

            # plot the single cell activity
            fig, ax = plt.subplots(figsize=[8, 4])
            fig.patch.set_facecolor('white')
            plt.style.use('classic')
            fig.suptitle('{}\n Unit {}. Cell group={}'.format(model_name, n, cg))
            plot_single_cell(ax=ax, n=n, rnn_activity=rnn_activity, plot_info=plot_info_rulexerror, hp_task=hp_task_test, hp=hp_test, legend_fontsize=20)
            ax.legend(bbox_to_anchor=[1, 0.5], loc='center left')
            fig.tight_layout()
            plt.show()

            if n == 252:
                cell_group = 'rule1'
            elif n == 279:
                cell_group = 'rule2'
            elif n == 290:
                cell_group = 'errorxrule1'
            elif n == 250:
                cell_group = 'errorxrule2'
            # fig.savefig('/home/yl4317/Documents/two_module_rnn/figs/example_units_error_rule_{}.pdf'.format(cell_group))


            
# generate source data
data_fig2a_example_units = {}
for cell_group in ['rule1', 'rule2', 'errorxrule1', 'errorxrule2']:
    if cell_group == 'rule1':
        n = 252
    elif cell_group == 'rule2':
        n = 279
    elif cell_group == 'errorxrule1':
        n = 290
    elif cell_group == 'errorxrule2':
        n = 250
    data_fig2a_example_units[cell_group] = dict.fromkeys(['rule1_aftercorr', 'rule2_aftercorr', 'rule1_aftererr', 'rule2_aftererr'])
    for key in data_fig2a_example_units[cell_group].keys():
        data_fig2a_example_units[cell_group][key] = {'trial_average': [], 'single_trial': []}
        if key == 'rule1_aftercorr':
            trials = rule1_trs_stable
        elif key == 'rule2_aftercorr':
            trials = rule2_trs_stable
        elif key == 'rule1_aftererr':
            trials = trial_labels['rule1_trs_after_error']
        elif key == 'rule2_aftererr':
            trials = trial_labels['rule2_trs_after_error']
        data_fig2a_example_units[cell_group][key]['trial_average'] = np.mean(rnn_activity[trials, :, 0, n], axis=0)
        data_fig2a_example_units[cell_group][key]['single_trial'] = rnn_activity[trials, :, 0, n]
for cell_group in data_fig2a_example_units.keys():
    for trial_type in data_fig2a_example_units[cell_group].keys():
        pd.DataFrame.from_dict(data=data_fig2a_example_units[cell_group][trial_type], orient='index').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/fig2a_example_units_{}_{}.csv'.format(cell_group, trial_type), header=False)

## Generate data

In [None]:
start = time.time()


plt.rc('font', size=12)

for condition in ['fast_switching', 'slow_switching']:
    
    all_data = []
    
    for model_name in sorted(os.listdir('/scratch/yl4317/two_module_rnn/saved_models/')):
    #     if '2022-08-22-17-05-50' in model_name and 'success' in model_name:
    #     if ('2023-12-22' in model_name) and 'success' in model_name and ('15' in model_name):
    #     if ('2024-01-02' in model_name) and 'success' in model_name and ('23' in model_name or '46' in model_name or '48'in model_name or '62' in model_name or '68' in model_name or '77' in model_name):
    #     if ('2024-01-04' in model_name) and 'success' in model_name and ('_20_' in model_name or '_3_' in model_name or '_8_' in model_name):
        if (condition == 'fast_switching' and '2023-05-10' in model_name and 'success' in model_name) or (condition == 'slow_switching' and ('2023-12-22' in model_name or '2024-01-02' in model_name or '2024-01-04' in model_name) and 'success' in model_name):  
        # if ('2024-01-02' in model_name or '2024-01-04' in model_name) and 'success' in model_name:
        
            
            print(condition, model_name)
            
            # load model
            path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
            with HiddenPrints():
                model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)
            if hp_test['dt']!=10:
                print('pass\n')
                continue
            if hp_test['dend_nonlinearity'] not in ['subtractive', 'divisive_2']:
                continue
            if len(hp_test['cell_group_list'])==2:
                print('pass')
                continue
    

            # load test data
            if condition == 'fast_switching':
                appendix = '_testdata_noiseless_no_current_matrix'
            elif condition == 'slow_switching':
                appendix = '_testdata_noiseless_moreblocks_17in20'
            with open('/scratch/yl4317/two_module_rnn/saved_testdata/{}'.format(model_name+appendix), 'rb') as f:
                neural_data = pickle.load(f)
            test_data = neural_data['test_data']
            rnn_activity = neural_data['rnn_activity'].detach().cpu().numpy()
            mean_perf = np.mean([np.mean(_) for _ in neural_data['test_data']['perfs']])
            print('mean_perf={}'.format(mean_perf))
            if (condition == 'fast_switching' and mean_perf < 0.7) or (condition == 'slow_switching' and mean_perf < 0.5):
                print('low perf, pass')
                continue
            mean_perf_rule = np.mean([np.mean(_) for _ in neural_data['test_data']['perf_rules']])
            
    #         current_matrix = neural_data['current_matrix']
            
            # generate trial labels
            if hp_test['task']=='cxtdm':
                trial_labels = label_trials_cxtdm(test_data=test_data)
        #         trial_labels_all_models[model_name] = trial_labels
                rule1_trs_stable = trial_labels['rule1_trs_stable']
                rule2_trs_stable = trial_labels['rule2_trs_stable']
                left_trs_stable = trial_labels['left_trs_stable']
                right_trs_stable = trial_labels['right_trs_stable']
                rule1_trs_after_error = trial_labels['rule1_trs_after_error']
                rule2_trs_after_error = trial_labels['rule2_trs_after_error']
                error_trials = trial_labelsbels['error_trials']
            elif hp_test['task']=='wcst':
                 # generate trial labels
                trial_labels = label_trials_wcst(test_data=test_data)
                rule1_trs_stable = trial_labels['rule1_trs_stable']
                rule2_trs_stable = trial_labels['rule2_trs_stable']
                rule1_trs_after_error = trial_labels['rule1_trs_after_error']
                rule2_trs_after_error = trial_labels['rule2_trs_after_error']
                c1_trs_stable = trial_labels['c1_trs_stable']
                c2_trs_stable = trial_labels['c2_trs_stable']
                c3_trs_stable = trial_labels['c3_trs_stable']
                resp_trs_stable = {'c1': c1_trs_stable, 'c2': c2_trs_stable, 'c3': c3_trs_stable}    # to be used as an argument in the "compute_sel_wcst" function
                error_trials = trial_labels['error_trials']
                rule1_after_correct_trs = trial_labels['rule1_trs_after_correct']
                rule2_after_correct_trs = trial_labels['rule2_trs_after_correct']
                    
            # compute cell selectivity
            if hp_test['task']=='cxtdm':
                all_sels = compute_sel_cxtdm(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, 
                                             rule1_trs_stable=rule1_trs_stable, rule2_trs_stable=rule2_trs_stable, 
                                             rule1_trs_after_error=rule1_trs_after_error, rule2_trs_after_error=rule2_trs_after_error,
                                             left_trs_stable=left_trs_stable, right_trs_stable=right_trs_stable, stims=test_data['stims'], error_trials=error_trials)
                rule_sel_normalized_activity = all_sels['rule_normalized_activity']
                rule_sel_activity = all_sels['rule_activity']
                error_sel_used = all_sels['error_normalized']
            elif hp_test['task']=='wcst':
                all_sels = compute_sel_wcst(rnn_activity=rnn_activity, hp=hp_test, hp_task=hp_task_test, rules=test_data['rules'], 
                                             rule1_trs_stable=trial_labels['rule1_trs_stable'], rule2_trs_stable=trial_labels['rule2_trs_stable'],
                                             rule1_trs_after_error = trial_labels['rule1_trs_after_error'], rule2_trs_after_error=trial_labels['rule2_trs_after_error'],
                                             resp_trs_stable = resp_trs_stable, trs_by_center_card=trial_labels['trs_by_center_card_stable'],
                                             stims=test_data['stims'], error_trials=trial_labels['error_trials'], trial_labels=trial_labels)
                
                
            
            # define subpopulations within PFC
            rule_sel_used = all_sels['rule_normalized_activity']
            error_sel_used = all_sels['error_normalized']
            
            subcg_pfc_idx = define_subpop_pfc(model=model, hp_task=hp_task_test, hp=hp_test, rnn_activity=rnn_activity, 
                                              rule_sel=rule_sel_used, err_sel=error_sel_used, rule1_trs_stable=rule1_trs_stable, 
                                              rule2_trs_stable=rule2_trs_stable, rule1_after_error_trs=trial_labels['rule1_trs_after_error'],
                                              rule2_after_error_trs=trial_labels['rule2_trs_after_error'],
                                              rule1_after_correct_trs=rule1_after_correct_trs, rule2_after_correct_trs=rule2_after_correct_trs,
                                              rule_threshold=0.5, err_threshold=0.5, toprint=False)
            for subcg in subcg_pfc_idx.keys():
                model.rnn.cg_idx['subcg_pfc_'+subcg] = subcg_pfc_idx[subcg]
                

                
                
            #=== analysis ===#
            w_rew_eff = model.rnn.effective_weight(w=model.rnn.w_rew, mask=model.rnn.mask_rew).detach().cpu().numpy()
            rule_sel_used_unnormalized = all_sels['rule_activity']
            err_sel_unnormalized = all_sels['error']
            err_sel_normalized = all_sels['error_normalized']
            all_data.append({
                             'model_name': model_name, 
                             'cg_idx': model.rnn.cg_idx,
                             'hp': hp_test,
                             'w_rew_eff': w_rew_eff, 
                             'subcg_pfc_idx': subcg_pfc_idx, 
                             'rule_sel_unnormalized': rule_sel_used_unnormalized, 
                             'err_sel_unnormalized': err_sel_unnormalized,
                             'err_sel_normalized': err_sel_normalized,
                             'rule_sel_normalized': rule_sel_used,
                             'rule_sel_aftererr': all_sels['rule_aftererr'], 
                             'rule_sel_aftererr_normalized': all_sels['rule_aftererr_roc'],
                             'all_sels': all_sels,
                             'mean_perf': mean_perf,
                             'mean_perf_rule': mean_perf_rule})

    if condition == 'fast_switching':
        with open('/home/yl4317/Documents/two_module_rnn/processed_data/input_weight_rule_sel_across_models.pickle', 'wb') as handle:
            pickle.dump(all_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    elif condition == 'slow_switching':
        with open('/home/yl4317/Documents/two_module_rnn/processed_data/input_weight_rule_sel_across_models_slow.pickle', 'wb') as handle:
            pickle.dump(all_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
print(time.time()-start)

## Figure 3b, error input weight x rule modulation, example model

In [None]:
with open('/home/yl4317/Documents/two_module_rnn/processed_data/input_weight_rule_sel_across_models.pickle', 'rb') as handle:
    all_data = pickle.load(handle)

In [None]:
data_fig2b = {'type': [], 'w_neg_fdbk': [], 'rule_sel': []}

for data in all_data:
    print(data['model_name'])
    if data['model_name'] != 'success_2023-05-10-14-28-42_wcst_50_sparsity0':
        continue
    fig, ax = plt.subplots(1, 1, figsize=[6,5])
    fig.patch.set_facecolor('white')
    plt.style.use('classic')
    for n in data['cg_idx']['pfc_edend']:
        branch_id = (n-data['cg_idx']['pfc_edend'][0])//len(data['cg_idx']['pfc_esoma'])+1    # this is dendrite number X
        soma_id = n-len(data['cg_idx']['pfc_esoma'])*branch_id
        w_neg_fdbk = data['w_rew_eff'][1, n]
        if soma_id in data['subcg_pfc_idx']['rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['rule2_pfc_esoma']:  
            color = 'blue'
            type = 'rule_neuron'
        elif soma_id in data['subcg_pfc_idx']['mix_err_rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['mix_err_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend:
            color = 'red'
            type = 'error_x_rule_neuron'
        elif soma_id in data['subcg_pfc_idx']['mix_corr_rule1_pfc_esoma'] or soma_id in data['subcg_pfc_idx']['mix_corr_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend_correct:
            continue
        else:
            continue
            
        data_fig2b['type'].append(type)
        data_fig2b['w_neg_fdbk'].append(w_neg_fdbk)
        data_fig2b['rule_sel'].append(data['rule_sel_unnormalized'][soma_id])
        
        ax.scatter(x=w_neg_fdbk, y=data['rule_sel_unnormalized'][soma_id], color=color)
    ax.set_xlabel('Input weight for negative feedback', fontsize=20)
    ax.set_ylabel('Rule modulation', fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()

# fig.savefig('/home/yl4317/Documents/two_module_rnn/figs/inp_weight_rule_sel_example_network.pdf')
# pd.DataFrame.from_dict(data=data_fig2b, orient='columns').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/fig2a_exampleNetwork.csv', header=False)     # save source data

## Figure 3c & Supplementary Figure 3: error input weight x rule modulation, across models

In [None]:
start = time.time()

data_fig2c = {'type': [], 'w_rew': [], 'rule_sel': []}
data_suppfig3 = {'type': [], 'w_rew': [], 'rule_sel': []}

for dend_nonlinear in ['subtractive', 'divisive_2']:
    # plt.rc('font', size=12)
    fig, ax = plt.subplots(1, 1, figsize=[7.5, 6])
    fig.suptitle(dend_nonlinear)
    fig.patch.set_facecolor('white')
    plt.style.use('classic')

    # load a sample model (the indices for all models are the same)
    model_name = all_data[0]['model_name']
    path_to_file = '/scratch/yl4317/two_module_rnn/saved_models/'+model_name
    with HiddenPrints():
        model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)
    
    for x in all_data:
        print(x['model_name'])
        if x['hp']['dend_nonlinearity'] != dend_nonlinear:
            continue
        subcg_pfc_idx = x['subcg_pfc_idx']
        w_rew_eff = x['w_rew_eff']
        rule_sel = x['rule_sel_unnormalized']
        
        for n in x['cg_idx']['pfc_edend']:
            branch_id = (n-x['cg_idx']['pfc_edend'][0])//len(x['cg_idx']['pfc_esoma'])+1    # this is dendrite number X
            soma_id = n-len(x['cg_idx']['pfc_esoma'])*branch_id
            if soma_id in subcg_pfc_idx['rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['rule2_pfc_esoma']:  
                color = 'blue'
                type = 'rule_neuron'
            elif soma_id in subcg_pfc_idx['mix_err_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_err_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend:
                color = 'red' 
                type = 'error_x_rule_neuron'
            elif soma_id in subcg_pfc_idx['mix_corr_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_corr_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend_correct:
                continue
            else:
                continue
            
            if x['hp']['dend_nonlinearity']=='subtractive':
                data_fig2c['type'].append(type)
                data_fig2c['w_rew'].append(w_rew_eff[1, n])
                data_fig2c['rule_sel'].append(rule_sel[soma_id])
            if x['hp']['dend_nonlinearity']=='divisive_2':
                data_suppfig3['type'].append(type)
                data_suppfig3['w_rew'].append(w_rew_eff[1, n])
                data_suppfig3['rule_sel'].append(rule_sel[soma_id])
    
            
            ax.scatter(x=w_rew_eff[1, n], y=rule_sel[soma_id], color=color, alpha=0.2)
    ax.set_xlabel('Input weight for the negative feedback signal', fontsize=20)
    ax.set_ylabel('Rule modulation', fontsize=20)
    make_pretty_axes(ax)
    fig.tight_layout()
    plt.show()
    # fig.savefig('/home/yl4317/Documents/two_module_rnn/figs/inpWeight_rulSel_{}.pdf'.format(dend_nonlinear))

# save to csv
# pd.DataFrame.from_dict(data=data_fig2c, orient='columns').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/fig2c.csv', header=False)
# pd.DataFrame.from_dict(data=data_suppfig3, orient='columns').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/suppfig3.csv', header=False)

print(time.time()-start)

# Supplementary Figure 6: input weight for negative feedback x rule modulation, for slow-switching models

In [None]:
with open('/home/yl4317/Documents/two_module_rnn/processed_data/input_weight_rule_sel_across_models_slow.pickle', 'rb') as handle:
    all_data_slow = pickle.load(handle)

# TEST - don't accidently include fast-switching models!
for x in all_data_slow:
    print(x['model_name'])

In [None]:
start = time.time()

data_suppfig6a = {'type': [], 'w_rew': [], 'rule_sel': []}

fig, ax = plt.subplots(1, 1, figsize=[7.5, 6])
fig.suptitle('slow switching models')
fig.patch.set_facecolor('white')
plt.style.use('classic')

for x in all_data_slow:
    print(x['model_name'])
    subcg_pfc_idx = x['subcg_pfc_idx']
    w_rew_eff = x['w_rew_eff']
    rule_sel = x['rule_sel_unnormalized']
    
    for n in x['cg_idx']['pfc_edend']:
        branch_id = (n-x['cg_idx']['pfc_edend'][0])//len(x['cg_idx']['pfc_esoma'])+1    # this is dendrite number X
        soma_id = n-len(x['cg_idx']['pfc_esoma'])*branch_id
        if soma_id in subcg_pfc_idx['rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['rule2_pfc_esoma']:  
            color = 'blue'
            type = 'rule_neuron'
        elif soma_id in subcg_pfc_idx['mix_err_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_err_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend:
            color = 'red' 
            type = 'error_x_rule_neuron'
        elif soma_id in subcg_pfc_idx['mix_corr_rule1_pfc_esoma'] or soma_id in subcg_pfc_idx['mix_corr_rule2_pfc_esoma']:    # mixed_selective_neurons_id_dend_correct:
            continue
        else:
            continue
        data_suppfig6a['type'].append(type)
        data_suppfig6a['w_rew'].append(w_rew_eff[1, n])
        data_suppfig6a['rule_sel'].append(rule_sel[soma_id])
        if np.isnan(rule_sel[soma_id]):
            print(rule_sel[soma_id])
        ax.scatter(x=w_rew_eff[1, n], y=rule_sel[soma_id], color=color, alpha=0.2)
ax.set_xlabel('Input weight for the negative feedback signal', fontsize=20)
ax.set_ylabel('Rule modulation', fontsize=20)
make_pretty_axes(ax)
fig.tight_layout()
plt.show()

# fig.savefig('/home/yl4317/Documents/two_module_rnn/figs/inpWeight_rulSel_slowSwitchModels.pdf')

# save to csv
# pd.DataFrame.from_dict(data=data_suppfig6a, orient='columns').to_csv('/home/yl4317/Documents/two_module_rnn/source_data/suppfig6a.csv', header=True)

print(time.time()-start)