In [23]:
%matplotlib widget

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import importlib

import dd_analysis as dda
#importlib.reload(dda)

<module 'dd_analysis' from 'D:\\neuroscience\\concept_learning_IS\\disjoint-domain\\dd_analysis.py'>

In [10]:
# Load data from previous runs
res_paths = {
    'Original': 'data/dd_res_2020-12-03_01-31-32.npz',
    'Merged repr': 'data/merged_repr_dd_res_2020-12-04_05-15-35.npz',
#    'Small item repr': 'data/small_item_repr_dd_res_2020-11-24_18-48-00.npz',
#    'More compressed': 'data/all_ratios_0.5_dd_res_2020-11-24_21-17-07.npz',
#    'Half-size HL': 'data/half_hidden_longer_dd_res_2020-11-25_01-06-37.npz',
    'No item repr': 'data/no_item_repr_reallocate_dd_res_2020-12-04_06-30-32.npz',
    'No context repr': 'data/no_ctx_repr_reallocate_dd_res_2020-12-04_07-48-18.npz',
    'No repr': 'data/no_repr_reallocate_dd_res_2020-12-03_03-24-49.npz',
#     'Short original': 'data/short_save_params_dd_res_2020-11-26_11-44-16.npz',
#     'Short no item repr': 'data/short_save_params_no_item_repr_dd_res_2020-11-26_11-58-59.npz',
#     'Short no context repr': 'data/short_save_params_no_ctx_repr_dd_res_2020-11-26_12-05-47.npz'
}

res_data = {name: dda.get_result_means(path) for name, path in res_paths.items()}

In [11]:
# Get ready to look at some particular runs
curr_runs = ['Original', 'Merged repr', 'No item repr', 'No context repr', 'No repr']
curr_sets = {name: data for name, data in res_data.items() if name in curr_runs}
n_curr = len(curr_runs)

# short_runs = ['Short original', 'Short no item repr', 'Short no context repr']
# short_sets = {name: data for name, data in res_data.items() if name in short_runs}

plt.close('all')

In [4]:
# Plot loss, accuracy, and test accuracy
fig, axs = dda.make_plot_grid(3, 1, ax_dims=(7, 3), ravel=True)

for (label, res), col in zip(curr_sets.items(), mcolors.TABLEAU_COLORS):
    dda.plot_report(axs[0], res, 'loss', label=label)
    dda.plot_report(axs[1], res, 'weighted_acc', label=label)
    dda.plot_report(axs[2], res, 'test_weighted_acc', label=label)

for ax in axs:
    ax.legend()

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [25]:
# Item RSA matrices
inds_to_plot = [0, 1, 4, 11, 12, 30, 50, -1]
fig, axss = dda.auto_subplots(len(inds_to_plot), n_curr, ax_dims=(3, 3))
fig.suptitle('Mean RSA for items (hidden layer)')

for ind, axs in zip(inds_to_plot, axss):
    for (label, res), ax in zip(curr_sets.items(), axs):
        dda.plot_rsa(ax, res, 'item_hidden', ind, title_addon=label)#, item_order='group-outer')

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [62]:
# Corresponding dendrograms
fig, axss = dda.auto_subplots(len(inds_to_plot), n_curr, ax_dims=(4, 3))
fig.suptitle('Item representation similarity')

for ind, axs in zip(inds_to_plot, axss):
    for (label, res), ax in zip(curr_sets.items(), axs):
        dda.plot_repr_dendrogram(ax, res, 'item', ind, title_addon=label)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
# Look at RSA projections onto group, domain, and item models
# First plot the models
n_domains = curr_sets['Original']['net_params']['n_domains']
group_model = dda.get_group_model_rsa(n_domains)
domain_model = dda.get_domain_model_rsa(n_domains)
item_model = dda.get_item_model_rsa(n_domains)

fig, axs = dda.make_plot_grid(3, 3, ax_dims=(4, 4), ravel=True)
for ax, mtype, model in zip(axs, ['Cross-domain group', 'Domain', 'Item'], [group_model, domain_model, item_model]):
    dda.plot_rsa(ax, curr_sets['Original'], 'item', 0, rsa_mat=model)
    ax.set_title(mtype + ' model RDM')
    
fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
# Item hidden layer projections
fig, axs = dda.make_plot_grid(3, 1, ax_dims=(9, 3), ravel=True)

snap_type = 'item_hidden'
model_types=['group', 'domain', 'individual']

for (label, res), col in zip(curr_sets.items(), mcolors.TABLEAU_COLORS):
    dda.plot_rdm_projections(res, snap_type, model_types, axs, color=col, label=label)
    
for ax in axs:
    dda.outside_legend(ax)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
# Context hidden layer projections
fig, axs = dda.make_plot_grid(2, 1, ax_dims=(9, 3), ravel=True)

snap_type = 'context_hidden'
model_types=['domain', 'individual']

for (label, res), col in zip(curr_sets.items(), mcolors.TABLEAU_COLORS):
    dda.plot_rdm_projections(res, snap_type, model_types, axs, color=col, label=label)
    
for ax in axs:
    dda.outside_legend(ax)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [21]:
# Item representation layer projections
fig, axs = dda.make_plot_grid(3, 1, ax_dims=(9, 3), ravel=True)

snap_type = 'item'
model_types=['group', 'domain', 'individual']

for (label, res), col in zip(curr_sets.items(), mcolors.TABLEAU_COLORS):
    # only plot if we actually have an item representation layer
    if 'item' in res['repr_dists']:
        dda.plot_rdm_projections(res, snap_type, model_types, axs, color=col, label=label)
    
for ax in axs:
    dda.outside_legend(ax)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [22]:
# Context representation layer projections
fig, axs = dda.make_plot_grid(2, 1, ax_dims=(9, 3), ravel=True)

snap_type = 'context'
model_types=['domain', 'individual']

for (label, res), col in zip(curr_sets.items(), mcolors.TABLEAU_COLORS):
    # only plot if we actually have a context representation layer
    if 'context' in res['repr_dists']:
        dda.plot_rdm_projections(res, snap_type, model_types, axs, color=col, label=label)
    
for ax in axs:
    dda.outside_legend(ax)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [47]:
# Context RSA matrix
inds_to_plot = [0, 1, 4, 11, 12, 30, 50, -1]
fig, axss = dda.auto_subplots(len(inds_to_plot), n_curr, ax_dims=(4, 4))
fig.suptitle('Mean RSA for contexts')

#ctx_sets = {name: res for name, res in curr_sets.items() if name in ['Original', 'Merged repr', 'No item repr']}

for ind, axs in zip(inds_to_plot, axss):
    for (label, res), ax in zip(curr_sets.items(), axs):
        dda.plot_rsa(ax, res, 'context_hidden', ind, title_addon=label, item_order='domain-inner')

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [60]:
# Corresponding context dendrograms
fig, axss = dda.auto_subplots(len(inds_to_plot), n_curr, ax_dims=(4, 3))
fig.suptitle('Context representation similarity')

for ind, axs in zip(inds_to_plot, axss):
    for (label, res), ax in zip(curr_sets.items(), axs):
        dda.plot_repr_dendrogram(ax, res, 'context', ind, title_addon=label)

fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [79]:
# Item rep RSA over all snapshots (useful maybe for finding interesting change points)
fig, axs = dda.make_plot_grid(len(curr_sets), 2, ax_dims=(5, 5))

for ax, (label, res) in zip(axs, curr_sets.items()):
    n_snaps = len(res['snap_epochs'])
    snap_freq = res['train_params']['snap_freq']
    im = ax.imshow(res['repr_dists']['item']['all'],
                   extent=(-0.5, n_snaps * snap_freq-0.5, n_snaps * snap_freq-0.5, -0.5))
    ax.set_xticks(res['snap_epochs'][::4])
    ax.tick_params(axis='x', labelrotation=45)
    ax.set_yticks(res['snap_epochs'][::4])
    ax.set_xlabel('Epochs')
    ax.set_title(f'Item repr distances over training ({label})')
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [35]:
# MDS of item representations over time
for label, res in curr_sets.items():
    dda.plot_repr_trajectories(res, 'item', dims=3, title_label=label)

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

KeyError: 'item'

In [81]:
for label, res in curr_sets.items():
    dda.plot_repr_trajectories(res, 'context', dims=2, title_label=label)

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [48]:
import importlib
importlib.reload(dda)

fig, axs = dda.make_plot_grid(len(curr_runs), 2, ax_dims=(6, 15))
run_num = 0
snap_index = 5

for (name, res), ax in zip(short_sets.items(), axs.ravel()):
    dda.plot_hl_input_pattern_correlations(ax, res, run_num, snap_index, title_label=name)
    
fig.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …