In [1]:
%matplotlib widget

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import dd_analysis as dda


In [2]:
# Load data from previous runs
res_paths = {
    'Original': 'data/dd_res_2020-11-24_04-03-27.npz',
    'Merged repr': 'data/merged_dd_res_2020-11-24_04-37-58.npz'
}

res_data = {name: dda.get_result_means(path) for name, path in res_paths.items()}

In [3]:
# Get ready to look at some particular runs
curr_runs = ['Original', 'Merged repr']
curr_sets = {name: data for name, data in res_data.items() if name in curr_runs}
n_curr = len(curr_runs)
plt.close('all')

In [4]:
# Plot loss, accuracy, and test accuracy
fig, axs = dda.make_plot_grid(3, 1, ax_dims=(6, 2), ravel=True)

for (label, res), col in zip(curr_sets.items(), mcolors.TABLEAU_COLORS):
    dda.plot_report(axs[0], res, 'loss', label=label)
    dda.plot_report(axs[1], res, 'accuracy', label=label)
    dda.plot_report(axs[2], res, 'test_accuracy', label=label)

for ax in axs:
    ax.legend()

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [26]:
# Item RSA matrices
inds_to_plot = range(0, 10, 2)
fig, axss = dda.auto_subplots(len(inds_to_plot), n_curr, ax_dims=(4, 4))
fig.suptitle('Mean RSA for items')

for ind, axs in zip(inds_to_plot, axss):
    for (label, res), ax in zip(curr_sets.items(), axs):
        im = dda.plot_rsa(ax, res, 'item', ind, title_addon=label)#, item_order='type-outer')
        fig.colorbar(im, ax=ax)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
# Corresponding dendrograms
fig, axss = dda.auto_subplots(len(inds_to_plot), n_curr, ax_dims=(4, 3))
fig.suptitle('Item representation similarity')

for ind, axs in zip(inds_to_plot, axss):
    for (label, res), ax in zip(curr_sets.items(), axs):
        dda.plot_repr_dendrogram(ax, res, 'item', ind, title_addon=label)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
# Context RSA matrix
inds_to_plot = [-1]
fig, axss = dda.auto_subplots(len(inds_to_plot), n_curr, ax_dims=(4, 4))
fig.suptitle('Mean RSA for contexts')

for ind, axs in zip(inds_to_plot, axss):
    for (label, res), ax in zip(curr_sets.items(), axs):
        dda.plot_rsa(ax, res, 'context', ind, title_addon=label, item_order='domain-inner')

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
# Corresponding context dendrograms
fig, axss = dda.auto_subplots(len(inds_to_plot), n_curr, ax_dims=(4, 3))
fig.suptitle('Context representation similarity')

for ind, axs in zip(inds_to_plot, axss):
    for (label, res), ax in zip(curr_sets.items(), axs):
        dda.plot_repr_dendrogram(ax, res, 'context', ind, title_addon=label)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [44]:
# Item rep RSA over all snapshots (useful maybe for finding interesting change points)
fig, axs = dda.make_plot_grid(len(curr_sets), 2, ax_dims=(5, 5))

for ax, (label, res) in zip(axs, curr_sets.items()):
    n_snaps = len(res['snap_epochs'])
    snap_freq = res['train_params']['snap_freq']
    im = ax.imshow(res['repr_dists']['item']['all'],
                   extent=(-0.5, n_snaps * snap_freq-0.5, n_snaps * snap_freq-0.5, -0.5))
    ax.set_xticks(res['snap_epochs'][::4])
    ax.tick_params(axis='x', labelrotation=45)
    ax.set_yticks(res['snap_epochs'][::4])
    ax.set_xlabel('Epochs')
    ax.set_title(f'Item repr distances over training ({label})')
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [36]:
# MDS of item representations over time
for label, res in curr_sets.items():
    dda.plot_repr_trajectories(res, 'item', dims=3, title_label=label)

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [48]:
import importlib
importlib.reload(dda)

for label, res in curr_sets.items():
    dda.plot_repr_trajectories(res, 'context', dims=2, title_label=label)

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …