In [None]:
'''
This code reproduces the results from Fig 2, where similarity measures
are compared for 9 synthetic representation spaces

To evaluate the stochastic shape metrics, please install the netrep package:
https://github.com/ahwillia/netrep
'''

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
tfkl = tf.keras.layers
from sklearn import cluster
import os, time
import PIL

import tensorflow_datasets as tfds
from matplotlib.gridspec import GridSpec
import scipy.ndimage as nim

from matplotlib.patches import Ellipse

from netrep.metrics import GaussianStochasticMetric

default_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
fingerprint_size = 64  ## the size of the dataset used in these examples

In [None]:
#@title Monte Carlo information evaluation
def monte_carlo_info(mus, logvars, number_random_samples=10):
  sample_size = 2000
  chunk_eval_size = 10_000
  info_estimates = []
  emb_dim = mus.shape[-1]
  for rand_sample in range(number_random_samples):
    rand_inds = np.random.choice(mus.shape[0], size=sample_size)
    rand_sample = tf.random.normal(shape=(sample_size, emb_dim),
                                  mean=mus[rand_inds],
                                  stddev=tf.exp(logvars[rand_inds]/2.),
                                   dtype=tf.float64)
    # rand_sample = tf.cast(rand_sample, tf.float64)
    posterior_probs = compute_likelihoods(rand_sample, mus[rand_inds], logvars[rand_inds], diag=True)
    marginal_probs = np.zeros((sample_size))
    for start_ind in range(0, mus.shape[0], chunk_eval_size):
      end_ind = min(start_ind+chunk_eval_size, mus.shape[0])
      marginal_probs = marginal_probs + compute_likelihoods(rand_sample, mus[start_ind:end_ind], logvars[start_ind:end_ind])
    marginal_probs = marginal_probs / mus.shape[0]

    info_estimates.append(tf.math.log(posterior_probs/marginal_probs))
  return np.mean(info_estimates)/np.log(2)


@tf.function(experimental_relax_shapes=True)
def compute_likelihoods(samples, mus, logvars, diag=False):
  mus = tf.cast(mus, tf.float64)
  logvars = tf.cast(logvars, tf.float64)
  sample_size = tf.shape(samples)[0]
  evaluation_batch_size = tf.shape(mus)[0]
  embedding_dimension = tf.shape(mus)[-1]
  stddevs = tf.exp(logvars/2.)
  # Expand dimensions to broadcast and compute the pairwise distances between
  # the sampled points and the centers of the conditional distributions
  samples = tf.reshape(samples,
    [sample_size, 1, embedding_dimension])
  mus = tf.reshape(mus, [1, evaluation_batch_size, embedding_dimension])
  distances_ui_muj = samples - mus

  normalized_distances_ui_muj = distances_ui_muj / tf.reshape(stddevs, [1, evaluation_batch_size, embedding_dimension])
  p_ui_cond_xj = tf.exp(-tf.reduce_sum(normalized_distances_ui_muj**2, axis=-1)/2. - \
    tf.reshape(tf.reduce_sum(logvars, axis=-1), [1, evaluation_batch_size])/2.)
  normalization_factor = (2.*np.pi)**(tf.cast(embedding_dimension, tf.float64)/2.)
  p_ui_cond_xj = p_ui_cond_xj / normalization_factor
  if diag:
    return tf.linalg.diag_part(p_ui_cond_xj)
  else:
    return tf.reduce_sum(p_ui_cond_xj, axis=-1)

In [None]:
#@title Bhattacharyya-based information evaluation
def bhattacharyya_dist_mat(mus, logvars):
  """Computes Bhattacharyya distances between multivariate Gaussians.
  The Bhattacharyya coefficient is the exponentiated negative distance.
  Args:
    mus: [N, d] float array of the means of the Gaussians.
    logvars: [N, d] float array of the log variances of the Gaussians (so we're assuming diagonal
    covariance matrices; these are the logs of the diagonal).
  Returns:
    [N, N] array of distances.
  """
  N = mus.shape[0]
  embedding_dimension = mus.shape[1]

  ## Manually broadcast
  mus1 = np.tile(mus[:, np.newaxis], [1, N, 1])
  logvars1 = np.tile(logvars[:, np.newaxis], [1, N, 1])
  mus2 = np.tile(mus[np.newaxis], [N, 1, 1])
  logvars2 = np.tile(logvars[np.newaxis], [N, 1, 1])
  difference_mus = mus1 - mus2  # [N, M, embedding_dimension]; we want [N, N, embedding_dimension, 1]
  difference_mus = difference_mus[..., np.newaxis]
  difference_mus_T = np.transpose(difference_mus, [0, 1, 3, 2])

  sigma_diag = 0.5 * (np.exp(logvars1) + np.exp(logvars2))  ## [N, N, embedding_dimension], but we want a diag mat [N, N, embedding_dimension, embedding_dimension]
  sigma_mat = np.expand_dims(sigma_diag, -1) * np.expand_dims(np.ones_like(sigma_diag), -2) * np.reshape(np.eye(embedding_dimension), [1, 1, embedding_dimension, embedding_dimension])
  sigma_mat_inv = np.expand_dims(1./sigma_diag, -1) * np.expand_dims(np.ones_like(sigma_diag), -2) * np.reshape(np.eye(embedding_dimension), [1, 1, embedding_dimension, embedding_dimension])

  log_determinant_sigma = np.sum(np.log(sigma_diag), axis=-1)
  log_determinant_sigma1 = np.sum(logvars1, axis=-1)
  log_determinant_sigma2 = np.sum(logvars2, axis=-1)
  term1 = 0.125 * (difference_mus_T @ sigma_mat_inv @ difference_mus).reshape([N, N])
  term2 = 0.5 * (log_determinant_sigma - 0.5 * (log_determinant_sigma1  + log_determinant_sigma2))
  return term1+term2

@tf.function(experimental_relax_shapes=True)
def bhattacharyya_dist_mat_tf(mus, logvars):
  """Computes Bhattacharyya distances between multivariate Gaussians.
  Args:
    mus1: [N, d] float array of the means of the Gaussians.
    logvars1: [N, d] float array of the log variances of the Gaussians (so we're assuming diagonal
    covariance matrices; these are the logs of the diagonal).
  Returns:
    [N, M] array of distances.
  """
  N = tf.shape(mus)[0]
  embedding_dimension = tf.shape(mus)[1]

  mus = tf.cast(mus, tf.float64)
  logvars = tf.cast(logvars, tf.float64)

  ## Manually broadcast in case either M or N is 1
  mus1 = tf.tile(tf.expand_dims(mus, 1), [1, N, 1])
  logvars1 = tf.tile(tf.expand_dims(logvars, 1), [1, N, 1])
  mus2 = tf.tile(tf.expand_dims(mus, 0), [N, 1, 1])
  logvars2 = tf.tile(tf.expand_dims(logvars, 0), [N, 1, 1])
  difference_mus = mus1 - mus2  # [N, M, embedding_dimension]; we want [N, M, embedding_dimension, 1]
  difference_mus = tf.expand_dims(difference_mus, -1)
  difference_mus_T = tf.transpose(difference_mus, [0, 1, 3, 2])

  sigma_diag = 0.5 * (tf.exp(logvars1) + tf.exp(logvars2))  ## [N, M, embedding_dimension], but we want a diag mat [N, M, embedding_dimension, embedding_dimension]
  # sigma_mat = np.apply_along_axis(np.diag, -1, sigma_diag)
  sigma_mat = tf.expand_dims(sigma_diag, -1) * tf.expand_dims(tf.ones_like(sigma_diag, dtype=tf.float64), -2) * tf.reshape(tf.eye(embedding_dimension, dtype=tf.float64), [1, 1, embedding_dimension, embedding_dimension])
  # sigma_mat_inv = np.apply_along_axis(np.diag, -1, 1./sigma_diag)
  sigma_mat_inv = tf.expand_dims(1./sigma_diag, -1) * tf.expand_dims(tf.ones_like(sigma_diag, dtype=tf.float64), -2) * tf.reshape(tf.eye(embedding_dimension, dtype=tf.float64), [1, 1, embedding_dimension, embedding_dimension])

  log_determinant_sigma = tf.reduce_sum(tf.math.log(sigma_diag), axis=-1)
  log_determinant_sigma1 = tf.reduce_sum(logvars1, axis=-1)
  log_determinant_sigma2 = tf.reduce_sum(logvars2, axis=-1)
  term1 = 0.125 * tf.reshape(difference_mus_T @ sigma_mat_inv @ difference_mus, [N, N])
  term2 = 0.5 * (log_determinant_sigma - 0.5 * (log_determinant_sigma1 + log_determinant_sigma2))
  return term1+term2

@tf.function(experimental_relax_shapes=True)
def bhat_info_tf(mus, logvars):
  bhat_dist_mat = bhattacharyya_dist_mat_tf(mus, logvars)
  info = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat_dist_mat))))
  return info

@tf.function
def compute_nmi_bhat_tf(bhat1, bhat2):
  i1 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1), axis=1)))
  i2 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat2), axis=1)))
  i11 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1*2), axis=1)))
  i22 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat2*2), axis=1)))
  i12 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1-bhat2), axis=1)))
  return (i1+i2-i12) / tf.sqrt((2*i1-i11)*(2*i2-i22))

@tf.function
def compute_vi_bhat_tf(bhat1, bhat2):
  i11 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1*2), axis=1)))
  i22 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat2*2), axis=1)))
  i12 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1-bhat2), axis=1)))
  return 2*i12 - i11 - i22

In [None]:
#@title CKA with Bhattacharyya matrices

centering_matrix = np.eye(fingerprint_size) - np.ones((fingerprint_size, fingerprint_size))/fingerprint_size
def compute_cka_bhat(bhat1, bhat2):
  sim11 = np.trace(bhat1 @ centering_matrix @ bhat1 @ centering_matrix)
  sim22 = np.trace(bhat2 @ centering_matrix @ bhat2 @ centering_matrix)
  sim12 = np.trace(bhat1 @ centering_matrix @ bhat2 @ centering_matrix)
  cka = sim12 / np.sqrt(sim11*sim22)
  return cka

@tf.function
def compute_cka_bhat_tf(bhat1, bhat2):
  sim11 = tf.linalg.trace(bhat1 @ centering_matrix @ bhat1 @ centering_matrix)
  sim22 = tf.linalg.trace(bhat2 @ centering_matrix @ bhat2 @ centering_matrix)
  sim12 = tf.linalg.trace(bhat1 @ centering_matrix @ bhat2 @ centering_matrix)
  cka = sim12 / tf.sqrt(sim11*sim22)
  return cka


In [None]:
#@title Generate the 9 representation spaces
sqrt_N = int(np.round(np.sqrt(N)))

cmap = plt.get_cmap('viridis')
alpha = 0.5

x = np.arange(N)

u_mus_all, u_logvars_all = [[], []]

constant_variance_offset = 0.1

###### Spiral: constant variance

spiral_freq = 0.2
u_mus = np.sqrt(x).reshape([-1, 1])*np.stack([np.cos(2*np.pi*np.sqrt(x)*spiral_freq), np.sin(2*np.pi*np.sqrt(x)*spiral_freq)], -1)
u_logvars = np.zeros_like(u_mus) + constant_variance_offset

u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)

###### Bloated spiral

u_mus = u_mus_all[0].copy()
u_logvars = u_logvars_all[0].copy() + 1.25

u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)


###### Bloated bloated spiral

u_mus = u_mus_all[0].copy()
u_logvars = u_logvars_all[0].copy() + 2.5

u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)


###### Square spiral

position = np.zeros(2)
u_mus = [position.copy()]
step_size = 0.5
movements = np.float32([[0, 1],
                        [1, 0],
                        [0, -1],
                        [-1, 0]])

movement_ind = 0
step_ind = 0
side_length = 1
for i in range(N-1):
  position += step_size*movements[movement_ind]
  step_ind += 1
  if step_ind == side_length:
    step_ind = 0
    movement_ind = (movement_ind+1) % 4
    step_size += 0.1
    if not(movement_ind % 2):
      side_length += 1
  u_mus.append(position.copy())
u_mus = np.array(u_mus)

u_logvars = np.zeros_like(u_mus)

u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)


########### more variance
position = np.zeros(2)
u_mus = [position.copy()]
step_size = 0.5
movements = np.float32([[0, 1],
                        [1, 0],
                        [0, -1],
                        [-1, 0]])

movement_ind = 0
step_ind = 0
side_length = 1
for i in range(N-1):
  position += step_size*movements[movement_ind]
  step_ind += 1
  if step_ind == side_length:
    step_ind = 0
    movement_ind = (movement_ind+1) % 4
    step_size += 0.1
    if not(movement_ind % 2):
      side_length += 1
  u_mus.append(position.copy())
u_mus = np.array(u_mus)

u_logvars = np.zeros_like(u_mus) + 2.5

u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)

###### 1D line

u_mus = np.linspace(-N/2, N/2, N).reshape([-1, 1])
u_logvars = np.zeros_like(u_mus)-1.

u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)

###### Discrete: two

u_mus = np.concatenate([np.ones((N//2, 2))*[[-sqrt_N*0.8, 0]],
                            np.ones((N//2, 2))*[[sqrt_N*0.8, 0]]], 0)
u_mus = u_mus + np.random.randn(N, 2)

u_logvars = np.zeros_like(u_mus)+2

u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)

###### Discrete: four

u_mus = np.concatenate([
    np.ones((N//4, 2))*[[-sqrt_N*0.7, sqrt_N*0.7]],
    np.ones((N//4, 2))*[[sqrt_N*0.7, sqrt_N*0.7]],
    np.ones((N//4, 2))*[[sqrt_N*0.7, -sqrt_N*0.7]],
    np.ones((N//4, 2))*[[-sqrt_N*0.7, -sqrt_N*0.7]]
    ], 0)

u_mus = u_mus + np.random.randn(N, 2)*0
u_logvars = np.zeros_like(u_mus)+2

u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)

u_mus = np.concatenate([
    np.ones((N//4, 2))*[[-sqrt_N*0.7, sqrt_N*0.7]],
    np.ones((N//4, 2))*[[-sqrt_N*0.7, -sqrt_N*0.7]],
    np.ones((N//4, 2))*[[sqrt_N*0.7, -sqrt_N*0.7]],
    np.ones((N//4, 2))*[[sqrt_N*0.7, sqrt_N*0.7]]
    ], 0)

vert_variance = 3.

u_mus = u_mus + np.random.randn(N, 2)*0
u_logvars = np.ones((N, 2))*[[0, vert_variance]]+1
u_mus_all.append(u_mus)
u_logvars_all.append(u_logvars)


plt.figure(figsize=(8, 8))
for plt_ind, (mus, logvars) in enumerate(zip(u_mus_all, u_logvars_all)):
  plt.subplot(3, 3, plt_ind+1)
  if mus.shape[1] == 2:
    if plt_ind < 6:
      for i in range(N):
        ell = Ellipse(xy=mus[i],
                      width=2*np.exp(logvars[i, 0]/2.), height=2*np.exp(logvars[i, 1]/2.),
                      facecolor=cmap(i/(N-1)), alpha=alpha, edgecolor='k')
        plt.gca().add_artist(ell)
      plt.ylim(-sqrt_N*1.5, sqrt_N*1.5)
      plt.xlim(-sqrt_N*1.5, sqrt_N*1.5)
    else:
      for i in range(N):
        ell = Ellipse(xy=mus[i],
                      width=2*np.exp(logvars[i, 0]/2.), height=2*np.exp(logvars[i, 1]/2.),
                      facecolor=cmap(i/(N-1)), alpha=1, edgecolor='k')
        plt.gca().add_artist(ell)
      plt.ylim(-sqrt_N*2, sqrt_N*2)
      plt.xlim(-sqrt_N*2, sqrt_N*2)
  else:
    plt_x = np.linspace(-N/2-20, N/2+20, 10000)
    for i in range(20, 40):

      sig = np.exp(logvars[i]/2.)
      plt_y = np.exp(-np.power((plt_x - mus[i]) / sig, 2.0) / 2) /  (np.sqrt(2.0 * np.pi) * sig)
      plt.plot(plt_x, plt_y, lw=4, color=cmap(i/(N-1)))
    plt.xlim(-3, 3)
  plt.axis('off')
plt.tight_layout()

plt.show()

In [None]:
# Now compute the pairwise similarities
ssm_scores = []
nmis_bhat, vis_bhat = [np.eye(len(u_mus_all)), np.eye(len(u_mus_all))]
cka_bhat, cka_reg = [np.eye(len(u_mus_all)), np.eye(len(u_mus_all))]
nmis_mc, vis_mc = [np.eye(len(u_mus_all)), np.eye(len(u_mus_all))]
number_mc_random_samples = 5
fractional_infos = []
for embedding_space_ind1 in range(len(u_mus_all)):
  u_mus = u_mus_all[embedding_space_ind1]
  u_logvars = u_logvars_all[embedding_space_ind1]
  # Only the MC evaluation takes long enough to care about saving some computation by doing the following in the outer loop
  i1 = monte_carlo_info(u_mus, u_logvars, number_random_samples=number_mc_random_samples)
  i11 = monte_carlo_info(np.tile(u_mus, [1, 2]),
                        np.tile(u_logvars, [1, 2]), number_random_samples=number_mc_random_samples)
  fractional_infos.append(i1/np.log2(N))
  for embedding_space_ind2 in range(embedding_space_ind1, len(u_mus_all)):

    v_mus = u_mus_all[embedding_space_ind2]
    v_logvars = u_logvars_all[embedding_space_ind2]

    i2 = monte_carlo_info(v_mus, v_logvars, number_random_samples=number_mc_random_samples)
    i22 = monte_carlo_info(np.tile(v_mus, [1, 2]),
                          np.tile(v_logvars, [1, 2]), number_random_samples=number_mc_random_samples)
    i12 = monte_carlo_info(np.concatenate([u_mus, v_mus], 1),
                          np.concatenate([u_logvars, v_logvars], 1), number_random_samples=number_mc_random_samples)

    nmi = (i1+i2-i12) / tf.sqrt((2*i1-i11)*(2*i2-i22))
    nmis_mc[embedding_space_ind1, embedding_space_ind2] = nmi
    nmis_mc[embedding_space_ind2, embedding_space_ind1] = nmi

    vi = 2*i12 - i11 - i22
    vis_mc[embedding_space_ind1, embedding_space_ind2] = vi
    vis_mc[embedding_space_ind2, embedding_space_ind1] = vi

    bhat1 = bhattacharyya_dist_mat_tf(u_mus, u_logvars)
    bhat2 = bhattacharyya_dist_mat_tf(v_mus, v_logvars)

    nmi = compute_nmi_bhat_tf(bhat1, bhat2)
    nmis_bhat[embedding_space_ind1, embedding_space_ind2] = nmi
    nmis_bhat[embedding_space_ind2, embedding_space_ind1] = nmi

    vi = compute_vi_bhat_tf(bhat1, bhat2)
    vis_bhat[embedding_space_ind1, embedding_space_ind2] = vi
    vis_bhat[embedding_space_ind2, embedding_space_ind1] = vi
    cka = compute_cka_bhat_tf(tf.exp(-bhat1), tf.exp(-bhat2))
    cka_bhat[embedding_space_ind1, embedding_space_ind2] = cka
    cka_bhat[embedding_space_ind2, embedding_space_ind1] = cka


Xs = []
for embedding_space_ind in range(len(u_mus_all)):
  mus = u_mus_all[embedding_space_ind]
  logvars = u_logvars_all[embedding_space_ind]
  if mus.shape[1] != 2:  ## Since the netrep code does not allow different dimensionalities, just fill the values in with something valid
    mus = u_mus_all[embedding_space_ind-1]
    logvars = u_logvars_all[embedding_space_ind-1]
  covs = np.apply_along_axis(np.diag, 1, np.exp(-logvars))
  X = (mus, covs)
  Xs.append(X)

ct = time.time()
alpha = 1  ## the Wasserstein thing comparing means and covariances
metric = GaussianStochasticMetric(alpha, init='rand', n_restarts=50)
dist_matrix, _ = metric.pairwise_distances(Xs)
ssm_scores = dist_matrix
## Fill in the NAN values
one_dim_ind = 5
ssm_scores[one_dim_ind] = np.nan
ssm_scores[:, one_dim_ind] = np.nan
ssm_scores[one_dim_ind, one_dim_ind] = 0.
print(f'Computed SSM scores, time taken: {time.time()-ct:.3f} sec')

In [None]:
labels = ['nmi_mc', 'nmi_bhat', 'cka_bhat', 'vi_mc', 'vi_bhat', 'SSM']
similarities = [nmis_mc, nmis_bhat, cka_bhat, vis_mc, vis_bhat, SSM_scores]
cmaps = ['Blues', 'Blues', 'Blues', 'Reds_r', 'Reds_r', 'Reds_r']
vmins = [0, 0, 0, 0, 0, 0]
vmaxes = [1, 1, 1, 3.2, 3.2, None]
plt.figure(figsize=(8, 5))
for plt_ind, (similarity_values, label, cmap, vmin, vmax) in enumerate(zip(similarities, labels, cmaps, vmins, vmaxes)):
  plt.subplot(2, 3, plt_ind+1)
  plt.imshow(np.reshape(similarity_values, [len(u_mus_all), -1]), vmin=vmin, vmax=vmax, cmap=cmap)
  plt.colorbar()
  plt.title(label, fontsize=15)
  plt.xticks(np.arange(len(u_mus_all)), 'abcdefghi')
  plt.yticks(np.arange(len(u_mus_all)), 'abcdefghi')
  plt.tick_params(axis='both', which='both',length=0)
plt.show()