# Analysis script for RNNs

In [1]:
import sys
import glob
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr as spearmanr
from skfda.exploratory.stats import geometric_median as geometric_median

# Examine all losses coming from some configuration
def extract_run_data(file_seed):
    list_files = glob.glob(f'{file_seed}*.pckl')
    dict_runs = {}
    for ifile, file in enumerate(list_files):
        handle = open(file,'rb')
        pckl_file = pickle.load(handle)

        dict_runs[ifile] = {}
        dict_runs[ifile]['run_number'] = int(file.split('run')[1].split('.pckl')[0])
        dict_runs[ifile]['loss_total'] = pckl_file['loss_totals']
        dict_runs[ifile]['loss_acc_b'] = pckl_file['loss_acc_b']
        dict_runs[ifile]['loss_acc_f'] = pckl_file['loss_acc_f']
        dict_runs[ifile]['loss_cos_bf'] = pckl_file['loss_cos_bf']
        dict_runs[ifile]['pcor_amx_b'] = pckl_file['pcor_amx_b']
        dict_runs[ifile]['pcor_amx_f'] = pckl_file['pcor_amx_f']
        dict_runs[ifile]['weights_dict'] = pckl_file['weights_dict']
        dict_runs[ifile]['cos_penalty'] = pckl_file['cos_penalty']
        dict_runs[ifile]['noise_value'] = float(file.split('noise')[1][:3])/100

    return dict_runs

# Print unique configurations found in directory with training data
list_names = glob.glob("./saved_data/*.pckl")
list_names = [name[0:-11] for name in list_names]
set_names = set(list_names)
for name in list(set_names):
    print(name)


## Visualize loss trajectories for trained RNNs

In [None]:
# Plot logged loss trajectories 
''' ---------- INPUT ---------- '''
#  [number of hidden units, noise x 100, penalty amount]
in_hidden   = 64
in_noise    = 1
in_penalty  = 40
''' -------- END INPUT -------- '''
# process numerical to formatted string
str_hidden   = f'{in_hidden:03d}'
str_noise    = f'{in_noise*100:03d}'
str_penalty  = f'{in_penalty:03d}'
cfg = ['064', '100', '040'] 

fileseed = f'./saved_data/n{cfg[0]}_noise{cfg[1]}_cosPnlt_{cfg[2]}'
dict_runs = extract_run_data(fileseed)
fig, axs = plt.subplots(3, figsize=(6,10), constrained_layout=True)

# plot
for ifile in range(len(dict_runs)):
    # total loss
    loss_total_test = dict_runs[ifile]['loss_total'][1]
    axs[0].plot(loss_total_test, linewidth=.2, alpha=.6)
    axs[0].axhline(loss_total_test[-1], linewidth=.2)
    axs[0].set_title('total loss')
    # condition wise accuracy loss
    loss_acc_b = dict_runs[ifile]['loss_acc_b'][1]
    loss_acc_f = dict_runs[ifile]['loss_acc_f'][1]
    axs[1].plot(loss_acc_b, linewidth=.2, alpha=.1, color='r')
    axs[1].plot(loss_acc_f, linewidth=.2, alpha=.1, color='b')
    axs[1].axhline(loss_acc_b[-1], linewidth=.2, color='r')
    axs[1].axhline(loss_acc_f[-1], linewidth=.2, color='b')
    axs[1].set_title('acc loss')
    # visualize cosine sim loss
    loss_cos_bf = dict_runs[ifile]['loss_cos_bf'][1]
    axs[2].plot(loss_cos_bf, linewidth=.2, alpha=.6)
    axs[2].axhline(loss_cos_bf[-1], linewidth=.2)
    axs[2].set_title('alignment loss')
plt.suptitle(f'{fileseed} | n_runs={len(dict_runs)}');


# Generate an instance of the task to evaluate trained RNNs

In [None]:
from eval_noisyRNN import (
    eval_noisyRNN as eval_noisyRNN,
    generate_blocks as generate_blocks
)

''' ---------- INPUT ---------- '''
n_trials   = 72
n_episodes = 6
n_blocks   = 500
''' -------- END INPUT -------- '''

task_list = generate_blocks(n_trials, n_episodes, n_blocks) # task generator for the human task

dat_eval = {}
for file_seed in set_names:
    print(f'\nProcessing file seed ({file_seed})...')
    cfg_str = file_seed.split('/')[2]
    if cfg_str[0] != 'n':
        continue
    nunit_str = cfg_str.split('n')[1][:-1]
    nunit = int(nunit_str)
    noise_str = cfg_str.split('noise')[1][:3]
    cpnlt_str = cfg_str.split('cosPnlt_')[1]
    # check if evaluation data dictionary has appropriate key for file seed
    if nunit_str not in dat_eval.keys():
        dat_eval[nunit_str] = {}
    if noise_str not in dat_eval[nunit_str].keys():
        dat_eval[nunit_str][noise_str] = {}
    if cpnlt_str not in dat_eval[nunit_str][noise_str].keys():
        dat_eval[nunit_str][noise_str][cpnlt_str] = {}
        
    dict_runs = extract_run_data(file_seed)
    nruns = len(dict_runs)
    dat_eval[nunit_str][noise_str][cpnlt_str]['run_numbers']  = np.full(nruns, np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['accuracy']     = np.full((nruns, 2), np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['bayesacc']     = np.full((nruns, 2), np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['trace_linear'] = np.full(nruns, np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['trace_quad']   = np.full(nruns, np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['rho_pc_bayes'] = np.full((nruns, 2, nunit), np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['eigvec_b']     = np.full((nunit, nunit, nruns), np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['eigvec_f']     = np.full((nunit, nunit, nruns), np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['cos_dec_bayes']= np.full(nruns, np.nan)
    dat_eval[nunit_str][noise_str][cpnlt_str]['cos_enc_bayes']= np.full(nruns, np.nan)

    print('Processing run ', end='')
    for irun in range(nruns): # choose run index (to be looped over)
        print(f'{irun}', end=' ')
        
        dict_run = dict_runs[irun]
        loss_total = dict_run['loss_total'] # get all losses
        run_number = dict_run['run_number'] # get run number from filename
        
        # choosing the epoch 
        epochs = np.array(list(dict_run['weights_dict'].keys())) # get epochs where weights were saved
        # epoch_min_loss = epochs[np.array(loss_total)[1][epochs].argmin()] # choose the epoch corresponding to the minimum total loss
        epoch_min_loss = epochs[-1] # choose the epoch corresponding to the minimum total loss
        
        run_weights = dict_run['weights_dict'][epoch_min_loss] # get the weights for that epoch
        run_weights['noise_value'] = dict_run['noise_value'] # get the noise value and store
        
        # evaluate the network
        out = eval_noisyRNN(run_weights, task_list)

        # parse output into evaluation dictionary
        acc  = out['accuracy'].squeeze().detach().cpu().numpy()
        bayesacc  = out['bayesacc'].squeeze().detach().cpu().numpy()
        npcs = np.array(out['npcs'])
        trace_lin  = np.array(torch.dot(out['trace_components'], out['eig_lin_weights']).detach().cpu().numpy())
        trace_quad = np.array(torch.dot(out['trace_components'], out['eig_quad_weights']).detach().cpu().numpy())
        corr_pc_bayesenc = out['corr_pc_bayesenc'].detach().cpu().numpy()
        V_b = out['eigvecs_b'].detach().cpu().numpy()
        V_f = out['eigvecs_f'].detach().cpu().numpy()
        cos_dec_bayes = out['cos_dec_bayes'][0].detach().cpu().numpy()
        cos_enc_bayes = out['cos_enc_bayes'][0].detach().cpu().numpy()

        dat_eval[nunit_str][noise_str][cpnlt_str]['run_numbers'][irun] = dict_run['run_number']
        dat_eval[nunit_str][noise_str][cpnlt_str]['accuracy'][irun,:] = acc
        dat_eval[nunit_str][noise_str][cpnlt_str]['bayesacc'][irun,:] = bayesacc
        dat_eval[nunit_str][noise_str][cpnlt_str]['trace_linear'][irun] = trace_lin
        dat_eval[nunit_str][noise_str][cpnlt_str]['trace_quad'][irun] = trace_quad
        dat_eval[nunit_str][noise_str][cpnlt_str]['rho_pc_bayes'][irun,:,:] = trace_quad.T
        dat_eval[nunit_str][noise_str][cpnlt_str]['eigvec_b'][:,:,irun] = V_b
        dat_eval[nunit_str][noise_str][cpnlt_str]['eigvec_f'][:,:,irun] = V_f
        dat_eval[nunit_str][noise_str][cpnlt_str]['cos_dec_bayes'][irun] = cos_dec_bayes
        dat_eval[nunit_str][noise_str][cpnlt_str]['cos_enc_bayes'][irun] = cos_enc_bayes
        
print(' ')
print('Finished')

# save to disk
f = open('./saved_data/evaluation_runs.pckl', 'wb')
pickle.dump(dat_eval, f)
f.close()

## Plots (Figure 4b,c)

In [None]:
''' ---------- INPUT ---------- '''
save_directory = '../figs/'
save_name      = acc_benc_constrained

# specify desired model configurations
#        units  noise  penalty
cfgs = [
        ['016', '000', '040'],
        ['032', '000', '040'],
        ['064', '000', '040'],
        ['016', '050', '040'],
        ['032', '050', '040'],
        ['064', '050', '040'],
        ['016', '100', '040'],
        ['032', '100', '040'],
        ['064', '100', '040'],
    
        # ['016', '000', 'Non'],
        # ['032', '000', 'Non'],
        # ['064', '000', 'Non'],
        # ['016', '050', 'Non'],
        # ['032', '050', 'Non'],
        # ['064', '050', 'Non'],
        # ['016', '100', 'Non'],
        # ['032', '100', 'Non'],
        # ['064', '100', 'Non'],
       ]

''' -------- END INPUT -------- '''

# load from disk
f = open('./saved_data/evaluation_runs.pckl', 'rb')
pckl_file = pickle.load(f)
dat_eval = pckl_file

''' By default, the y-axis will show accuracy '''
# x-axis options
is_plotx_trace    = False # plot normalized traces on x-axis
is_plotx_cos_benc = False # plot cosine similarity between Bayes enc weights on x-axis
# y-axis options
is_ploty_benc = False # plot cosine similarity between Bayes enc weights on y-axis
is_ploty_bdec = False # plot cosine similarity between Bayes decoding weights on y-axis

if is_ploty_benc and is_ploty_bdec:
    sys.exit('Both the encoding and decoding weights cannot be simultaneously plotted!')
    
dat = {}
for cfg in cfgs:
    # create label
    nunit = int(cfg[0])
    nsstr = 'n' + cfg[1]
    pnstr = cfg[2]
    label = f'{nunit}_{nsstr}_{pnstr}'
    # store evaluation 
    dat[label] = dat_eval[cfg[0]][cfg[1]][cfg[2]]
fig, axs = plt.subplots(2, figsize=(4,7), constrained_layout=True)
fig.set_figwidth(2.4)
fig.set_figheight(4.4)
for i in range(2):
    for ikey, key in enumerate(dat.keys()):
        nunits = int(key.split('n')[0][:2])
        noisev = int(key.split('n')[1][:3])
        
        # color for different noise values
        match noisev:
            case 0:
                rgb = [.5,.5,.5]
            case 50:
                rgb = [.43,.81,.96]
            case 100:
                rgb = [.43,.61,.96]
        match nunits:
            case 16:
                marker='s' 
                xval = 1
            case 32:
                marker='o'
                xval = 2
            case 64: 
                marker='d'
                xval = 3
        dat_tr_q = dat[key]['trace_quad']
        dat_tr_l = dat[key]['trace_linear']
        dat_acc  = dat[key]['accuracy'][:, i]
        dat_bacc = dat[key]['bayesacc'][:, i]
        dat_bdec = dat[key]['cos_dec_bayes']
        dat_benc = dat[key]['cos_enc_bayes']
        dat_tr = dat_tr_l

        # data for plots
        ydat = dat_acc
        ylabelstr = 'Accuracy'
        if is_ploty_bdec:
            ydat = np.abs(dat_bdec)
        if is_ploty_benc:
            ydat = np.abs(dat_benc)
        if is_ploty_benc or is_ploty_bdec:
            ylabelstr = 'Cosine similarity'
            for y in np.arange(0, 1, .1):
                axs[i].axhline(y, alpha=1, linewidth=.01, c='k')
        else:
            axs[i].axhline(dat_bacc[0], linewidth=.5, c='k') # bayes
            for y in np.arange(.5, .8, .05):
                axs[i].axhline(y, alpha=1, linewidth=.01, c='k')
            axs[i].set_ylim([.5, .8])

        # concatenated data for correlation calculation
        if ikey == 0:
            dat_all = dat_tr
            y_all = ydat
        else:
            dat_all = np.concatenate((dat_all, dat_tr),axis=None)
            y_all = np.concatenate((y_all, ydat),axis=None)
            
        if is_plotx_trace:
            xdat = dat_tr
            xlabelstr = 'Normalized trace'
        elif is_plotx_cos_benc:
            xdat = np.abs(dat_benc)
            xlabelstr = 'Cosine similarity'
        
        if not is_plotx_trace and not is_plotx_cos_benc:
            axs[i].errorbar(xval, ydat.mean(), # mean and std
                            yerr=ydat.std()/np.sqrt(np.size(ydat)), label=key, marker=marker, c=rgb, ms=7);
        else:
            # mean and sem
            axs[i].errorbar(xdat.mean(), ydat.mean(),
                            xerr=dat_tr.std()/np.sqrt(np.size(dat_tr)), yerr=ydat.std()/np.sqrt(np.size(ydat)), label=key, marker=marker, c=rgb, ms=7, mfc='w');
            # calculate geometric median
            dat_mat = np.column_stack((xdat, ydat))
            median = geometric_median(dat_mat)
            # geometric median
            axs[i].scatter(median[0], median[1], marker=marker, color=rgb, zorder=10, s=10); 
            for x in np.arange(0, 1, .2):
                axs[i].axvline(x, alpha=1, linewidth=.01, c='k')
            axs[i].set_xlim([0, 1])
            axs[i].set_xlabel(xlabelstr)
            axs[i].set_ylabel(ylabelstr)
            
        # axs[i].legend(loc='lower right', prop={'size': 8})
        axs[i].tick_params(axis='both', which='major', labelsize=14, labelfontfamily='Arial')
        axs[i].tick_params(axis='both', which='minor', labelsize=14, labelfontfamily='Arial')

# correlation
print(f'{spearmanr(dat_all, y_all).statistic:.3f} p={spearmanr(dat_all, y_all).pvalue:.4f}')

fig.savefig(f'./{save_directory}/{save_name}.pdf', format='pdf', bbox_inches = 'tight')

