In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np; np.set_printoptions(precision=2, threshold=100); np.random.seed(0)
import torch; torch.set_printoptions(precision=2, threshold=100)
seed = 1 

torch.manual_seed(seed)
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rc('font', size=12)
import matplotlib   

from matplotlib.font_manager import FontProperties
from mpl_toolkits import mplot3d
import matplotlib.pylab as plt

import seaborn as sns
import time
import sys 
import itertools
import random; random.seed(0)
import datetime
import pickle
import copy
import pandas as pd
import scipy
import os


import sys
# from task import *
from functions import *
# from train import *
# from model import *


print(torch.__version__)
print(sys.version)
                
%matplotlib inline

# Generate data

In [None]:
start = time.time()
plt.rc('font', size=18)
perfs_across_models = []
for model_name in sorted(os.listdir('/where/models/are/stored/')):    # replace with the model directory on your machine
    if ('2023-05-10' in model_name) and 'success' in model_name:    # sub-select models as you wish
        print(model_name)
        path_to_file = '/where/models/are/stored/' + model_name
        with HiddenPrints():
            model, hp_test, hp_task_test, optim, saved_data = load_model_v2(path_to_file=path_to_file, model_name=model_name, simple=False, plot=False, toprint=False)
        if hp_test['dt']!=10:
            print('pass\n')
            continue    # put the filtering condition here
        
        # make noiseless
        model.rnn.network_noise = 0
        model.output_noise = 0
        hp_test['input_noise_perceptual'] = 0
        hp_test['input_noise_rule'] = 0
        

        with open('/where/data/for/test/run/is/stored/'+model_name+'_testdata_noiseless', 'rb') as f:
            neural_data = pickle.load(f)
        perfs = neural_data['test_data']['perfs']
        perf_rules = neural_data['test_data']['perf_rules']
        mean_perf = np.mean([_[0] for _ in perfs])
        rules = neural_data['test_data']['rules']
        switch_trials = [tr for tr in range(len(rules)-1) if rules[tr]!=rules[tr+1]]
        
        mean_perf_rule = np.mean([_[0] for _ in perf_rules])
        perfs_across_models.append({'model': model_name, 'mean_perf': mean_perf, 'mean_perf_rule': mean_perf_rule, 'switch_trials': switch_trials, 'switch_to_color_trials': switch_to_color_trials, 'switch_to_shape_trials': switch_to_shape_trials, 'perfs': perfs, 'perf_rules': perf_rules})
print(time.time()-start)

In [None]:
def plot_perf_after_switch(switch_trials, perfs, n_trs_max=10, title='Figure'):
    """ 
        Plot the performance aligned to rule switch
        
        Args:
            n_trs_max: plot the performance for how many number of trials after switch 

    """
    
    fig, ax = plt.subplots(2, 1, figsize=[7, 10])
    fig.suptitle(title)
    ax[0].set_xlabel('Trial after switch')
    ax[1].set_xlabel('Trial after switch')
    ax[0].set_ylabel('Perf')
    ax[1].set_ylabel('Perf for rule')
    for i in range(2):
        ax[i].set_xticks(np.arange(n_trs_max))
#     ax[0].set_xlim([1, 50])
#     ax[1].set_xlim([1, 50])
    # the performance n trials after a switch
    perf_after_switch = dict.fromkeys(np.arange(0, n_trs_max))
    perf_rule_after_switch = dict.fromkeys(np.arange(0, n_trs_max))
    for key in perf_after_switch.keys():    
        perf_after_switch[key] = []
        perf_rule_after_switch[key] = []
    for n_switch in range(len(switch_trials)-1):
        current_switch = switch_trials[n_switch]
        print('current_switch={}'.format(current_switch))
#         next_switch = switch_trials[n_switch+1]
#         ax[0].plot(perfs[current_switch:next_switch], alpha=0.25, color='gray')
#         ax[1].plot(perf_rules[current_switch:next_switch], alpha=0.25, color='gray')
        for tr in range(0, n_trs_max):
            if tr==0:
                perf_after_switch[tr].append(1-perfs[current_switch+tr])
                perf_rule_after_switch[tr].append(1-perf_rules[current_switch+tr])
            else:
                perf_after_switch[tr].append(perfs[current_switch+tr])
                perf_rule_after_switch[tr].append(perf_rules[current_switch+tr])
    print('perf_after_switch={}'.format(perf_after_switch))
    x = np.arange(0, n_trs_max)
    y = [np.mean(perf_after_switch[tr]) for tr in perf_after_switch.keys()]
    print('y={}'.format(y))
#     y_err = [np.std(perf_after_switch[tr]) for tr in perf_after_switch.keys()]
    y_err = [scipy.stats.sem(perf_after_switch[tr]) for tr in perf_after_switch.keys()]
    y_rule = [np.mean(perf_rule_after_switch[tr]) for tr in perf_after_switch.keys()]
#     y_err_rule = [np.std(perf_rule_after_switch[tr]) for tr in perf_after_switch.keys()]
    y_err_rule = [scipy.stats.sem(perf_rule_after_switch[tr]) for tr in perf_after_switch.keys()]
    ax[0].errorbar(x=x, y=y, yerr=y_err, color='gray', marker='s', fillstyle='none')
    ax[1].errorbar(x=x, y=y_rule, yerr=y_err_rule, color='gray', marker='s', fillstyle='none')
    for i in range(2):
        ax[i].axhline(y=1/3, color='k', linestyle='dotted')
        make_pretty_axes(ax[i])
    
    plt.show()

# Producing Figure 1 c, d, f
The name of the model in Figure 1f: success_2023-12-22-16-37-35_wcst_15_early_stopping_correct2

Please re-run the first block by replacing '2023-05-10' with '2023-12-22' before running the following block

In [None]:
# plot performance as a function of trial number

for data in perfs_across_models:
    print(data['model'])
    
    perfs = [_[0] for _ in data['perfs']]
    perf_rules = [_[0] for _ in data['perf_rules']]
    switch_trials = data['switch_trials']
    
    fig, ax = plt.subplots(1, 1, figsize=[50, 10])
    fig.suptitle('Performance')
    ax.set_xlabel('Trial')
    ax.set_ylabel('Perf')
    ax.set_xlim([0, 1000])
    ax.plot(perfs, color='k', linewidth=2)
    for tr in switch_trials:
        ax.axvline(x=tr+1, color='#e6550d', linewidth=5)
    make_pretty_axes(ax)
    plt.show()

    
    plot_perf_after_switch(switch_trials=data['switch_trials'], perfs=perfs, title='All switches')