In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
tfkl = tf.keras.layers
import matplotlib

import utils

In [None]:
bhat_mats = []
ensemble_sizes = [2, 3, 5, 10, 15, 20, 30, 50]
number_repeat_training_runs = 5  ## the number of times to repeat the experiment for statistics

## "Standard candle" is the codename for the set of data points with which to compute the Bhattacharyya matrices
standard_candle_size = 200
pts_standard_candle = np.linspace(0, 2*np.pi,standard_candle_size, endpoint=False)
pts_standard_candle = np.stack([np.cos(pts_standard_candle), np.sin(pts_standard_candle)], -1)

number_training_steps = 3000
batch_size = 2048
beta = 3e-2
learning_rate = 1e-3
data_dimensionality = 2
enc_arch_spec = [256]*2
dec_arch_spec = [256]*2
activation_fn = 'tanh'
number_bottleneck_channels = 1

for trial in range(number_repeat_training_runs*np.max(ensemble_sizes)):
  encoder = tf.keras.Sequential([tf.keras.Input((data_dimensionality,))] + \
                                [tfkl.Dense(number_units, activation_fn) for number_units in enc_arch_spec] + \
                                [tfkl.Dense(2*number_bottleneck_channels)])
  decoder = tf.keras.Sequential([tf.keras.Input((number_bottleneck_channels,))] + \
                                [tfkl.Dense(number_units, activation_fn) for number_units in dec_arch_spec] + \
                                [tfkl.Dense(data_dimensionality)])
  all_trainable_variables = encoder.trainable_variables + decoder.trainable_variables
  optimizer = tf.keras.optimizers.Adam(learning_rate)
  mse = tf.keras.losses.MeanSquaredError()
  @tf.function
  def train_step(pts):
    with tf.GradientTape() as tape:
      embs_mus, embs_logvars = tf.split(encoder(pts), 2, axis=-1)
      kl = tf.reduce_mean(tf.reduce_sum(0.5 * (tf.square(embs_mus) + tf.exp(embs_logvars) - embs_logvars - 1.), axis=-1), axis=0)
      reparameterized_embs = tf.random.normal(embs_mus.shape, mean=embs_mus, stddev=tf.exp(embs_logvars/2.))
      recon = decoder(reparameterized_embs)
      reconstruction_loss = mse(pts, recon)
      loss = tf.reduce_mean(reconstruction_loss) + beta * kl
    grads = tape.gradient(loss, all_trainable_variables)
    optimizer.apply_gradients(zip(grads, all_trainable_variables))
    return reconstruction_loss, kl

  recon_loss_series, kl_loss_series = [[], []]
  for step in range(number_training_steps):
    batch_theta = np.random.uniform(0, 2*np.pi, size=batch_size)
    batch_pts = np.stack([np.cos(batch_theta), np.sin(batch_theta)], -1)
    recon_loss, kl = train_step(batch_pts)
    recon_loss_series.append(recon_loss.numpy())
    kl_loss_series.append(kl.numpy())

  plt.figure(figsize=(8, 5))
  plt.plot(recon_loss_series, 'k', lw=2)
  plt.ylabel('Recon', fontsize=15)
  plt.xlabel('Step', fontsize=15)
  plt.gca().twinx().plot(kl_loss_series, lw=2)
  plt.ylabel('KL', color='b', fontsize=15)
  plt.show()

  standard_candle_embs = encoder(pts_standard_candle)
  mus, logvars = tf.split(standard_candle_embs, 2, -1)
  bhat_mats.append(utils.bhattacharyya_dist_mat(mus, logvars))

  print('Computed bhat mat, onto the next run.')

In [None]:
@tf.function
def compute_vi_similarity(bhat1, bhat2):
  i11 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1*2, axis=1))
  i22 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2*2, axis=1))
  i12 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1-bhat2, axis=1))
  return tf.exp(-(i12*2 - i11 - i22))

@tf.function
def compute_nmi(bhat1, bhat2):
  i1 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1, axis=1))
  i2 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2, axis=1))
  i11 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1*2, axis=1))
  i22 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2*2, axis=1))
  i12 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1-bhat2, axis=1))
  return (i1+i2-i12) / tf.sqrt((2*i1-i11)*(2*i2-i22))

In [None]:
def uniformity_metric(mus):
  adj_vecs = mus - np.roll(mus, 1, axis=0)
  distances = np.linalg.norm(adj_vecs, axis=-1, ord=2)
  return np.sqrt(np.mean(distances**2)) / np.mean(distances)

In [None]:
re_emb_dim = 2
number_opt_steps = 20000
uniformity_metric_values = []
learning_rate = 3e0
for sim_method_name, sim_method in zip(['vi', 'nmi'], [compute_vi_similarity, compute_nmi]):
  for ensemble_size in ensemble_sizes:
    for repeat_ind in range(number_repeat_training_runs):
      mus_var = tf.Variable(tf.random.normal((standard_candle_size, re_emb_dim), stddev=0.05), trainable=True)
      logvars_var = tf.Variable(tf.zeros((standard_candle_size, re_emb_dim)), trainable=True)
      all_trainable_variables = [mus_var, logvars_var]
      bhats_to_use = bhat_mats[repeat_ind*np.max(ensemble_sizes):repeat_ind*np.max(ensemble_sizes)+ensemble_size]
      optimizer = tf.keras.optimizers.SGD(learning_rate)
      @tf.function
      def compute_sum_dist():
        with tf.GradientTape() as tape:
          bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)
          total_dist = 0.
          for other_bhat in bhats_to_use:
            dist = sim_method(bhat_mat_opt, tf.cast(other_bhat, tf.float32))
            total_dist -= dist
          loss = total_dist / ensemble_size
        grads = tape.gradient(loss, all_trainable_variables)
        optimizer.apply_gradients(zip(grads, all_trainable_variables))
        return total_dist
      dist_series = []
      for opt_step in range(number_opt_steps):
        total_dist = compute_sum_dist()
        dist_series.append(total_dist)
        if (opt_step % 1000) == 0:
          bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)
          plt.figure(figsize=(10, 5))
          plt.subplot(121)
          plt.imshow(tf.exp(-bhat_mat_opt), vmin=0, vmax=1)
          plt.axis('off')
          plt.subplot(122)
          plt.scatter(*np.float32(mus_var.numpy()).T, c=np.linspace(0, 1, standard_candle_size), cmap='hsv')
          plt.show()
      plt.figure(figsize=(7, 4))
      plt.plot(dist_series)
      plt.ylabel('Distance', fontsize=15)
      plt.xlabel('Step', fontsize=15)
      plt.show()

      bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)
      plt.figure(figsize=(10, 5))
      plt.subplot(121)
      plt.imshow(tf.exp(-bhat_mat_opt), vmin=0, vmax=1, cmap='Blues_r')
      plt.axis('off')
      plt.subplot(122)
      plt.scatter(*np.float32(mus_var.numpy()).T, c=np.linspace(0, 1, standard_candle_size), cmap='hsv')
      plt.xticks([]); plt.yticks([])
      plt.show()

      uniformity_metric_values.append(uniformity_metric(np.float32(mus_var.numpy())))



In [None]:
plt.figure(figsize=(4, 6))
uniformity_metric_values = np.reshape(uniformity_metric_values, [2, -1, number_repeat_training_runs])
for method_ind, sim_method_name in enumerate(['vi', 'nmi']):
  plt.errorbar(ensemble_sizes, np.mean(uniformity_metric_values[method_ind], -1)-1., yerr=np.std(uniformity_metric_values[method_ind], -1),
               ls='-', marker='os'[method_ind], markersize=12, lw=4, label=sim_method_name)
plt.xscale('log')
plt.yscale('log')
plt.xticks(ensemble_sizes, ensemble_sizes)
plt.xlabel('Ensemble size', fontsize=15)
plt.ylabel('Uniformity metric-1', fontsize=15)
plt.legend()
plt.show()