In [None]:
# built-in
import sys
import os
from copy import deepcopy
from collections import defaultdict
from pprint import pprint

# third-party
import numpy as np
import joblib
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline

# ours
root = os.path.dirname(os.path.abspath(os.curdir))
sys.path.append(root)
import plot_all_figs_helper as helper
import configs
import tools
from analysis.analysis import evaluate_run
from models.model_utils import get_model
from datasets.recall import RecallDataset, EcstasyRecall

# setup
figpath = os.path.join(root, 'publish')

import seaborn as sns
sns.set(font='Arial',
        font_scale=7/12., #default size is 12pt, scale down to 7pt
        palette='Set1',
        rc={'axes.axisbelow': True,
            'axes.edgecolor': 'lightgrey',
            'axes.facecolor': 'None',
            'axes.grid': False,
            'axes.labelcolor': 'dimgrey',
            'axes.spines.right': False,
            'axes.spines.top': False,
            'text.color': 'dimgrey', #e.g. legend

            'lines.solid_capstyle': 'round',
            'legend.facecolor': 'white',
            'legend.framealpha':0.8,

            'xtick.bottom': True,
            'xtick.color': 'dimgrey',
            'xtick.direction': 'out',

            'ytick.color': 'dimgrey',
            'ytick.direction': 'out',
            'ytick.left': True,

             'xtick.major.size': 2,
             'xtick.major.width': .5,
             'xtick.minor.size': 1,
             'xtick.minor.width': .5,

             'ytick.major.size': 2,
             'ytick.major.width': .5,
             'ytick.minor.size': 1,
             'ytick.minor.width': .5})


# Figure 2: Benchmark
## (a) Dataset

In [None]:
dataset = RecallDataset(T_min=15, T_max=15, stim_dim=30, p_recall=1)
dataset.visualize(figpath=figpath, figname='benchmark_dataset')

## (b) Performance: simple-sequential, simple-random (p=4/N), Hopfield, TVT

In [None]:
net_size = 40      
ref_configs = {'Sequential' : configs.get_config('ref_seq', stim_dim=net_size, hidden_size=net_size), 
               'Random' : configs.get_config('ref_rand', stim_dim=net_size, hidden_size=net_size),
               'Hopfield' : configs.get_config('hopfield', stim_dim=net_size),
              }

pprint(ref_configs)

In [None]:
fname = os.path.join(figpath, 'benchmark_acc_vs_seqlen.pkl')

seqlen_list, acc_vs_seqlen, err_vs_seqlen = helper.get_acc_vs_seqlen_configs(ref_configs, fname=fname, repeat='auto')

path = '../files/tvt/000000'
config = tools.load_config(path)
net = get_model(config)
net.load(os.path.join(path, 'model.pt'))
acc_vs_seqlen['TVT'], err_vs_seqlen['TVT'] = helper._get_acc_vs_seqlen(net, seqlen_list, config=config, repeat='auto')

In [None]:
ax = helper.plot_acc_vs_seqlen(
    seqlen_list, acc_vs_seqlen, err_vs_seqlen,
    save_fname=fname[:-4]+'.pdf'
    )

## (c) Capacity

In [None]:
fname = os.path.join(figpath, 'benchmark_capacity_vs_size.pkl')
try:
    result = joblib.load(fname)
    sizes = result['sizes']
    capacities = result['capacities']
    labels = result['labels']
    thres = result['thres']
except:    
    thres = 0.98
    size_list = [20,40,80,160]
    sizes = []
    capacities = []
    labels = []
    for label, config in ref_configs.items():
        for i, net_size in enumerate(size_list):        
            print('{} N={}'.format(label, net_size))
            config['dataset']['stim_dim'] = net_size
            if 'plasticnet' in config:
                config['plasticnet']['hidden_size'] = net_size
            for _ in range(50): # Number of repeats
                capacity = helper.get_capacity(config, thres=thres, repeat=1)
                capacities.append(capacity)
                labels.append(label)
                sizes.append(net_size)
    sizes = np.array(sizes)
    capacities = np.array(capacities)
    labels = np.array(labels)
    result = {
        'sizes': sizes,
        'capacities': capacities,
        'labels': labels,
        'thres': thres
        }
    joblib.dump(result, fname)

ax = helper.plot_capacity(sizes, capacities, labels, ref_configs.keys(), thres)

ax.set_xlim([15,165])
helper.format_and_save(ax.get_figure(), fname[:-4]+'.pdf')


# Figure 3: Optimization
## (a) Optimized performance
Note that trained has passive decay in h2o but performs about the same as active if correctly tuned

In [None]:
net_size = 40
prob_list = [k/net_size for k in [1,4]]

exp_path = '../files/train_random_capacity'  

select = {'plasticnet.hidden_size' : net_size,
          'dataset.stim_dim' : net_size,
          'plasticnet.local_thirdfactor.mode' : 'sequential'}
modeldirs = tools.get_modeldirs(exp_path, select_dict=select)

for p in prob_list:
    select = {'plasticnet.hidden_size' : net_size,
              'dataset.stim_dim' : net_size,
              'plasticnet.local_thirdfactor.mode' : 'random',
              'plasticnet.local_thirdfactor.b0' : p}
    modeldirs += tools.get_modeldirs(exp_path, select_dict=select) 

for path in modeldirs:
    config = tools.load_config(path)
    try: p = config['plasticnet']['local_thirdfactor']['b0']
    except: p = ''
    print(path, config['plasticnet']['local_thirdfactor']['mode'], p)

In [None]:
fname = os.path.join(figpath, 'optimized_acc_vs_seqlen.pkl')
try:
    result = joblib.load(fname)
    seqlen_list = result['seqlen_list']
    acc_vs_seqlen = result['acc_vs_seqlen']
    err_vs_seqlen = result['err_vs_seqlen']
except:
    acc_vs_seqlen = {}
    err_vs_seqlen = {}
    seqlen_list = np.unique(np.logspace(0,np.log10(180), 50, dtype=int))
    for path in modeldirs:    
        for trained in [False, True]:
            config = tools.load_config(path)
            results = tools.load_results(path)
            log = tools.load_log(path)
            if trained:
                net = get_model(config)
                model_path = os.path.join(path, 'model.pt')
                net.load(model_path)
            else:
                ref_config = tools.nested_update(config,
                    {'plasticnet' : {'h2o' : {'decay' : 'active'},
                                     'h2o_use_local_thirdfactor' : True, #required for active decay
                                     'reset_to_reference': True}})
                net = get_model(ref_config)
            label = '{} {}'.format('Trained' if trained else 'Simple', net.rnn.ltf.mode)
            if net.rnn.ltf.mode == 'random':
                label += ' {}'.format(config['plasticnet']['local_thirdfactor']['b0'])
            print(label)
            acc_vs_seqlen[label], err_vs_seqlen[label] = helper._get_acc_vs_seqlen(
                net, seqlen_list, config, repeat='auto'
                )
    result = {'seqlen_list':seqlen_list, 'acc_vs_seqlen':acc_vs_seqlen, 'err_vs_seqlen':err_vs_seqlen}
    joblib.dump(result, fname)

fig, ax = plt.subplots()
legend_handles = []
for label, acc in acc_vs_seqlen.items():
    #relies on the order being the same as during compute
    err = err_vs_seqlen[label]
    if label.startswith('Trained'): 
        color = ax.get_lines()[-1].get_color()
        handle = ax.plot(seqlen_list, acc_vs_seqlen[label], color=color, label='Trained', ls='-')
        ax.fill_between(
            seqlen_list, [e[0] for e in err], [e[1] for e in err],
            label='Trained', color=color, alpha=0.3
            )
    else:
        label_legend = label[label.find(' ')+1:].capitalize()
        handle = ax.plot(seqlen_list, acc_vs_seqlen[label], label=label_legend, ls='--')
        ax.fill_between(
            seqlen_list, [e[0] for e in err], [e[1] for e in err],
            label=label_legend, alpha=0.3
            )
    legend_handles.append(handle[0])
ax.legend(handles=legend_handles)
ax.set_xlabel('Number of stimuli')
ax.set_ylabel('Accuracy')
fig.set_size_inches(2.75,1.7)
fig.tight_layout()
fig.savefig(fname[:-4]+'.pdf')

## (b) Optimized capacity
Note that scaling p with N rather than keeping constant is helpful

In [None]:
size_list = [20, 40, 80, 160]
k = 4  #ltf_prob = k/N

exp_path = '../files/train_random_capacity'  
modeldirs_dict = defaultdict(list)
             
for trial in os.listdir(exp_path):
    path = os.path.join(exp_path, trial)
    config = tools.load_config(path)
    
    if config['plasticnet']['local_thirdfactor']['mode'] == 'sequential':
        modeldirs_dict['Sequential'].append(path)
        
    size = config['plasticnet']['hidden_size']
    if (config['plasticnet']['local_thirdfactor']['mode'] == 'random'
    and config['plasticnet']['local_thirdfactor']['b0'] == k/size):
        modeldirs_dict['Random {}/N'.format(k)].append(path)

    if (config['plasticnet']['local_thirdfactor']['mode'] == 'random'
    and config['plasticnet']['local_thirdfactor']['b0'] == 0.1):
        modeldirs_dict['Random 0.1'].append(path)
    
for label, modeldirs in modeldirs_dict.items():
    for path in modeldirs:
        config = tools.load_config(path)
        try: p = config['plasticnet']['local_thirdfactor']['b0']
        except: p = ''
        print(label, path, config['plasticnet']['local_thirdfactor']['mode'], p, config['plasticnet']['hidden_size'] )

In [None]:
fname = os.path.join(figpath, 'optimized_capacity_vs_size.pkl')
try:
    result = joblib.load(fname)
    sizes = result['sizes']
    capacities = result['capacities']
    labels = result['labels']
    thres = result['thres']
except: 
    thres = 0.98
    sizes = []
    capacities = []
    labels = []

    for label, modeldirs in modeldirs_dict.items():
        for path in modeldirs:
            config = tools.load_config(path)
            net = get_model(config)
            model_path = os.path.join(path, 'model.pt')
            net.load(model_path)
            print(label, config['plasticnet']['hidden_size'])
#             i = int(np.log2(net.rnn.hidden_size/10))-1 #assumes size_list=[20,40,80,160]
#             capacity_vs_size[label][i] = helper.get_capacity(config, net, thres=thres, repeat=50)
            for _ in range(50): # Number of repeats
                capacity = helper.get_capacity(config, net, thres=thres, repeat=1)
                capacities.append(capacity)
                labels.append(label)
                sizes.append(net.rnn.hidden_size)
    sizes = np.array(sizes)
    capacities = np.array(capacities)
    labels = np.array(labels)
    result = {
        'sizes': sizes,
        'capacities': capacities,
        'labels': labels,
        'thres': thres
        }
    joblib.dump(result, fname)

ax = helper.plot_capacity(sizes, capacities, labels, ['Sequential', 'Random 4/N', 'Random 0.1'], thres)
ax.set_xlim([15,165])
helper.format_and_save(ax.get_figure(), fname[:-4]+'.pdf')

## (c) Training loss/accuracy: trained-seq, trained-rand 4/N

In [None]:
net_size = 40

exp_path = '../files/train_prepost_zero_init'  

modeldirs = []
modeldirs += tools.get_modeldirs(exp_path, select_dict=
                                {'plasticnet.hidden_size' : net_size,
                                 'dataset.stim_dim' : net_size,
                                 'plasticnet.local_thirdfactor.mode' : 'sequential'})
modeldirs += tools.get_modeldirs(exp_path, select_dict=
                                 {'plasticnet.hidden_size' : net_size,
                                  'dataset.stim_dim' : net_size,
                                  'plasticnet.local_thirdfactor.mode' : 'random',
                                  'plasticnet.local_thirdfactor.b0' : 0.1}) 

for path in modeldirs:
    config = tools.load_config(path)
    try: p = config['plasticnet']['local_thirdfactor']['b0']
    except: p = ''
    print(path, config['plasticnet']['local_thirdfactor']['mode'], p)

In [None]:
fig, ax = plt.subplots(2,1,sharex=True)
clip=60
for path in modeldirs:    
    log = tools.load_log(path)
    
    config = tools.load_config(path)
    label = config['plasticnet']['local_thirdfactor']['mode'].capitalize()
    
    ax[0].plot(log['steps'][:clip],log['loss_train'][:clip], label=label)
    ax[1].plot(log['steps'][:clip],log['acc_train'][:clip], label=label)
ax[0].legend()    
ax[0].set_ylabel('Loss')
ax[1].set_ylabel('Accuracy')
ax[1].set_xlabel('Iterations')

fname = os.path.join(figpath, 'loss_acc_optimized.pdf')
fig.set_size_inches(2,3.5)
fig.tight_layout()
fig.savefig(fname)

## (d) Training plasticity params: trained-seq, trained-rand 4/N

In [None]:
import matplotlib as mpl

for path in modeldirs: 
    fig, ax = plt.subplots(2,2, sharex=True, sharey=True)
    log = tools.load_log(path)
    config = tools.load_config(path)
    label = config['plasticnet']['local_thirdfactor']['mode']
    
    iters = log['steps'][clip]
    N = len(log['steps'][:clip])
    for c, layer in enumerate(['i2h', 'h2o']):
        for r, syn in enumerate(['pre', 'post']):
            ax[r,c].axvline(0, color='grey', linewidth=1)
            ax[r,c].axhline(0, color='grey', linewidth=1)
            
            w = log['rnn.{}.{}_fn.weight'.format(layer, syn)][:clip]
            b = log['rnn.{}.{}_fn.bias'.format(layer, syn)][:clip]
            for i in range(N-1):
                ax[r,c].plot(w[i:i+2], b[i:i+2], color=plt.cm.jet(i/N))
            
            layer_str = 'Input-to-hidden' if layer == 'i2h' else 'Hidden-to-output'
            title = '{}, {}-synaptic'.format(layer_str, syn)    
            ax[r,c].set_title(title)


    cb = fig.colorbar(plt.cm.ScalarMappable(norm=mpl.colors.Normalize(0, iters),
                                       cmap=plt.cm.jet),
                 ax=ax, fraction=0.05, aspect=40, ticks=[0,iters]          
                 )
    cb.set_label('Iterations', labelpad=-15)
    [a.set_xlabel('$\widetilde{a}$') for a in ax[1,:].flatten()]
    [a.set_ylabel('$\widetilde{b}$') for a in ax[:,0].flatten()]

    fname = os.path.join(figpath, 'plast_params_{}.pdf'.format(label))
    fig.set_size_inches(4,3.5)
#     fig.tight_layout()
    fig.savefig(fname)

# Figure 4: Harder tasks
## (a) Continual recall

In [None]:
dataset = RecallDataset(T_min=15, T_max=15, stim_dim=30, p_recall=0.6,
                        recall_order='interleave', recall_interleave_delay=2)
dataset.visualize(figpath=figpath, figname='continual_dataset')

## Continual performance: simple-seq, simple/trained-rand (p=4/N), Hopfield

In [None]:
net_size = 40      
continual_configs = {'Sequential' : configs.get_config('ref_seq', stim_dim=net_size, hidden_size=net_size), 
              'Random' : configs.get_config('ref_rand', stim_dim=net_size, hidden_size=net_size),
              'Hopfield' : configs.get_config('hopfield', stim_dim=net_size),
              }

#from Parisi 1986
# continual_configs['Hopfield']['hopfield']['clamp_val'] = 0.4
# continual_configs['Hopfield']['hopfield']['learning_rate'] = 1/np.sqrt(net_size)
# continual_configs['Hopfield']['hopfield']['steps'] = 20

for key in continual_configs.keys():
    config = continual_configs[key]
    config['dataset']['recall_order'] = 'interleave'
#     config['dataset']['recall_interleave_delay'] = delay
#     config['dataset']['T_min'] = config['dataset']['T_max'] = max(1000, delay*20)
    config['dataset']['p_recall'] = 0.5
continual_configs['Hopfield']['hopfield']['decay_rate'] = 0.95

pprint(continual_configs)

In [None]:
fname = os.path.join(figpath, 'continual_acc_vs_delay.pkl')
try:
    result = joblib.load(fname)
    delay_list = result['delay_list']
    acc_vs_delay = result['acc_vs_delay']
    err_vs_delay = result['err_vs_delay']
except:
    delay_list = np.unique(np.logspace(0, np.log10(180), 50, dtype=int))
    acc_vs_delay = defaultdict(list)
    err_vs_delay = defaultdict(list)
    for label, config in continual_configs.items():
        print(label)
        for delay in delay_list:
            config['dataset']['recall_interleave_delay'] = delay
            config['dataset']['T_min'] = config['dataset']['T_max'] = max(1000, delay*20)
            net = get_model(config) 
            result = evaluate_run(model=net, update_config=config, n_batch=3)
            acc, err = helper.get_avg_recall_acc(result)
            print(' Delay = {}, acc = {}'.format(delay, acc))
            acc_vs_delay[label].append(acc)
            err_vs_delay[label].append(err)
    result = {
        'delay_list':delay_list, 'acc_vs_delay':acc_vs_delay,
        'err_vs_delay':err_vs_delay
        }
    joblib.dump(result, fname)

In [None]:
path = '../files/train_continual'
config = tools.load_config(path)
net = get_model(config)
model_path = os.path.join(path, 'model.pt')
net.load(model_path)
label='Trained'
for delay in delay_list: 
    config['dataset']['recall_interleave_delay'] = delay
    config['dataset']['T_min'] = config['dataset']['T_max'] = max(1000, delay*20)
    result = evaluate_run(model=net, update_config=config, n_batch=3)
    acc, err = helper.get_avg_recall_acc(result)
    print(' Delay = {}, acc = {}'.format(delay, acc))
    acc_vs_delay[label].append(acc)
    err_vs_delay[label].append(err)

In [None]:
ax = helper.plot_acc_vs_seqlen(delay_list, acc_vs_delay, err_vs_delay, 
                                   labels=['Sequential','Random','Trained','Hopfield'])
ax.set_xlabel('Delay interval')
ax.set_title('Continual recall', fontdict={'fontweight':'bold'})
helper.format_and_save(ax.get_figure(), fname[:-4]+'.pdf')

## (b) Flashbulb: simple-seq, simple/trained-rand (p=4/N), Hopfield

In [None]:
dataset = EcstasyRecall(T_min=15, T_max=15, stim_dim=30, p_recall=0.5,
                        recall_order='interleave', recall_interleave_delay=5,
                        ec_strength=3, p_ec=0.4)
dataset.visualize(figpath=figpath, figname='flashbulb_dataset')

In [None]:
def plot_flashbulb(net_type, net_title, add_legend=True, ax=None):
    fname = os.path.join(figpath, f'flashbulb_{net_type.lower()}.pkl')
    try:
        result = joblib.load(fname)
        print("Loaded pickle file")
    except:
        vary_ec_strength = True if net_type == 'Hopfield' else False
        result = helper.get_flashbulb_performance(
            net_type, vary_ec_strength=vary_ec_strength
            )
        joblib.dump(result, fname)
        print("Ran results and saved to pickle file")

    ec_strength_list = result['ec_strength_list']
    R_list = result['R_list']
    
    if ax is None:
        fig, ax = plt.subplots()
    if len(ec_strength_list) == 1:
        cmap_space = [0.99]
    else:
        cmap_space = np.linspace(0.5, 0.99, len(ec_strength_list))
    colors_reg = [cm.Blues(i) for i in cmap_space]
    colors_flashbulb = [cm.Reds(i) for i in cmap_space]

    for ec_idx, ec_strength in enumerate(ec_strength_list):
        if ec_strength == ec_strength_list[-1]:
            reg_label = 'Regular'
            flashbulb_label = 'Flashbulb'
        else:
            reg_label = flashbulb_label = None
        
        ax.plot(R_list, result[ec_strength]['acc_vs_R_reg'], 
                color=colors_reg[ec_idx], label=reg_label)
        ax.fill_between(R_list, *zip(*result[ec_strength]['err_vs_R_reg']), 
                        color=colors_reg[ec_idx], alpha=0.3)
        
        ax.plot(R_list, result[ec_strength]['acc_vs_R_flashbulb'], 
                color=colors_flashbulb[ec_idx], label=flashbulb_label)
        ax.fill_between(R_list, *zip(*result[ec_strength]['err_vs_R_flashbulb']), 
                        color=colors_flashbulb[ec_idx], alpha=0.3)
   
    if add_legend:
        ax.legend()
        ax.set_ylabel('Accuracy')
    ax.set_xlabel('Delay interval')
    ax.set_title(f"{net_title}", fontdict={'fontweight':'bold'})
    ax.set_ylim(0.55, 1.025)
    return ax

In [None]:
fig, ax = plt.subplots(1,3, sharex=True, sharey=True)
plot_flashbulb("Reference", "Flashbulb: Sequential", ax=ax[0] )
plot_flashbulb("Random", "Random", add_legend=False, ax=ax[1]) # This is random reference!
plot_flashbulb("Hopfield", "Hopfield", add_legend=False, ax=ax[2])
fig.set_size_inches(5.5, 1.7)
fig.tight_layout()
fig.savefig(os.path.join(figpath,'flashbulb_all.pdf'))
plt.show()

## (c) Correlated: simple-seq, simple/trained-rand (p=4/N), Hopfield

In [None]:
dataset = RecallDataset(T_min=15, T_max=15, stim_dim=30, p_recall=1,
                        temporal_corr=0.6, temporal_corr_mode='template',
                        )
dataset.visualize(figpath=figpath, figname='corr_dataset')

In [None]:
net_size = 40
for i,corr in enumerate([0.3, 0.6, 0.9]):
    corr_configs = {'Sequential' : configs.get_config('ref_seq', stim_dim=net_size, hidden_size=net_size),
                   'Random' : configs.get_config('ref_rand', stim_dim=net_size, hidden_size=net_size),
                   'Hopfield' : configs.get_config('hopfield', stim_dim=net_size),
                  }

    for label,config in corr_configs.items():
        config['dataset']['temporal_corr']= corr
        config['dataset']['temporal_corr_mode']= 'template'
#     pprint(corr_configs)

    fname = os.path.join(figpath, 'corr_{}_acc_vs_seqlen.pkl'.format(corr))
    seqlen_list, acc_vs_seqlen, err_vs_seqlen \
        = helper.get_acc_vs_seqlen_configs(corr_configs, repeat='auto', fname=fname)
    
    path = '../files/train_corr/{:06d}'.format(i)
    config = tools.load_config(path)
    net = get_model(config)
    net.load(os.path.join(path, 'model.pt'))
    label = 'Trained'
    print(label)
    acc_vs_seqlen[label], err_vs_seqlen[label] = \
        helper._get_acc_vs_seqlen(net, seqlen_list, config=config, repeat='auto')

    ax = helper.plot_acc_vs_seqlen(seqlen_list, acc_vs_seqlen, err_vs_seqlen, 
                                   labels=['Sequential','Random','Trained','Hopfield'])
    ax.set_title('Correlation={}'.format(corr), fontdict={'fontweight':'bold'})

    helper.format_and_save(ax.get_figure(), fname[:-4]+'.pdf')

# Figure 5: Heteroassociative tasks

In [None]:
import experiments
from pathlib import Path
import torch

In [None]:
# If you're recomputing results, make sure this is correct for your system!!!
filesdir = root / Path("files")
print(filesdir)

## (a) Heteroassociative Recall

In [None]:
dataset = RecallDataset(T_min=15, T_max=15, stim_dim=30, p_recall=1,
                       heteroassociative=True, heteroassociative_stim_dim=15)
dataset.visualize(figsize_scalings=[[1,1], [1,0.5], [1,0.5]],
                  figpath=figpath, figname='heteroassociative_dataset')

In [None]:
exps = [
    'heteroassociative',
    'heteroassociative_random_reference',
    'heteroassociative_random',
    'heteroassociative_hopfield'
    ]
exp_labels = [
    'Sequential', 'Random',
    'Trained', 'BAM'
    ]
seqlen_list = np.arange(0, 90, 5)
seqlen_list[0] = 1

fname = os.path.join(figpath, 'heteroassociative_recall.pkl')
try:
    result = joblib.load(fname)
    print("Loaded pickle file")
except:
    result = {}
    result['acc'] = {}
    result['err'] = {}
    for exp, label in zip(exps, exp_labels):
        net = helper.load_model(exp, filesdir=filesdir)
        fullconfig, _, _ = getattr(experiments, exp)()
        acc_vs_seqlen, err_vs_seqlen = helper._get_acc_vs_seqlen(
            net, seqlen_list, config=fullconfig
            )
        result['acc'][exp] = acc_vs_seqlen
        result['err'][exp] = err_vs_seqlen
    joblib.dump(result, fname)
    print("Ran results and saved to pickle file")    


fig, ax = plt.subplots()
legend_handles = []
for exp, label in zip(exps, exp_labels):
    acc_vs_seqlen = result['acc'][exp]
    err_vs_seqlen = result['err'][exp]
    handle = ax.plot(seqlen_list, acc_vs_seqlen, label=label)
    ax.fill_between(seqlen_list, *zip(*err_vs_seqlen), label=label, alpha=0.3)
    legend_handles.append(handle[0])

ax.legend(handles=legend_handles)
ax.set_xlabel('Number of stimuli')
ax.set_ylabel('Accuracy')
fig.set_size_inches(5.5/3,1.7)
plt.title("Heteroassociative Recall", fontdict={'fontweight':'bold'})
fig.tight_layout()
fig.savefig(fname[:-4]+'.pdf')

## (b) Sequence Recall

In [None]:
np.random.seed(1)
dataset = helper.load_dataset('seqrecall')
dataset.visualize(figpath=figpath,figname='seqrecall_dataset')

In [None]:
exps = [
    'seqrecall', 'seqrecall_random_reference',
    'seqrecall_random', 'seqrecall_hopfield'
    ]
exp_labels = [
    'Sequential', 'Random',
    'Trained', 'BAM'
    ]
dset_param = 'n_patterns'
dset_param_range = np.arange(0, 61, 5)
dset_param_range[0] = 2

fname = os.path.join(figpath, 'seqrecall.pkl')
try:
    result = joblib.load(fname)
    print("Loaded pickle file")
except:
    result = helper.get_generalization_curves(
        exps, dset_param, dset_param_range,
        num_iters=40, filesdir=filesdir
        )
    joblib.dump(result, fname)
    print("Ran results and saved to pickle file")    

fig, ax = plt.subplots()
for exp, label in zip(exps, exp_labels):
    x_range = result[exp]['params']
    accs = result[exp]['accs']
    errs = result[exp]['errs']
    ax.plot(x_range, accs, linewidth=1, label=label)
    ax.fill_between(
        x_range,
        [e[0] for e in errs], [e[1] for e in errs],
        label=label, alpha=0.3
        )

ax.set_xlabel('Number of patterns')
ax.set_ylabel('Accuracy')
plt.title("Sequence Recall", fontdict={'fontweight':'bold'})
fig.set_size_inches(5.5/3,1.7)
fig.tight_layout()
fig.savefig(fname[:-4]+'.pdf')
plt.show()

## (c) Copy-Paste

In [None]:
np.random.seed(6)
dataset = helper.load_dataset('copy')
dataset.visualize(figpath=figpath, figname='copy_dataset')

In [None]:
exps = [
    'copy', 'copy_random_reference',
    'copy_random', 'copy_hopfield'
    ]
exp_labels = [
    'Sequential', 'Random',
    'Trained', 'BAM'
    ]
dset_param = 'n_patterns'
dset_param_range = np.arange(1,21)

fname = os.path.join(figpath, 'copy.pkl')
try:
    result = joblib.load(fname)
    print("Loaded pickle file")
except:
    result = helper.get_generalization_curves(
        exps, dset_param, dset_param_range,
        num_iters=40, filesdir=filesdir
        )
    joblib.dump(result, fname)
    print("Ran results and saved to pickle file")    

fig, ax = plt.subplots()
for exp, label in zip(exps, exp_labels):
    params = result[exp]['params']
    accs = result[exp]['accs']
    errs = result[exp]['errs']
    ax.plot(params, accs, linewidth=1, label=label)
    ax.fill_between(
        params,
        [e[0] for e in errs], [e[1] for e in errs],
        label=label, alpha=0.3
        )

ax.set_xlabel('Number of patterns')
ax.set_ylabel('Accuracy')
ax.axvline(10, color="gray", linestyle="dashed") # Maximum seen in training
fig.set_size_inches(5.5/3,1.7)
plt.title("Copy-Paste", fontdict={'fontweight':'bold'})
plt.tight_layout()
fig.savefig(fname[:-4]+'.pdf')
plt.show()