In [None]:
%matplotlib widget

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from mpl_toolkits.mplot3d import Axes3D
from scipy.cluster import hierarchy
from scipy.spatial import distance
from sklearn.manifold import MDS, TSNE, Isomap, LocallyLinearEmbedding as LLE
from sklearn.decomposition import PCA

import disjoint_domain as dd

Get some basic info for use below (e.g. item, context names)

In [None]:
items_per_domain = dd.ITEMS_PER_DOMAIN
ctx_per_domain, n_domains, n_items, n_ctx, attrs_per_context = dd.get_net_defaults()

# Individual item/context tensors for evaluating the network
items, item_names = dd.get_items(n_domains=n_domains)
contexts, context_names = dd.get_contexts(n_domains=n_domains, ctx_per_domain=ctx_per_domain)

In [None]:
# Load data from previous runs
results_file = 'data/dd_snaps_2020-11-06_01-03-28.npz'

with np.load(results_file, allow_pickle=True) as res:
    snaps = res['snapshots'].item()
    snap_epochs = res['snap_epochs']

    reports = res['reports'].item()
    report_freq = res['report_freq']

report_epochs = np.arange(0, report_freq * len(reports['loss']), report_freq)

# mean over iterations
def mean_dists(snapshots):
    # calculate full pdist
    num_runs, num_snaps, num_items, num_rep = np.array(snapshots).shape
    snapshots_flat = np.reshape(snapshots, (num_runs, num_snaps * num_items, num_rep))
    dists_all = np.stack([distance.pdist(snaps) for snaps in snapshots_flat])
    mean_dists_all = distance.squareform(np.nanmean(dists_all, axis=0))

    # get pdist for individual snapshots
    mean_dists_snaps = np.ndarray((num_snaps, num_items, num_items))
    for kS in range(n_snaps):
        this_slice = slice(kS * num_items, (kS+1) * num_items)
        mean_dists_snaps[kS] = mean_dists_all[this_slice, this_slice]

    return mean_dists_all, mean_dists_snaps

mean_idists_all, mean_idists_snaps = mean_dists(snaps['item'])
mean_cdists_all, mean_cdists_snaps = mean_dists(snaps['context'])

mean_loss = np.mean(reports['loss'], axis=0)
mean_acc = np.mean(reports['accuracy'], axis=0)

plt.close('all')

In [None]:
# Item RSA matrix
ind_to_plot = -1

fig, ax = plt.subplots(figsize=(7, 7))
ax.matshow(mean_idists_snaps[ind_to_plot])
ax.set_xticks(range(n_items))
ax.set_xticklabels(item_names)
ax.tick_params(axis='x', top=False, bottom=True, labeltop=False, labelbottom=True, labelrotation=45)
ax.set_yticks(range(n_items))
ax.set_yticklabels(item_names)

ax.set_title('RSA for item representations after training')

plt.show()

In [None]:
# Dendrogram
z = hierarchy.linkage(distance.squareform(mean_idists_snaps[ind_to_plot]), optimal_ordering=True)
plt.figure(figsize=(8, 3))
hierarchy.dendrogram(z, labels=item_names, count_sort=True)
plt.show()

In [None]:
# Context RSA matrix

# permutation to put all 1s together, etc.
inds = np.reshape(np.arange(n_ctx, dtype=int), (ctx_per_domain, n_domains))
perm = inds.T.ravel()

_, ax = plt.subplots(figsize=(4, 4))
mean_cdists_reordered = mean_cdists_snaps[ind_to_plot][np.ix_(perm, perm)]
cnames_reordered = [context_names[i] for i in perm]

ax.matshow(mean_cdists_reordered)
ax.set_xticks(range(n_ctx))
ax.set_xticklabels(cnames_reordered)
ax.tick_params(axis='x', top=False, bottom=True, labeltop=False, labelbottom=True, labelrotation=45)
ax.set_yticks(range(n_ctx))
ax.set_yticklabels(cnames_reordered)

ax.set_title('RSA for context representations after training')

plt.show()

In [None]:
# Context dendrogram
z = hierarchy.linkage(distance.squareform(mean_cdists_reordered), optimal_ordering=True)
plt.figure(figsize=(8, 3))
hierarchy.dendrogram(z, labels=cnames_reordered, count_sort=True)
plt.show()

In [None]:
# Composite plot of item RSAs

_, ax_all = plt.subplots(4, 2, figsize=(10, 18))
ax_all = ax_all.ravel()

for k in range(len(ax_all)):
    ax = ax_all[k]
    ind = k * 6
    ax.matshow(mean_idists_snaps[ind])
    ax.set_xticks(range(n_items))
    ax.set_xticklabels(item_names)
    ax.tick_params(axis='x', top=False, bottom=True, labeltop=False, labelbottom=True, labelrotation=45)
    ax.set_yticks(range(n_items))
    ax.set_yticklabels(item_names)
    ax.set_title(f'Item RSA after {snap_epochs[ind]} epochs')

fig.tight_layout()
plt.show()

In [None]:
# MDS of item representations over time
item_dims = 3
embedding = MDS(n_components=item_dims, dissimilarity='precomputed')

# have to fit transform to all 32 items over all snapshots - could take a while
n_snaps = len(mean_idists_snaps)
reprs_embedded = embedding.fit_transform(mean_idists_all)
reprs_embedded = reprs_embedded.reshape((n_snaps, -1, item_dims))

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection=('3d' if item_dims==3 else None))

kI = 0
markers = ['o', 's', '*']
for _, col in zip(range(n_domains), mcolors.TABLEAU_COLORS):
    for rel_item in range(items_per_domain):
        linestyle = markers[dd.item_group(rel_item)] + '-'
        ax.plot(*[reprs_embedded[:, kI, d] for d in range(item_dims)],
               linestyle, label=item_names[kI], markersize=4, color=col,
               linewidth=0.5)
        kI += 1

# add start and end markers
kI = 0
for _, col in zip(range(n_domains), mcolors.TABLEAU_COLORS):
    for rel_item in range(items_per_domain):
        marker = markers[dd.item_group(rel_item)]
        ax.plot(*[reprs_embedded[0, kI, d] for d in range(item_dims)], 'g' + marker,
           markersize=8)
        ax.plot(*[reprs_embedded[0, kI, d] for d in range(item_dims)], marker,
           markersize=5, color=col)
        ax.plot(*[reprs_embedded[-1, kI, d] for d in range(item_dims)], 'k' + marker,
           markersize=8)
        ax.plot(*[reprs_embedded[-1, kI, d] for d in range(item_dims)], marker,
           markersize=5, color=col)

        # special point
        ax.plot(*[reprs_embedded[32, kI, d] for d in range(item_dims)], 'm' + marker,
           markersize=10)

        kI += 1

#ax.legend()
ax.set_title('MDS of item representations; color = domain, marker = item type')
plt.show()

In [None]:
# MDS of context representation over time
ctx_dims = 3
embedding = MDS(n_components=ctx_dims, dissimilarity='precomputed')

n_snaps = len(mean_cdists_snaps)
ctx_reprs_embedded = embedding.fit_transform(mean_cdists_all)
ctx_reprs_embedded = ctx_reprs_embedded.reshape((n_snaps, -1, ctx_dims))

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection=('3d' if ctx_dims==3 else None))

kI = 0
markers = ['o', 's', '*', '^']

for _, col in zip(range(n_domains), mcolors.TABLEAU_COLORS):
    for rel_ctx in range(ctx_per_domain):
        linestyle = markers[rel_ctx] + '-'
        ax.plot(*[ctx_reprs_embedded[:, kI, d] for d in range(ctx_dims)],
               linestyle, label=item_names[kI], markersize=4, color=col,
               linewidth=0.5)
        kI += 1

# add start and end markers
kI = 0
for _, col in zip(range(n_domains), mcolors.TABLEAU_COLORS):
    for rel_ctx in range(ctx_per_domain):
        marker = markers[rel_ctx]
        ax.plot(*[ctx_reprs_embedded[0, kI, d] for d in range(ctx_dims)], 'g' + marker,
           markersize=8)
        ax.plot(*[ctx_reprs_embedded[0, kI, d] for d in range(ctx_dims)], marker,
           markersize=5, color=col)
        ax.plot(*[ctx_reprs_embedded[-1, kI, d] for d in range(ctx_dims)], 'k' + marker,
           markersize=8)
        ax.plot(*[ctx_reprs_embedded[-1, kI, d] for d in range(ctx_dims)], marker,
           markersize=5, color=col)

        # special point
        ax.plot(*[ctx_reprs_embedded[33, kI, d] for d in range(ctx_dims)], 'm' + marker,
           markersize=10)

        kI += 1

#ax.legend()
ax.set_title('MDS of context representations; color = domain, marker = context type')
plt.show()

In [None]:
# Plot loss and accuracy
_, ax = plt.subplots(2, 1)

ax[0].plot(range(0, report_freq), mean_loss, '.-')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Mean loss')

ax[1].plot(report_epochs, mean_acc, '.-')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Mean accuracy')

plt.show()

In [None]:
print(snap_epochs)
print(np.nonzero(snap_epochs > 700)[0][0])

In [None]:
print(snap_epochs)
print(np.nonzero(snap_epochs > 700)[0][0])

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection=('3d' if item_dims==3 else None))

kI = 0
markers = ['o', 's', '*']
for _, col in zip(range(n_domains), mcolors.TABLEAU_COLORS):
    for rel_item in range(n_items):
        linestyle = markers[dd.item_group(rel_item)] + '-'
        ax.plot(*[reprs_embedded[:, kI, d] for d in range(item_dims)],
               linestyle, label=item_names[kI], markersize=4, color=col,
               linewidth=0.5)
        kI += 1

# add start and end markers
kI = 0
for _, col in zip(range(n_domains), mcolors.TABLEAU_COLORS):
    for rel_item in range(n_items):
        marker = markers[dd.item_group(rel_item)]
        ax.plot(*[reprs_embedded[0, kI, d] for d in range(item_dims)], 'g' + marker,
           markersize=8)
        ax.plot(*[reprs_embedded[0, kI, d] for d in range(item_dims)], marker,
           markersize=5, color=col)
        ax.plot(*[reprs_embedded[-1, kI, d] for d in range(item_dims)], 'k' + marker,
           markersize=8)
        ax.plot(*[reprs_embedded[-1, kI, d] for d in range(item_dims)], marker,
           markersize=5, color=col)

        # special point
        ax.plot(*[reprs_embedded[32, kI, d] for d in range(item_dims)], 'm' + marker,
           markersize=10)

        kI += 1

#ax.legend()
ax.set_title('MDS of item representations; color = domain, marker = item type')
plt.show()

In [None]:
# MDS of context representation over time
ctx_dims = 3
embedding = MDS(n_components=ctx_dims, dissimilarity='precomputed')

n_snaps = len(mean_cdists_snaps)
ctx_reprs_embedded = embedding.fit_transform(mean_cdists_all)
ctx_reprs_embedded = ctx_reprs_embedded.reshape((n_snaps, -1, ctx_dims))

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection=('3d' if ctx_dims==3 else None))

kI = 0
markers = ['o', 's', '*', '^']

for _, col in zip(range(n_domains), mcolors.TABLEAU_COLORS):
    for rel_ctx in range(n_ctx // n_domains):
        linestyle = markers[rel_ctx] + '-'
        ax.plot(*[ctx_reprs_embedded[:, kI, d] for d in range(ctx_dims)],
               linestyle, label=item_names[kI], markersize=4, color=col,
               linewidth=0.5)
        kI += 1

# add start and end markers
kI = 0
for _, col in zip(range(n_domains), mcolors.TABLEAU_COLORS):
    for rel_ctx in range(n_ctx // n_domains):
        marker = markers[rel_ctx]
        ax.plot(*[ctx_reprs_embedded[0, kI, d] for d in range(ctx_dims)], 'g' + marker,
           markersize=8)
        ax.plot(*[ctx_reprs_embedded[0, kI, d] for d in range(ctx_dims)], marker,
           markersize=5, color=col)
        ax.plot(*[ctx_reprs_embedded[-1, kI, d] for d in range(ctx_dims)], 'k' + marker,
           markersize=8)
        ax.plot(*[ctx_reprs_embedded[-1, kI, d] for d in range(ctx_dims)], marker,
           markersize=5, color=col)

        # special point
        ax.plot(*[ctx_reprs_embedded[33, kI, d] for d in range(ctx_dims)], 'm' + marker,
           markersize=10)

        kI += 1

#ax.legend()
ax.set_title('MDS of context representations; color = domain, marker = context type')
plt.show()

In [None]:
# Plot loss and accuracy
fig, ax = plt.subplots(2, 1)

ax[0].plot(report_epochs, mean_loss, '.-')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Mean loss')

ax[1].plot(report_epochs, mean_acc, '.-')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Mean accuracy')

plt.show()

In [None]:
print(snap_epochs)
print(np.nonzero(snap_epochs > 700)[0][0])
