In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
import matplotlib
from matplotlib.patches import Ellipse

import utils

In [2]:
ensemble_sizes = [2, 3, 5, 10, 15, 20, 30, 50]
number_repeat_training_runs = 5  ## the number of times to repeat the experiment for statistics
fingerprint_size = 200

In [None]:
#### Train the ensemble of weak learners, beta-VAEs with 1D latent spaces
#### Should take a couple seconds each
#### Then store the fingerprint matrices that we will use for the model fusion afterward
bhat_mats = []

display_training_curves = False

thetas_fingerprint = np.linspace(0, 2*np.pi,fingerprint_size, endpoint=False)
pts_fingerprint = np.stack([np.cos(thetas_fingerprint), np.sin(thetas_fingerprint)], -1)

number_training_steps = 3000
batch_size = 2048
beta = 3e-2
learning_rate = 1e-3
data_dimensionality = 2
enc_arch_spec = [256]*2
dec_arch_spec = [256]*2
activation_fn = 'tanh'
number_bottleneck_channels = 1

for trial in range(number_repeat_training_runs*np.max(ensemble_sizes)):
  encoder = tf.keras.Sequential([tf.keras.Input((data_dimensionality,))] + \
                                [tf.keras.layers.Dense(number_units, activation_fn) for number_units in enc_arch_spec] + \
                                [tf.keras.layers.Dense(2*number_bottleneck_channels)])
  decoder = tf.keras.Sequential([tf.keras.Input((number_bottleneck_channels,))] + \
                                [tf.keras.layers.Dense(number_units, activation_fn) for number_units in dec_arch_spec] + \
                                [tf.keras.layers.Dense(data_dimensionality)])
  all_trainable_variables = encoder.trainable_variables + decoder.trainable_variables
  optimizer = tf.keras.optimizers.Adam(learning_rate)
  mse = tf.keras.losses.MeanSquaredError()
  @tf.function
  def train_step(pts):
    with tf.GradientTape() as tape:
      embs_mus, embs_logvars = tf.split(encoder(pts), 2, axis=-1)
      kl = tf.reduce_mean(tf.reduce_sum(0.5 * (tf.square(embs_mus) + tf.exp(embs_logvars) - embs_logvars - 1.), axis=-1), axis=0)
      reparameterized_embs = tf.random.normal(embs_mus.shape, mean=embs_mus, stddev=tf.exp(embs_logvars/2.))
      recon = decoder(reparameterized_embs)
      reconstruction_loss = mse(pts, recon)
      loss = tf.reduce_mean(reconstruction_loss) + beta * kl
    grads = tape.gradient(loss, all_trainable_variables)
    optimizer.apply_gradients(zip(grads, all_trainable_variables))
    return reconstruction_loss, kl

  recon_loss_series, kl_loss_series = [[], []]
  for step in range(number_training_steps):
    batch_theta = np.random.uniform(0, 2*np.pi, size=batch_size)
    batch_pts = np.stack([np.cos(batch_theta), np.sin(batch_theta)], -1)
    recon_loss, kl = train_step(batch_pts)
    recon_loss_series.append(recon_loss.numpy())
    kl_loss_series.append(kl.numpy())

  if display_training_curves:
    plt.figure(figsize=(8, 5))
    plt.plot(recon_loss_series, 'k', lw=2)
    plt.ylabel('Recon', fontsize=15)
    plt.xlabel('Step', fontsize=15)
    plt.gca().twinx().plot(kl_loss_series, lw=2)
    plt.ylabel('KL', color='b', fontsize=15)
    plt.show()

  fingerprint_embs = encoder(pts_fingerprint)
  mus, logvars = tf.split(fingerprint_embs, 2, -1)
  bhat_mats.append(utils.bhattacharyya_dist_mat(mus, logvars))

  print(f'Computed bhat mat for run {len(bhat_mats)}/{number_repeat_training_runs*np.max(ensemble_sizes)}.')

In [7]:
@tf.function
def compute_vi_similarity(bhat1, bhat2):
  i11 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1*2, axis=1))
  i22 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2*2, axis=1))
  i12 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1-bhat2, axis=1))
  return tf.exp(-(i12*2 - i11 - i22))

@tf.function
def compute_nmi(bhat1, bhat2):
  i1 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1, axis=1))
  i2 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2, axis=1))
  i11 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1*2, axis=1))
  i22 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2*2, axis=1))
  i12 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1-bhat2, axis=1))
  return (i1+i2-i12) / tf.sqrt((2*i1-i11)*(2*i2-i22))

@tf.function
def compute_info(bhat1, bhat2):
  i1 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1, axis=1))
  i2 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2, axis=1))
  i12 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1-bhat2, axis=1))
  return i1+i2-i12

In [8]:
def continuity_metric(mus, logvars, percentile=90):
  adj_bhat_distances = np.diag(utils.bhattacharyya_dist_mat(mus, logvars), k=1)
  return np.max(adj_bhat_distances) / np.percentile(adj_bhat_distances, percentile)

In [None]:
display_periodically_during_training = False

fusion_space_dimensionality = 2
number_opt_steps = 20000
continuity_metric_values = []
learning_rate = 3e0
cmap = plt.get_cmap('hsv')
gaussian_display_alpha = 0.5
for sim_method_name, sim_method in zip(['info', 'vi', 'nmi'], [compute_info, compute_vi_similarity, compute_nmi]):
  for ensemble_size in ensemble_sizes:
    for repeat_ind in range(number_repeat_training_runs):
      ## Randomly initialize the posterior parameters for each of the fingerprint points
      mus_var = tf.Variable(tf.random.normal((fingerprint_size, fusion_space_dimensionality), stddev=0.05), trainable=True)
      logvars_var = tf.Variable(tf.zeros((fingerprint_size, fusion_space_dimensionality)), trainable=True)
      all_trainable_variables = [mus_var, logvars_var]
      bhats_to_use = bhat_mats[repeat_ind*np.max(ensemble_sizes):repeat_ind*np.max(ensemble_sizes)+ensemble_size]
      optimizer = tf.keras.optimizers.SGD(learning_rate)
      @tf.function
      def compute_avg_similarity():
        with tf.GradientTape() as tape:
          bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)
          total_similarity = 0.
          for other_bhat in bhats_to_use:
            sim = sim_method(bhat_mat_opt, tf.cast(other_bhat, tf.float32))
            total_similarity -= sim
          loss = total_similarity / ensemble_size
        grads = tape.gradient(loss, all_trainable_variables)
        optimizer.apply_gradients(zip(grads, all_trainable_variables))
        return total_similarity
      sim_series = []
      for opt_step in range(number_opt_steps):
        total_similarity = compute_avg_similarity()
        sim_series.append(total_similarity)
        if display_periodically_during_training and (opt_step % 1000) == 0:
          bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)
          plt.figure(figsize=(10, 5))
          plt.subplot(121)
          plt.imshow(tf.exp(-bhat_mat_opt), vmin=0, vmax=1)
          plt.axis('off')
          plt.subplot(122)
          plt.scatter(*np.float32(mus_var.numpy()).T, c=np.linspace(0, 1, fingerprint_size), cmap='hsv')
          for i in range(fingerprint_size):
            ell = Ellipse(xy=mus_var.numpy()[i],
                          width=2*np.exp(logvars_var.numpy()[i, 0]/2.), height=2*np.exp(logvars_var.numpy()[i, 1]/2.),
                          facecolor=cmap(i/(fingerprint_size-1)), alpha=gaussian_display_alpha, edgecolor='k')
            plt.gca().add_artist(ell)
          plt.show()
      plt.figure(figsize=(7, 4))
      plt.plot(sim_series)
      plt.ylabel(f'Similarity: {sim_method_name}', fontsize=15)
      plt.xlabel('Step', fontsize=15)
      plt.show()

      bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)
      plt.figure(figsize=(10, 5))
      plt.subplot(121)
      plt.imshow(tf.exp(-bhat_mat_opt), vmin=0, vmax=1, cmap='Blues_r')
      plt.axis('off')
      plt.subplot(122)
      plt.scatter(*np.float32(mus_var.numpy()).T, c=np.linspace(0, 1, fingerprint_size), cmap='hsv')
      for i in range(fingerprint_size):
        ell = Ellipse(xy=mus_var.numpy()[i],
                      width=2*np.exp(logvars_var.numpy()[i, 0]/2.), height=2*np.exp(logvars_var.numpy()[i, 1]/2.),
                      facecolor=cmap(i/(fingerprint_size-1)), alpha=gaussian_display_alpha, edgecolor='k')
        plt.gca().add_artist(ell)
      plt.xticks([]); plt.yticks([])
      plt.savefig(f'outs/SGD_{sim_method_name}_{ensemble_size}_{repeat_ind}.svg')
      plt.show()

      continuity_metric_values.append(continuity_metric(mus_var.numpy(), logvars_var.numpy()))

np.savez('fusion_results.npz',
         continuity_metric_values=continuity_metric_values,
         ensemble_sizes=ensemble_sizes)

In [None]:
plt.figure(figsize=(8, 4))
continuity_metric_values = np.reshape(continuity_metric_values, [3, -1, number_repeat_training_runs])
for method_ind, sim_method_name in enumerate(['info', 'vi', 'nmi']):
  plt.errorbar(np.array(ensemble_sizes)*1.1**(method_ind-1),
               np.mean(continuity_metric_values[method_ind], -1),
               yerr=np.std(continuity_metric_values[method_ind], -1),
              ls='-', marker='dos'[method_ind],
               markersize=8, lw=2, elinewidth=4, label=sim_method_name)
plt.xscale('log')
plt.yscale('log')
plt.xticks(ensemble_sizes, ensemble_sizes)
plt.xticks([2, 5, 10, 20, 50], [2, 5, 10, 20, 50])
buffer_factor = 1.15
plt.xlim(ensemble_sizes[0]/buffer_factor, ensemble_sizes[-1]*buffer_factor)
plt.ylim(1, 1000)
plt.xlabel('Ensemble size', fontsize=15)
plt.ylabel('Continuity', fontsize=15)
plt.tick_params(width=2, length=4, which='major')
plt.tick_params(width=2, length=3, which='minor')
plt.legend()
plt.tight_layout()
plt.savefig('fusion_performance_comparison.svg')
plt.show()