In [None]:
!gdown '1LU5Lcf-wPR9UnOfWNVyXuZRxDQeOzXzX'
!unzip 'demo_accessories.zip'

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
tfkl = tf.keras.layers

import utils

from sklearn import cluster
import os

import tensorflow_datasets as tfds
from matplotlib.gridspec import GridSpec

default_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# Compute distinguishability matrices for the channels of a single trained model

In [None]:
## Create a set of images to be used for the distinguishability matrix
dataset_name = 'fashion_mnist'
number_standard_candle_examples = 300
dset, dset_info = tfds.load(dataset_name, with_info=True, decoders={
        'image': tfds.decode.SkipDecoding()})
dset = dset['train']

dset = dset.map(lambda example: (example['image'], example['label']))

dset = dset.shuffle(1_000_000)  ## shuffle everything before the image decoding

dset = dset.map(
  lambda image, label: (dset_info.features['image'].decode_example(image), label))
dset = dset.map(lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label))

standard_candle_images, standard_candle_labels = next(iter(dset.batch(number_standard_candle_examples)))

In [None]:
model_dir = 'trained_fashion_mnist_beta4/'
print(f'Loading {model_dir}')
number_bottleneck_channels = 10
optimizer = tf.keras.optimizers.Adam()
image_side = 28
image_channels = standard_candle_images.shape[-1]
encoder = tf.keras.Sequential(
  [
  tfkl.Input((image_side, image_side, image_channels)),
  tf.keras.layers.Conv2D(32, 4, strides=2, activation='relu', padding='same'),
  tf.keras.layers.Conv2D(64, 4, strides=2, activation='relu', padding='same'),
  tfkl.Flatten(),
  tfkl.Dense(256, 'relu'),
  tfkl.Dense(2*number_bottleneck_channels)
  ])

decoder = tf.keras.Sequential(
  [
  tf.keras.Input((number_bottleneck_channels,)),
  tfkl.Dense(7*7*32, 'relu'),
  tfkl.Reshape([7, 7, 32]),
  tfkl.Conv2DTranspose(64, 4, strides=2, padding='same', activation='relu'),
  tfkl.Conv2DTranspose(32, 4, strides=2, padding='same', activation='relu'),
  tfkl.Conv2DTranspose(image_channels, 4, padding='same')
  ])

checkpoint = tf.train.Checkpoint(step=tf.Variable(0), encoder=encoder, decoder=decoder)
checkpoint_directory = os.path.join(model_dir, "training_checkpoints")
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
checkpoint_ind = 1
checkpoint.restore(os.path.join(checkpoint_directory, f'ckpt-{checkpoint_ind}')).expect_partial()

## Run a sample through the model to check that it was restored properly
embs = encoder(standard_candle_images[:10])
mus, logvars = tf.split(embs, 2, -1)
reparam = tf.random.normal(mus.shape, mus, tf.exp(logvars/2.))
recon = decoder(reparam)
plt.figure(figsize=(10, 2))
for img_ind in range(10):
  plt.subplot(2, 10, 1+img_ind)
  plt.imshow(standard_candle_images[img_ind])
  plt.axis('off')

  plt.subplot(2, 10, 11+img_ind)
  plt.imshow(tf.nn.sigmoid(recon[img_ind]))
  plt.axis('off')
plt.show()

In [None]:
## Compute the distinguishability matrices for the model's 10 channels and visualize them
inches_per_subplot = 2
plt.figure(figsize=(inches_per_subplot*number_bottleneck_channels, inches_per_subplot))
standard_candle_embs = encoder(standard_candle_images)
embs_mus, embs_logvars = tf.split(standard_candle_embs, 2, -1)
for channel_id in range(number_bottleneck_channels):
  bhat_distance_mat = utils.bhattacharyya_dist_mat(np.reshape(embs_mus[:, channel_id], [number_standard_candle_examples, 1]),
                                          np.reshape(embs_logvars[:, channel_id], [number_standard_candle_examples, 1]))

  plt.subplot(1, number_bottleneck_channels, channel_id+1)
  plt.imshow(np.exp(-bhat_distance_mat)[:100, :100], vmin=0, vmax=1)
  plt.xticks([]); plt.yticks([])
  plt.xlabel(f'Channel {channel_id}', fontsize=14)

plt.suptitle('Bhattacharyya distinguishability matrices for each channel', y=0.95, fontsize=16)
plt.show()

# Compute pairwise NMI and VI values given a set of distinguishability matrices, assess structure

In [None]:
## Load precomputed distinguishability matrices; the last one is the matrix for the class label
bhat_distance_mats = np.load('bhats.npy')
bhats_labels = bhat_distance_mats[-1:]  ## these were stored simply as a distinguishability mat with 0s and 1s
bhats_labels_distances = -np.log(np.clip(bhats_labels, 1e-8, None))
bhat_distance_mats = np.concatenate([bhat_distance_mats[:-1], bhats_labels_distances], 0)
print(f'Distinguishability matrices shape: {bhat_distance_mats.shape}')

print('Computing pairwise similarity between the channels')
pairwise_nmi, pairwise_vi = utils.compute_pairwise_similarities(bhat_distance_mats)

plt.figure(figsize=(10, 5))
for plt_ind, (pairwise_mat, label) in enumerate(zip([pairwise_nmi, pairwise_vi],
                                                    ['NMI', 'VI'])):
  plt.subplot(1, 2, plt_ind+1)
  plt.imshow(pairwise_mat)
  plt.xticks([]); plt.yticks([])
  plt.title(label, fontsize=16)
  plt.colorbar()
plt.show()

In [None]:
OPTICS_MIN_SAMPLES = 20

## Pull off the similarities with the class label
nmi_with_labels = pairwise_nmi[-1]
vi_with_labels = pairwise_vi[-1]
pairwise_nmi = pairwise_nmi[:-1, :-1]
pairwise_vi = pairwise_vi[:-1, :-1]

for plt_ind, (distance_label, distance_mat, display_mat) in enumerate(zip(['VI', 'NMI'],
                                          [np.clip(pairwise_vi, 0, None), -np.log(np.clip(pairwise_nmi, 1e-4, None))],
                                          [pairwise_vi, pairwise_nmi])):

  clustering = cluster.OPTICS(min_samples=OPTICS_MIN_SAMPLES, metric='precomputed', cluster_method='xi', max_eps=np.inf).fit(distance_mat)
  index_reorder = clustering.ordering_
  labels = clustering.labels_[index_reorder]
  reachability = clustering.reachability_[index_reorder]

  fig = plt.figure(figsize=(8, 8))
  ax1 = fig.add_axes([0.3, 0.71, 0.6, 0.2])

  for cluster_id in np.unique(labels):
    Xk = np.arange(distance_mat.shape[0])[labels == cluster_id]
    Rk = reachability[labels == cluster_id]
    if cluster_id >= 0:
      ax1.axvspan(Xk.min(), Xk.max(), color=default_colors[cluster_id%10], alpha=0.6)
  ax1.fill_between(np.arange(distance_mat.shape[0]), reachability, color='#454a4a', zorder=100)
  ax1.set_xlim(0, len(index_reorder))
  ax1.set_ylabel(f'reachability, {distance_label}', fontsize=14)
  ax1.set_ylim(0, None)
  ax1.set_xticks([])

  # Plot the distances to the labels
  ax2 = fig.add_axes([0.09, 0.1, 0.2, 0.6])
  ax2.plot(nmi_with_labels[index_reorder], range(len(index_reorder)), lw=3)
  ax2.set_ylim(0, nmi_with_labels.shape[0])
  ax2.set_xlim(0, 1)
  ax2.set_xlabel('NMI w class label', fontsize=14)
  ax2.invert_xaxis()
  ax2.invert_yaxis()
  # ax2.set_xticks([])
  ax2.set_yticks([])

  # Plot distance matrix.
  axmatrix = fig.add_axes([0.3, 0.1, 0.6, 0.6])
  display_mat_actually = display_mat[index_reorder,:]
  display_mat_actually = display_mat_actually[:,index_reorder]
  im = axmatrix.imshow(display_mat_actually, aspect='auto', origin='upper', cmap=plt.cm.YlGnBu)
  axmatrix.set_xticks([])
  axmatrix.set_yticks([])

  # Plot colorbar.
  axcolor = fig.add_axes([0.91, 0.1, 0.02, 0.6])
  cb = plt.colorbar(im, cax=axcolor)
  cb.set_label(label=distance_label, fontsize=14)
  plt.suptitle('Fashion-MNIST ensemble structure', y=0.94, fontsize=16)

  plt.show()