In [None]:
# default parameters

param = {'USE_MASK': False,
         'GAUSS_MASK_SIGMA': 1.0,
         'IMAGE_FILTER': (-1,1),
         'DOG_KSIZE': (5,5),
         'DOG_SIGMA1': 1.3,
         'DOG_SIGMA2': 2.6,
         'INPUT_SCALE': 1.0,
         'ITER_N': 1,
         'EPOCH_N': 100,
         'CLEAR_SAVED_WEIGHTS': True,
         'IN_DIR': 'data/slex_len3_small',
         'OUT_DIR': 'results/slex_len3_small_results',
         'RF1_SIZE': {'x': 1, 'y': 3},
         'RF1_OFFSET': {'x': 1, 'y': 3},
         'RF1_LAYOUT': {'x': 1, 'y': 7},
         'LEVEL1_MODULE_SIZE': 8,
         'LEVEL2_MODULE_SIZE': 32,
         'ALPHA_R': 0.1,
         'ALPHA_U': 0.1,
         'ALPHA_V': 0.1,
         'ALPHA_DECAY': 1,
         'ALPHA_MIN': 0,
         'TEST_INTERVAL': 10,
         'AF': 'linear'}

In [None]:
%matplotlib inline

import numpy as np
import cv2
import imageio

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from dataset import Dataset
from model import Model

import re
import glob

import multiprocessing as mp
import os
import shutil
from itertools import product

from datetime import datetime
import pytz

In [None]:
GIT_COMMIT_HASH = os.popen('git rev-parse --short HEAD').read().replace('\n', '')
print(GIT_COMMIT_HASH)

In [None]:
print('Start:', datetime.now(pytz.timezone('US/Eastern')).strftime('%c'))

# Model Paramters and Simulations

In [None]:
param = pd.Series(param)
param

In [None]:
%%capture

in_dir = param.IN_DIR
out_dir = param.OUT_DIR

test_set = Dataset(scale=param.INPUT_SCALE,
                   shuffle=False,
                   data_dir=in_dir,
                   rf1_x=param.RF1_SIZE['x'],
                   rf1_y=param.RF1_SIZE['y'],
                   rf1_offset_x=param.RF1_OFFSET['x'],
                   rf1_offset_y=param.RF1_OFFSET['y'],
                   rf1_layout_x=param.RF1_LAYOUT['x'],
                   rf1_layout_y=param.RF1_LAYOUT['y'],
                   use_mask=param.USE_MASK,
                   gauss_mask_sigma=param.GAUSS_MASK_SIGMA,
                   image_filter=param.IMAGE_FILTER,
                   DoG_ksize=param.DOG_KSIZE,
                   DoG_sigma1=param.DOG_SIGMA1,
                   DoG_sigma2=param.DOG_SIGMA2)

test_int = param.TEST_INTERVAL

epoch_n = param.EPOCH_N
zero_pad_len = len(str(epoch_n))

model = Model(dataset=test_set,
              level1_module_size=param.LEVEL1_MODULE_SIZE,
              level2_module_size=param.LEVEL2_MODULE_SIZE)

# parameters
model.iteration = param.ITER_N

model.alpha_r = param.ALPHA_R
model.alpha_u = param.ALPHA_U
model.alpha_v = param.ALPHA_V

model.af = param.AF

In [None]:
%%capture

if param.CLEAR_SAVED_WEIGHTS == True and os.path.exists(out_dir):
    shutil.rmtree(out_dir)

# determining which set of weights to use

out_dir_all = glob.glob(os.path.join(out_dir, '*'))
out_dir_epoch = glob.glob(os.path.join(out_dir, 'epoch_*'))
out_dir_pretrain = os.path.join(out_dir, 'pretraining')

if len(out_dir_epoch) > 0:
    # load weights from previous results
    regex = re.compile(os.path.join(out_dir, 'epoch_(?P<epoch>.*)'))
    epoch_all = [int(regex.match(x).group('epoch')) for x in out_dir_epoch]
    epoch_max_idx = np.argmax(epoch_all)
    epoch_max = epoch_all[epoch_max_idx]
    model.load(out_dir_epoch[epoch_max_idx])
elif out_dir_pretrain not in out_dir_all:
    # save current pretraining weights
    epoch_max = -1
    model.save(out_dir_pretrain)
else:
    # load previous pretraining weights
    epoch_max = -1
    model.save(out_dir_pretrain)

In [None]:
# learning rate decay over epochs
u = np.array([param.ALPHA_U/(param.ALPHA_DECAY**i) for i in range(param.EPOCH_N)])
v = np.array([param.ALPHA_V/(param.ALPHA_DECAY**i) for i in range(param.EPOCH_N)])
u[u < param.ALPHA_MIN] = param.ALPHA_MIN
v[v < param.ALPHA_MIN] = param.ALPHA_MIN
plt.plot(u, label='u');
plt.plot(v, label='v');
plt.legend();

In [None]:
%%capture

for epoch in (x for x in range(epoch_n) if x > epoch_max):
    # learning rate
    model.alpha_u = u[epoch]
    model.alpha_v = v[epoch]
    
    # images are shuffled for each training epoch
    train_set = Dataset(scale=param.INPUT_SCALE,
                        shuffle=True,
                        data_dir=in_dir,
                        rf1_x=param.RF1_SIZE['x'],
                        rf1_y=param.RF1_SIZE['y'],
                        rf1_offset_x=param.RF1_OFFSET['x'],
                        rf1_offset_y=param.RF1_OFFSET['y'],
                        rf1_layout_x=param.RF1_LAYOUT['x'],
                        rf1_layout_y=param.RF1_LAYOUT['y'],
                        use_mask=param.USE_MASK,
                        gauss_mask_sigma=param.GAUSS_MASK_SIGMA,
                        image_filter=param.IMAGE_FILTER,
                        DoG_ksize=param.DOG_KSIZE,
                        DoG_sigma1=param.DOG_SIGMA1,
                        DoG_sigma2=param.DOG_SIGMA2)

    # replaced model.train(train_set)
    for word_i in range(train_set.rf2_patches.shape[0]):
        inputs = train_set.rf2_patches[word_i]
        labels = train_set.labels[word_i]

        # HACK: remove rf2 patches where all values are identical (assuming no variation = silence or no information)
        bool_mask = np.max(inputs, axis=(2,3)) != np.min(inputs, axis=(2,3))
        mask_y = bool_mask.any(axis=1).sum()
        mask_x = bool_mask.any(axis=0).sum()
        
        inputs = inputs[bool_mask].reshape((mask_y, mask_x) + inputs.shape[2:])
        labels = labels[bool_mask].reshape((mask_y, mask_x) + labels.shape[2:])

        output = pd.DataFrame.from_dict(model.apply_input(inputs, labels, train_set, training=True))

    if epoch == 0 or epoch % test_int == test_int-1:
        model.save(os.path.join(out_dir, 'epoch_{:0>{}d}'.format(epoch, zero_pad_len)))

# Inputs

In [None]:
filenames = sorted(glob.glob(os.path.join(in_dir, '*.png')))

regex = re.compile(os.path.join(in_dir, '(?P<index>.*)_(?P<word>.*).png'))
f_index = [regex.match(x).group('index') for x in filenames]
f_word = [regex.match(x).group('word') for x in filenames]

In [None]:
# # randomly choose 10 words to plot
# word_select = sorted(np.random.choice(len(test_set.filtered_images), 10, replace=False))

# or select all words
word_select = list(range(len(test_set.filtered_images)))

# Load Simulation Results as a Dataframe

In [None]:
# load data by epoch for parallelization

def load_data(epoch, iteration):
    
    results = {'epoch': [], 'target': [], 'idx': [], 'last_idx': [], 'iteration': [],
               'node': [], 'node_word': [], 'target_label': [],
               'activation_raw': [], 'activation': [],
               'target_n': [], 'response_n': [], 'accuracy': [],
               'e10': [], 'e21': [], 'e32': [], 'e43': [], 'e11': [], 'e22': [], 'e33': [], 'e_all': []}

    filenames = sorted(glob.glob(os.path.join(out_dir, 'epoch_*')))
    regex = re.compile(os.path.join(out_dir, 'epoch_(?P<pad>0*)(?P<epoch>.+)'))
    f_epoch = [int(regex.match(x).group('epoch')) for x in filenames]
    
    model.load(filenames[f_epoch.index(epoch)])
    
    for word_i in np.ndindex(test_set.rf2_patches.shape[0]):

        inputs = test_set.rf2_patches[word_i]
        labels = test_set.labels[word_i]

        # HACK: remove rf2 patches where all values are identical (assuming no variation = silence or no information)
        bool_mask = np.max(inputs, axis=(2,3)) != np.min(inputs, axis=(2,3))
        mask_y = bool_mask.any(axis=1).sum()
        mask_x = bool_mask.any(axis=0).sum()
        inputs = inputs[bool_mask].reshape((mask_y, mask_x) + inputs.shape[2:])
        labels = labels[bool_mask].reshape((mask_y, mask_x) + labels.shape[2:])

        output = pd.DataFrame.from_dict(model.apply_input(inputs, labels, test_set, training=False))
        
        idx_list = [x for x in output.index if output.iteration[x] == iteration]
        
        for idx in idx_list:
            last_idx = True if idx == max(idx_list) else False
        
            target_n = np.argmax(output.label[idx])

            r3_raw = output.r3[idx].astype(np.float128)
            r3 = np.exp(r3_raw)/np.sum(np.exp(r3_raw))

            if sum(r3 == r3.max()) != 1:
                response_n = None
            else:
                response_n = np.argmax(r3)

            if target_n == response_n:
                accuracy = 1
            else:
                accuracy = 0
                
            e10 = output.e10[idx].flatten().T @ output.e10[idx].flatten()
            e21 = output.e21[idx].flatten().T @ output.e21[idx].flatten()
            e32 = output.e32[idx].flatten().T @ output.e32[idx].flatten()
            e43 = output.e43[idx].flatten().T @ output.e43[idx].flatten()
            e11 = output.e11[idx].flatten().T @ output.e11[idx].flatten()
            e22 = output.e22[idx].flatten().T @ output.e22[idx].flatten()
            e33 = output.e33[idx].flatten().T @ output.e33[idx].flatten()
            e_all = e10 + e21 + e32 + e43 + e11 + e22 + e33

            results['epoch'] += [epoch] * len(r3)
            results['target'] += [f_word[target_n]] * len(r3)
            results['idx'] += [idx] * len(r3)
            results['last_idx'] += [last_idx] * len(r3)
            results['iteration'] += [output.iteration[idx]] * len(r3)
            results['node'] += list(range(len(r3)))
            results['node_word'] += [f_word[x] for x in list(range(len(r3)))]
            results['target_label'] += list(output.label[idx])
            results['activation_raw'] += list(r3_raw)
            results['activation'] += list(r3)
            results['target_n'] += [target_n] * len(r3)
            results['response_n'] += [response_n] * len(r3)
            results['accuracy'] += [accuracy] * len(r3)
            
            results['e10'] += [e10] * len(r3)
            results['e21'] += [e21] * len(r3)
            results['e32'] += [e32] * len(r3)
            results['e43'] += [e43] * len(r3)
            results['e11'] += [e11] * len(r3)
            results['e22'] += [e22] * len(r3)
            results['e33'] += [e33] * len(r3)
            results['e_all'] += [e_all] * len(r3)
    
    return pd.DataFrame.from_dict(results)

In [None]:
%%capture

pool = mp.Pool(8) # may increase number of CPU if available

epoch_list = (x for x in range(epoch_n) if x == 0 or x % test_int == test_int-1)
iter_list = (x for x in range(model.iteration) if x == max(range(model.iteration)))

df_list = pool.starmap(load_data, product(epoch_list, iter_list))
pool.close()

results_df = pd.concat(df_list, ignore_index=True)

In [None]:
def category_given_target(target, word):

    if target == word:
        category = 'target'
    elif target[0:2] == word[0:2]:
        category = 'cohort'
    elif target[1:] == word[1:]:
        category = 'rhyme'
    elif word in target:
        category = 'embedded'
    else:
        category = 'other'

    return category

In [None]:
word_combo_df = pd.DataFrame(list(product(f_word, f_word)), columns=['target','node_word'])

word_combo_df['category'] = word_combo_df.apply(lambda x: category_given_target(x['target'], x['node_word']), axis = 1)
cat_list = ['target', 'cohort', 'rhyme', 'other', 'embedded']
word_combo_df.category = word_combo_df.category.astype('category').cat.set_categories(cat_list)

results_df = pd.merge(results_df, word_combo_df, on=['target', 'node_word'])

In [None]:
# convert to by-target accuracy, as opposed to by-idx
accuracy_by_target = results_df[(results_df.last_idx == True)].groupby(['epoch','target_n'])['accuracy'].mean()
results_df = pd.merge(results_df, accuracy_by_target, on=['epoch','target_n'], suffixes=('_idx', ''))

In [None]:
results_df.to_pickle(os.path.join(out_dir, 'results.pkl'))

In [None]:
# results_df = pd.read_pickle(os.path.join(out_dir, 'results.pkl'))

# Weight Change by Epoch

In [None]:
%%capture

ALL_WEIGHTS = {"epoch": [], "U1": [], "U2": [], "U3": [], "V1": [], "V2": [], "V3": []}

model.load(out_dir_pretrain)

ALL_WEIGHTS["epoch"].append(-1)
ALL_WEIGHTS["U1"].append(model.U1.mean())
ALL_WEIGHTS["U2"].append(model.U2.mean())
ALL_WEIGHTS["U3"].append(model.U3.mean())
ALL_WEIGHTS["V1"].append(model.V1.mean())
ALL_WEIGHTS["V2"].append(model.V2.mean())
ALL_WEIGHTS["V3"].append(model.V3.mean())

for epoch in range(epoch_n):
    if epoch % test_int == test_int-1:
        filenames = sorted(glob.glob(os.path.join(out_dir, 'epoch_*')))
        regex = re.compile(os.path.join(out_dir, 'epoch_(?P<pad>0*)(?P<epoch>.+)'))
        f_epoch = [int(regex.match(x).group('epoch')) for x in filenames]

        model.load(filenames[f_epoch.index(epoch)])
        
        ALL_WEIGHTS["epoch"].append(epoch)
        ALL_WEIGHTS["U1"].append(model.U1.mean())
        ALL_WEIGHTS["U2"].append(model.U2.mean())
        ALL_WEIGHTS["U3"].append(model.U3.mean())
        ALL_WEIGHTS["V1"].append(model.V1.mean())
        ALL_WEIGHTS["V2"].append(model.V2.mean())
        ALL_WEIGHTS["V3"].append(model.V3.mean())
        
ALL_WEIGHTS_DF = pd.DataFrame.from_dict(ALL_WEIGHTS)

In [None]:
ALL_WEIGHTS_DF.groupby('epoch').mean().plot();
plt.legend(loc='lower left');

# Loss by Epoch

In [None]:
results_df.groupby(['epoch'])[['e10','e21','e32','e43','e11','e22','e33','e_all']].mean().plot();
plt.legend(loc='lower left');

# Accuracy by Epoch

In [None]:
results_df[(results_df.last_idx == True)].groupby('epoch')['accuracy'].mean().plot(ylim=(0,1));

# Accuracy by Item

In [None]:
results_df[(results_df.last_idx == True)].groupby('target')['accuracy'].mean().plot.bar(figsize=(40,5), ylim=(0,1));

# Activation of Top 10 Activated Items for a Given Target

In [None]:
# epoch_cutoff = int(epoch_n * 1/5)
epoch_cutoff = 0

## Over Epochs

In [None]:
# subplots settings
ncols = 5
nrows = int(np.ceil(len(word_select)/ncols))
subplot_x, subplot_y = (5,4)

In [None]:
fig, axes = plt.subplots(nrows, ncols, figsize=(subplot_x*ncols, subplot_y*nrows), sharey=True)

for word_i in word_select:
    results_df_select = results_df[(results_df.target_n == word_i) & (results_df.epoch >= epoch_cutoff) & (results_df.last_idx == True)]
    top_10 = results_df_select[(results_df_select.last_idx == True) & (results_df_select.epoch == max(results_df_select.epoch))].sort_values(by=['activation'], ascending=False).node_word[0:10]
    results_df_select = results_df_select.loc[results_df_select.node_word.isin(top_10)]
    
    # order based on average activation values
    results_df_select.node_word = results_df_select.node_word.astype('category').cat.set_categories(top_10)
    
    # plot
    df_select = results_df_select.groupby(['epoch','node_word']).mean()['activation'].unstack()
    df_plot = df_select.plot(title='target: {}'.format(f_word[word_i]), ax=axes[word_select.index(word_i)//ncols, word_select.index(word_i)%ncols]);
    
    # thicken target line
    lws = results_df_select.groupby('node_word').mean()['target_label']*2+1
    
    for i, l in enumerate(df_plot.lines):
        plt.setp(l, linewidth=lws[i])
        
    df_plot.legend(loc='upper left')

plt.tight_layout()
plt.show()

## Over Timesteps at the Last Epoch

In [None]:
fig, axes = plt.subplots(nrows, ncols, figsize=(subplot_x*ncols, subplot_y*nrows), sharey=True)

for word_i in word_select:
    results_df_select = results_df[(results_df.target_n == word_i) & (results_df.epoch == max(results_df.epoch))]
    top_10 = results_df_select[(results_df_select.last_idx == True) & (results_df_select.epoch == max(results_df_select.epoch))].sort_values(by=['activation'], ascending=False).node_word[0:10]
    results_df_select = results_df_select[results_df_select.node_word.isin(top_10)]
    
    # order based on average activation values
    results_df_select.node_word = results_df_select.node_word.astype('category').cat.set_categories(top_10)

    # plot
    df_select = results_df_select.groupby(['idx','node_word']).mean()['activation'].unstack()
    df_plot = df_select.plot(title='target: {}'.format(f_word[word_i]), ax=axes[word_select.index(word_i)//ncols, word_select.index(word_i)%ncols]);
    
    # thicken target line
    lws = results_df_select.groupby('node_word').mean()['target_label']*2+1
    
    for i, l in enumerate(df_plot.lines):
        plt.setp(l, linewidth=lws[i])
        
    df_plot.legend(loc='upper left')

plt.tight_layout()
plt.show()

# Activation by Category

## Average Across All Items Over Epoch

### All Items

In [None]:
sns.lineplot(x="epoch", y="activation", hue="category", hue_order=cat_list, err_style=None,
             data=results_df[(results_df.last_idx == True) & (results_df.epoch >= epoch_cutoff)]);

### Correct Items

In [None]:
sns.lineplot(x="epoch", y="activation", hue="category", hue_order=cat_list, err_style=None,
             data=results_df[(results_df.last_idx == True) & (results_df.epoch >= epoch_cutoff) & (results_df.accuracy == 1)]);

### Incorrect Items

In [None]:
sns.lineplot(x="epoch", y="activation", hue="category", hue_order=cat_list, err_style=None,
             data=results_df[(results_df.last_idx == True) & (results_df.epoch >= epoch_cutoff) & (results_df.accuracy == 0)]);

## Average Across All Items Over Timesteps at the Last Epoch

### All Items

In [None]:
sns.lineplot(x="idx", y="activation", hue="category", hue_order=cat_list, err_style="band",
             data=results_df[(results_df.epoch == max(range(epoch_n)))]);

### Correct Items

In [None]:
sns.lineplot(x="idx", y="activation", hue="category", hue_order=cat_list, err_style="band",
             data=results_df[(results_df.epoch == max(range(epoch_n))) & (results_df.accuracy == 1)]);

### Incorrect Items

In [None]:
sns.lineplot(x="idx", y="activation", hue="category", hue_order=cat_list, err_style="band",
             data=results_df[(results_df.epoch == max(range(epoch_n))) & (results_df.accuracy == 0)]);

## By Item Over Epoch

In [None]:
fig, axes = plt.subplots(nrows, ncols, figsize=(subplot_x*ncols, subplot_y*nrows))

for word_i in word_select:
    results_df_select = results_df[(results_df.target_n == word_i) & (results_df.epoch >= epoch_cutoff) & (results_df.last_idx == True)]
    df_select = results_df_select.groupby(['epoch','category']).mean()['activation'].unstack()
    df_plot = df_select.plot(title='target: {}'.format(f_word[word_i]), ax=axes[word_select.index(word_i)//ncols, word_select.index(word_i)%ncols]);
    df_plot.legend(loc='upper left')

plt.tight_layout()
plt.show()

## Over Timesteps at the Last Epoch

In [None]:
fig, axes = plt.subplots(nrows, ncols, figsize=(subplot_x*ncols, subplot_y*nrows))

for word_i in word_select:
    results_df_select = results_df[(results_df.target_n == word_i) & (results_df.epoch == max(results_df.epoch))]
    df_select = results_df_select.groupby(['idx','category']).mean()['activation'].unstack()
    df_plot = df_select.plot(title='target: {}'.format(f_word[word_i]), ax=axes[word_select.index(word_i)//ncols, word_select.index(word_i)%ncols]);
    df_plot.legend(loc='upper left')

plt.tight_layout()
plt.show()

# Relationships Between Classification Accuracy and Item Properties

In [None]:
target_acc = results_df.groupby('target')['accuracy'].mean()

In [None]:
slex_stats = pd.read_csv(os.path.join(in_dir, 'slex.csv'))

In [None]:
# accuracy by neighborhood density
target_N = slex_stats.sort_values(by=['Phono.orig'])['N_sum'].squeeze().values
plt.scatter(target_N, target_acc);

In [None]:
# accuracy by word length
target_len = slex_stats.sort_values(by=['Phono.orig'])['n_phon'].squeeze().values
plt.scatter(target_len, target_acc);

In [None]:
# accuracy by Pfreq_sum
target_Pfreq_sum = slex_stats.sort_values(by=['Phono.orig'])['Pfreq_sum'].squeeze().values
plt.scatter(target_Pfreq_sum, target_acc);

In [None]:
print('End:', datetime.now(pytz.timezone('US/Eastern')).strftime('%c'))