# Simulation 1b: Fitting 24-Item Lists

This notebook interacts with part of Simulation 3 from the paper - fitting the pooled-modality data from 24-item lists. 

In [None]:
import os
import re
import sys
import json
import numpy as np
import pickle as pkl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import CMR2_pack_cyth as CMR2
from glob import glob
from cluster_helper.cluster import cluster_view

# Set Up Helper Functions

In [None]:
def jitter(x, y, s=20, c='b', marker='o', cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, **kwargs):
    """
    Creates a jittered scatterplot, for use in plotting particle swarms.
    """
    stdev = .12
    jittered = x + np.random.randn(len(x)) * stdev
    return plt.scatter(jittered, y, s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, verts=verts, **kwargs)

def read_err_files(paths):
    """
    Read in the error values from each file in the path list.
    """
    # Return an empty array if no paths are provided
    if len(paths) == 0:
        return np.array([])
    
    # Load first file to determine number of particles and parameters
    scores = np.loadtxt(paths[0], delimiter=' ')
    n_iter = len(paths)
    n_particles = scores.shape[0]
    
    # Create a matrix of iterations x particles
    err_vals = np.empty((n_iter, n_particles))
    err_vals.fill(np.nan)
    for i, path in enumerate(paths):
        scores = np.loadtxt(path, delimiter=' ')
        if len(scores.shape) == 2:
            err_vals[i, :len(scores)] = scores[:, 0]
        else:
            err_vals[i, :len(scores)] = scores

    return err_vals

def read_xfiles(paths):
    """
    Read in the particle locations from each file in the path list.
    """
    # Return an empty array if no paths are provided
    if len(paths) == 0:
        return np.array([])
    
    # Load first xfile to determine number of particles and parameters
    params = np.loadtxt(paths[0], delimiter=' ')
    n_iter = len(paths)
    n_particles, n_params = params.shape
    
    # Create a matrix of iterations x particles x parameters
    param_values = np.empty((n_iter, n_particles, n_params))
    param_values.fill(np.nan)
    for i, path in enumerate(paths):
        params = np.loadtxt(path, delimiter=' ')
        param_values[i, :len(params), :] = params
        
    return param_values

def load_stats(n, i, p, done=True):
    """
    Load the behavioral stats pickle file generated by run n, iteration i, particle p.
    """
    path = '/scratch/jpazdera/cmr2/param_saves/out%s/%sdata%s.pkl' % (n, i, p) if done else '/scratch/jpazdera/cmr2/outfiles/%sdata%s.pkl' % (i, p)
    with open(path, 'rb') as f:
        cmr_stats = pkl.load(f)
    return cmr_stats

def load_targets(target_stat_file):
    # Load target stats from JSON file
    with open(target_stat_file, 'r') as f:
        targets = json.load(f)
    for key in targets:
        if isinstance(targets[key], list):
            targets[key] = np.array(targets[key], dtype=float)
        if isinstance(targets[key], dict):
            for subkey in targets[key]:
                if isinstance(targets[key][subkey], list):
                    targets[key][subkey] = np.array(targets[key][subkey], dtype=float)
    return targets

def plot_fit(data, model, savefile=None):
    """
    Plot model performance against target data.
    """
    VIS_FMT = 'k-'
    AUD_FMT = 'k--'
    ERR_VIS = 'k'
    ERR_AUD = 'k'
    ERR_ALPHA = .2
    ERR_ALPHA2 = .075

    SMALL_SIZE = 12
    MEDIUM_SIZE = 14
    LARGE_SIZE = 16

    plt.rc('font', size=LARGE_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=LARGE_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=LARGE_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
    plt.rc('figure', titlesize=LARGE_SIZE)  # fontsize of the figure title

    plt.figure(figsize=(10, 12.5))
    
    ax=plt.subplot(321)
    plt.title('SPC')
    plt.plot(range(1, 25), np.array(data['spc']['24']), VIS_FMT, alpha=.3)
    plt.fill_between(range(1, 25), np.add(data['spc']['24'], data['spc_sem']['24']), np.subtract(data['spc']['24'], data['spc_sem']['24']), alpha=.12, color=ERR_VIS)
    plt.plot(range(1, 25), np.array(model['spc']['24']), AUD_FMT)
    plt.fill_between(range(1, 25), np.add(model['spc']['24'], model['spc_sem']['24']), np.subtract(model['spc']['24'], model['spc_sem']['24']), alpha=ERR_ALPHA, color=ERR_AUD)
    plt.xticks([1, 3, 6, 9, 12, 15, 18, 21, 24])
    plt.xlabel('Serial Position')
    plt.ylabel('Recall Prob.')
    plt.ylim(0, 1)
    plt.legend(['Data', 'Model'], loc=2)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    ax=plt.subplot(322)
    plt.title('PFR')
    plt.plot(range(1, 25), np.array(data['pfr']['24']), VIS_FMT, alpha=.3)
    plt.fill_between(range(1, 25), np.add(data['pfr']['24'], data['pfr_sem']['24']), np.subtract(data['pfr']['24'], data['pfr_sem']['24']), alpha=.12, color=ERR_VIS)
    plt.plot(range(1, 25), np.array(model['pfr']['24']), AUD_FMT)
    plt.fill_between(range(1, 25), np.add(model['pfr']['24'], model['pfr_sem']['24']), np.subtract(model['pfr']['24'], model['pfr_sem']['24']), alpha=ERR_ALPHA, color=ERR_AUD)
    plt.xticks([1, 3, 6, 9, 12, 15, 18, 21, 24])
    plt.xlabel('Serial Position')
    plt.ylabel('Prob. of First Recall')
    plt.ylim(0, .6)
    plt.legend(['Data', 'Model'], loc=2)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    
    ax=plt.subplot(323)
    plt.title('SPC (Start=SP1)')
    plt.plot(range(1, 25), np.array(data['spc_fr1']['24']), VIS_FMT, alpha=.3)
    plt.fill_between(range(1, 25), np.add(data['spc_fr1']['24'], data['spc_fr1_sem']['24']), np.subtract(data['spc_fr1']['24'], data['spc_fr1_sem']['24']), alpha=.12, color=ERR_VIS)
    plt.plot(range(1, 25), np.array(model['spc_fr1']['24']), AUD_FMT)
    plt.fill_between(range(1, 25), np.add(model['spc_fr1']['24'], model['spc_fr1_sem']['24']), np.subtract(model['spc_fr1']['24'], model['spc_fr1_sem']['24']), alpha=ERR_ALPHA, color=ERR_AUD)
    plt.xticks([1, 3, 6, 9, 12, 15, 18, 21, 24])
    plt.xlabel('Serial Position')
    plt.ylabel('Recall Prob.')
    plt.ylim(0, 1)
    plt.legend(['Data', 'Model'], loc=2)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    
    ax=plt.subplot(324)
    plt.title('SPC (Start=L4)')
    plt.plot(range(1, 25), np.array(data['spc_frl4']['24']), VIS_FMT, alpha=.3)
    plt.fill_between(range(1, 25), np.add(data['spc_frl4']['24'], data['spc_frl4_sem']['24']), np.subtract(data['spc_frl4']['24'], data['spc_frl4_sem']['24']), alpha=.12, color=ERR_VIS)
    plt.plot(range(1, 25), np.array(model['spc_frl4']['24']), AUD_FMT)
    plt.fill_between(range(1, 25), np.add(model['spc_frl4']['24'], model['spc_frl4_sem']['24']), np.subtract(model['spc_frl4']['24'], model['spc_frl4_sem']['24']), alpha=ERR_ALPHA, color=ERR_AUD)
    plt.xticks([1, 3, 6, 9, 12, 15, 18, 21, 24])
    plt.xlabel('Serial Position')
    plt.ylabel('Recall Prob.')
    plt.ylim(0, 1)
    plt.legend(['Data', 'Model'], loc=2)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    
    ax=plt.subplot(325)
    plt.title('PLIs')
    plt.bar([1], [data['plis']], color='w', ec='k', width=.5, capsize=3, yerr=[data['plis_sem']*1.96])
    plt.bar([1.75], [model['plis']], color='w', hatch='/', ec='k', width=.5, capsize=3, yerr=[model['plis_sem']*1.96])
    plt.xlim(.5, 2.25)
    plt.xticks([1, 1.75], ['Data', 'Model'])
    plt.ylabel('PLIs per Trial')
    plt.ylim(0, .2)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    ax=plt.subplot(326)
    plt.title('PLI Recency')
    plt.plot(range(1, 6), data['pli_recency'], VIS_FMT, alpha=.3)
    plt.fill_between(range(1, 6), np.add(data['pli_recency'], data['pli_recency_sem']), np.subtract(data['pli_recency'], data['pli_recency_sem']), alpha=.12, color=ERR_VIS)
    plt.plot(range(1, 6), model['pli_recency'], AUD_FMT)
    plt.fill_between(range(1, 6), np.add(model['pli_recency'], model['pli_recency_sem']), np.subtract(model['pli_recency'], model['pli_recency_sem']), alpha=ERR_ALPHA, color=ERR_AUD)
    plt.xlabel('List Recency')
    plt.ylabel('Proportion of PLIs')
    plt.ylim(0, .6)
    plt.legend(['Data', 'Model'], loc=2)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    
    plt.tight_layout()
    if savefile is not None:
        plt.gcf().savefig(savefile)

# Generate sessions with 24-item lists

In [None]:
# Settings
n_sess = 3000
n_trials = 16
list_length = 24

# Load ltpFR3 wordpool
wp = np.loadtxt('word_IDs.txt', dtype='U256')
if len(wp) < n_trials * list_length:
    raise ValueError('Wordpool is not large enough to support that long of a session!')

# Create N sessions with 16 trials of 24 words
presw = np.zeros((n_sess, n_trials, list_length), dtype=int)
for i in range(n_sess):
    presw[i, :, :] = np.random.choice(wp, size=(n_trials, list_length), replace=False)
d = {'pres_words': presw.tolist()}

# Save word lists to a json file
with open('sim1b_lists.json', 'w') as f:
    json.dump(d, f)

# Create target stat file

### Fixed List Length (24 Items)

In [None]:
# Load Experiment 2 average stats
with open('/data/eeg/scalp/ltp/ltpFR3_MTurk/stats/all_v2_excl_wn.json', 'r') as f:
    d = json.load(f)

# Get the desired stats for 24-item lists, collapsed across presentation rates and modalities
m = d['mean']
s = d['sem']
targets = {}
for stat in ('spc', 'spc_fr1', 'spc_frl4', 'pfr', 'temp_fact', 'sem_fact'):
    targets[stat] = {}
    targets[stat + '_sem'] = {}
    targets[stat]['24'] = m[stat]['24']
    targets[stat + '_sem']['24'] = s[stat]['24']
    
for stat in ('plis', 'pli_recency'):
    targets[stat] = m[stat]['24']
    targets[stat + '_sem'] = s[stat]['24']

# Save stats to a JSON file for access by the particle swarm
with open('../CMR2/target_stats_sim1b.json', 'w') as f:
    json.dump(targets, f)

# Archive Completed Run

In [None]:
fit_num = 203

if os.path.exists('/scratch/jpazdera/cmr2/param_saves/out%s' % fit_num):
    print('Save files already exist! Check to make sure the fit number is correct...')
else:
    # Copy optimal parameters to opt.txt file
    os.system('cp /scratch/jpazdera/cmr2/outfiles/xopt_ltpFR3.txt /scratch/jpazdera/cmr2/param_saves/opt%s.txt' % fit_num)
    # Clean up noise files
    os.system('rm /scratch/jpazdera/cmr2/noise_files/*')
    # Archive outfiles folder
    os.system('mv /scratch/jpazdera/cmr2/outfiles/ /scratch/jpazdera/cmr2/param_saves/out%s' % fit_num)
    # Create a new outfiles folder
    os.system('mkdir /scratch/jpazdera/cmr2/outfiles/')
    # Remove tempfiles from archive
    os.system('rm /scratch/jpazdera/cmr2/param_saves/out%s/*temp*' % fit_num)

### Alternative: Clean outfiles without archiving

In [None]:
fit_num = 999
# Copy optimal parameters to opt.txt file
os.system('cp /scratch/jpazdera/cmr2/outfiles/xopt_ltpFR3.txt /scratch/jpazdera/cmr2/param_saves/opt%s.txt' % fit_num)
# Clean up noise files
os.system('rm /scratch/jpazdera/cmr2/noise_files/*')
# Archive outfiles folder
os.system('mv /scratch/jpazdera/cmr2/outfiles/ /scratch/jpazdera/cmr2/param_saves/out%s' % fit_num)
# Create a new outfiles folder
os.system('mkdir /scratch/jpazdera/cmr2/outfiles/')
# Remove all files
os.system('rm -r /scratch/jpazdera/cmr2/param_saves/out%s' % fit_num)
os.system('rm -r /scratch/jpazdera/cmr2/param_saves/opt%s.txt' % fit_num)

# Particle Swarm Plotting
The scripts that runs the particle swarms can be found in ../CMR2/fitting/pso_cmr2_ltpFR3.py and ../CMR2/fitting/genalg_cmr2_ltpFR3.py. Particles/individuals can be processed in parallel by running multiple instances of the script.
### Plotting parameter sets

In [None]:
############
#
#   SETTINGS
#
############

fit_num = 203
done = False
folder_path = '/scratch/jpazdera/cmr2/param_saves/out%s/' % fit_num if done else '/scratch/jpazdera/cmr2/outfiles/'

#    [ b_e,  b_r, g_fc, g_cf, p_s, p_d,  k,  e, s_cf, b_rp,  o,  a, c_t,   l]
lb = [  .1,    0,    0,  .15,   0,   0,  0,  0,   .5,    0,  5, .5,   0,   0]
ub = [  .9,    1,    1,  .85,   8,   5, .5, .5,    3,    1, 20,  1,  .5, .25]

# Set labels for each parameter's graph
param_labels = [
    r'$\beta_{enc}$',
    r'$\beta_{rec}$',
    r'$\gamma_{FC}$',
    r'$\gamma_{CF}$',
    r'$\phi_{s}$',
    r'$\phi_{d}$',
    r'$\kappa$',
    r'$\eta$',
    r'$s$',
    r'$\beta_{post}$',
    r'$\omega$',
    r'$\alpha$',
    r'$c_{thresh}$',
    r'$\lambda$',
    r'$\beta_{source}$',
    r'$\gamma_{source}$'
]

############
#
#   Data preparation
#
############

# Get all parameter and error value file paths
xfile_paths = glob(folder_path + '*xfile*')
err_paths = glob(folder_path + 'err_iter*')

# Sort paths numerically, rather than alphabetically
xfile_paths.sort(key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
err_paths.sort(key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

# Load parameter and error values
param_vals = read_xfiles(xfile_paths)  # iteration x particle x param
err_vals = np.sqrt(read_err_files(err_paths))  # iteration x particle

# Determine number of iterations, swarm size, and number of parameters from shape of param matrix
n_iter, swarmsize, n_params = param_vals.shape

# Get range of iteration values for plotting purposes
iter_vals = range(n_iter)

############
#
#   Plotting
#
############

plt.figure(figsize=(10, 28))
# Plot one row in the figure for each parameter
for param_num in range(n_params):

    # Select the subplot for the current parameter, and label the axes
    plt.subplot(n_params, 1, param_num+1)
    plt.ylabel(param_labels[param_num])
    plt.yticks([lb[param_num], lb[param_num] + (ub[param_num] - lb[param_num]) / 2, ub[param_num]])
    
    positions = np.array([])
    iterations = np.array([])
    err_sorted = np.array([])
    for i in iter_vals:
        # Sort positions and error vals from worst to best fit, so that best fits always appear in front
        sort_idx = np.argsort(1/err_vals[i, :])
        positions = np.concatenate((positions, param_vals[i, sort_idx, param_num]))
        err_sorted = np.concatenate((err_sorted, err_vals[i, sort_idx]))
        iterations = np.concatenate((iterations, [i for _ in range(swarmsize)]))
    
    jitter(iterations+1, positions, s=2, alpha=1, c=1/err_sorted, cmap=cm.rainbow)#, vmin=1/1000, vmax = 1/100)
    plt.ylim(lb[param_num], ub[param_num])

plt.tight_layout()
plt.gcf().savefig('CMR2_plots/param_plot_%s.jpg' % fit_num, dpi=150)

### Plotting error values over time

In [None]:
fit_num = 203
done = False
saved_files_path = '/scratch/jpazdera/cmr2/param_saves/out%s/' % fit_num if done else '/scratch/jpazdera/cmr2/outfiles/'

# Get all error file paths
err_paths = glob(saved_files_path+'err_iter*')

# Sort paths numerically, rather than alphabetically
err_paths.sort(key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

# Read error values
err_array = np.sqrt(read_err_files(err_paths))

# Determine number of iterations
nits = err_array.shape[0]

# Determine minimum and mean error in each iteration
min_errs = np.nanmin(err_array, axis=1)
mean_errs = np.nanmean(err_array, axis=1)
print('\nIterations completed: ', len(min_errs))
print("\nSmallest error is at iteration:")
smallest_err_loc = np.nanargmin(min_errs) + 1  # Add one since iteration numbers start indexing at 1
print(smallest_err_loc)
print("And its value is: ")
print(np.nanmin(min_errs))

# Plot error over time
plt.figure(figsize=(14, 5))
plt.subplot(121)
plt.axvline(smallest_err_loc, color='k', linestyle='--')
plt.plot(range(1, nits+1), min_errs, 'b-', label="Particle Swarm")
plt.title("Minimum Error Over Time")
plt.xlabel("Iteration")
plt.ylabel("Min Error")
plt.subplot(122)
plt.plot(range(1, nits+1), mean_errs, 'b-')
plt.title("Mean Error Over Time")
plt.xlabel("Iteration")
plt.ylabel("Mean Error")
#plt.yscale("log")
plt.tight_layout(w_pad=3)
plt.gcf().savefig('CMR2_plots/error%s.jpg' % fit_num, dpi=150)

# Plotting best fit from an iteration

In [None]:
############
#
#   Settings
#
############

fit_num = 203
iter_num = 171
done = False

############
#
#   Identify Best Fit
#
############

iter_num = str(iter_num)
target_stat_file = '../CMR2/target_stats_sim1b.json'
err_dir = '/scratch/jpazdera/cmr2/param_saves/out%s/' % fit_num if done else '/scratch/jpazdera/cmr2/outfiles/'
err_path = 'err_iter' + iter_num

# Load error values
err_vals = np.sqrt(np.loadtxt(err_dir + err_path, delimiter=' '))
err_indices = [i for i, _ in enumerate(err_vals)]

# get minimum err value
min_err = np.nanmin(err_vals)

# get index where min err value was located
min_err_index = int(np.where(err_vals == min_err)[0])

print("\nMinimum error and matching index are: ")
print("Error: ", min_err)
print("Index: ", min_err_index)

# find the set of parameters in the associated xfile that match
xfile_path = iter_num + 'xfile.txt'
xfile = np.loadtxt(err_dir + xfile_path, delimiter=' ')

print("\nBest-fitting parameters for iteration " + iter_num + " are: ")
[print('%s,' % x) for x in xfile[min_err_index]]

############
#
#   Load Stats
#
############

cmr_stat_file = '/scratch/jpazdera/cmr2/param_saves/out%s/%sdata%s.pkl' % (fit_num, iter_num, min_err_index) if done else '/scratch/jpazdera/cmr2/outfiles/%sdata%s.pkl' % (iter_num, min_err_index)
with open(cmr_stat_file, 'rb') as f:
    cmr_stats = pkl.load(f)

# Load target stats from JSON file
targets = load_targets(target_stat_file)

############
#
#   Plot Fit
#
############

plot_fit(targets, cmr_stats)
plt.gcf().savefig('CMR2_plots/fit%s.jpg' % fit_num, dpi=150)

# Finalizing Best Fit

In [None]:
def eval_model(param_vec):
    
    import os
    import sys
    import json
    import numpy as np
    import pickle as pkl
    import CMR2_pack_cyth as CMR2
    from glob import glob
    
    sys.path.append('../CMR2/fitting/')
    import optimization_utils as opt
    
    ##########
    #
    # Initialization
    #
    ##########
    
    # Settings
    n_sess = 2500  # Number of sessions to simulate
    fit_number = 203
    
    wordpool_file = '../CMR2/wasnorm_wordpool.txt'  # Path to wordpool file
    w2v_file = '../CMR2/w2v.txt'  # Path to semantic associative matrix file
    target_stat_file = '../CMR2/target_stats_sim1b.json'  # Path to file with target stats
    
    # Load randomly generated 24-item lists
    with open('sim1b_lists.json', 'r') as f:
        data_pres = np.array(json.load(f)['pres_words'][:n_sess])

    # Create session indices and collapse sessions and trials of presented items into one dimension
    sessions = []
    for n, sess_pres in enumerate(data_pres):
        sessions += [n for _ in sess_pres]
    sessions = np.array(sessions)
    data_pres = data_pres.reshape((data_pres.shape[0] * data_pres.shape[1], data_pres.shape[2]))
    sources = None

    # Load semantic similarity matrix (word2vec)
    w2v = np.loadtxt(w2v_file)

    # Load target stats from JSON file
    with open(target_stat_file, 'r') as f:
        targets = json.load(f)
    for key in targets:
        if isinstance(targets[key], list):
            targets[key] = np.array(targets[key], dtype=float)
        if isinstance(targets[key], dict):
            for subkey in targets[key]:
                if isinstance(targets[key][subkey], list):
                    targets[key][subkey] = np.array(targets[key][subkey], dtype=float)
    
    # Extract parallel ID number if one was provided in the parameter vector
    if len(param_vec) == 15:
        parallel_ID = int(param_vec[-1])
        param_vec = param_vec[:-1]
    else:
        parallel_ID = None
    
    ##########
    #
    # Run Model
    #
    ##########
    
    # Run model with the parameters given in param_vec
    data_path = '/scratch/jpazdera/cmr2/param_saves/out%i/final_data%s.pkl' % (fit_number, parallel_ID)
    if os.path.exists(data_path):
        with open(data_path, 'rb') as f:
            cmr_stats = pkl.load(data_path)
    else:
        err, cmr_stats = opt.obj_func(param_vec, targets, data_pres, sessions, w2v, sources, return_recalls=False)
        with open(data_path, 'wb') as f:
            pkl.dump(cmr_stats, f)

    return cmr_stats

### Load parameters and scores

In [None]:
run = 203

folder_path = '/scratch/jpazdera/cmr2/param_saves/out%s/' % run
target_stat_file = '../CMR2/target_stats_sim1b.json'

# Get parameter and error value file paths
xfile_paths = glob(folder_path + '*xfile*')
err_paths = glob(folder_path + 'err_iter*')

# Sort paths numerically, rather than alphabetically
xfile_paths.sort(key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
err_paths.sort(key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

# Load parameter and error values
params = read_xfiles(xfile_paths)  # iteration x particle x parameter
scores = read_err_files(err_paths)  # iteration x particle

# Identify each particle's 5 best parameter sets
swarmsize = scores.shape[1]
nparams = params.shape[2]
best_params = np.full((5 * swarmsize, nparams), np.nan)
for i in range(swarmsize):
    best_iters = np.argsort(scores[:, i])[:5]
    best_params[i*5:i*5+5, :] = params[best_iters, i]

# Attach ID numbers to parameter sets
best_params = np.hstack((best_params, np.atleast_2d(np.arange(1000)).T))

### Evaluate best models

In [None]:
# Test the best parameter set found by each particle
done = True

results = None
if done:
    with open(folder_path + 'final_results.pkl', 'rb') as f:
        results = pkl.load(f)
else:
    try:
        with cluster_view(scheduler='sge', queue='RAM.q', num_jobs=250, cores_per_job=1) as view:
            results = view.map(eval_model, best_params)
    except IOError as e:
        print(e)
    with open(folder_path + 'final_results.pkl', 'wb') as f:
        pkl.dump(results, f, 2)

### Identify best fitting model

In [None]:
# Identify best parameter set and print it along with its error and the number of the particle that found it
scores = np.array([r['err'] for r in results])
best_err = scores.min()
best_idx = scores.argmin()
print(best_err, best_idx)
print()
[print(round(p, 5)) for p in best_params[best_idx]]

# Plot behavioral stats for the best fitting parameter
targets = load_targets(target_stat_file)
                
plot_fit(targets, results[best_idx], savefile='/home1/jpazdera/jupyter/ltpFR3/notebooks-analysis/ltpFR3_Figs/sim1b.pdf')