Copyright 2021 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

      https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Approximate Bijective Correspondence (ABC) with MNIST

ABC seeks correspondence between input sets of data which have been grouped by inactive factor of variation.  In the case of MNIST, the data has been grouped by digit class, leaving style as the active factor of variation to embed.

In [None]:
#@title Imports
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow_datasets as tfds
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from sklearn.decomposition import PCA

tfkl = tf.keras.layers

In [None]:
# The number of images of a digit in each of the two sets
# This impacts the level of detail to which the network is sensitive.  You could
# imagine finding correspondence between stacks of 4 would require coarser
# details than stacks of 64.
stack_size = 64

# The dimension of the embedding space
num_latent_dims = 8

# The similarity type to use as the distance metric in embedding space
# Available options:
# l2 : Negative Euclidean distance
# l2sq : Negative squared Euclidean distance
# l1 : Neggative L1 distance ('Manhattan' distance)
# linf : ord = inf distance (the max displacement along any coordinate), negated
# cosine : cosine similarity, bounded between -1 and 1
similarity_type = 'l2sq'

# The digit to withold during training
test_digit = 9

optimizer_name = 'adam'
lr = 1e-4
num_steps = 500
temperature = 1.  # essentially ineffective unless using cosine similarity, just sets the length scale in embedding space
imgs_to_plot = 400  # for displaying the embeddings via pca
output_plots_during_training = True
output_loss_every = 100

In [None]:
#@title Load MNIST into 10 digit-specific datasets
dset = tfds.load('mnist', split='train')
dset = dset.map(lambda example: (tf.cast(example['image'], tf.float32)/255., example['label']),
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

ds = [dset.filter(lambda x, y: y==i) for i in range(10)]
ds = [d.map(lambda x, y: x).shuffle(1000).repeat().batch(stack_size) for d in ds]

dset_test = tfds.load('mnist', split='test')
dset_test = dset_test.map(lambda example: (tf.cast(example['image'], tf.float32)/255., example['label']),
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

ds_test = [dset_test.filter(lambda x, y: y==i) for i in range(10)]
ds_test = [d.map(lambda x, y: x).batch(stack_size) for d in ds_test]

# Combine stacks from different digits randomly (sometimes a digit is paired with itself but this does not derail training).
# The shape of each element is [2, stack_size, 28, 28, 1].
combined_dset = tf.data.experimental.sample_from_datasets(ds[:test_digit]+ds[test_digit+1:]).batch(2)

In [None]:
#@title The data is grouped by class label; this is all the supervision needed to learn about writing style.
for d in ds:
  for img_stack in d.take(1):
    plt.figure(figsize=(9, 1))
    for j in range(8):
      plt.subplot(1, 8, j+1)
      plt.imshow(img_stack[j, ..., 0], cmap='binary')
      plt.xticks([]); plt.yticks([])
    plt.show()

In [None]:
# The embedding model
model = tf.keras.Sequential([
                             tfkl.Input(shape=(28, 28, 1)),
                             tfkl.Conv2D(32, 3, activation='relu'),
                             tfkl.Conv2D(32, 3, activation='relu'),
                             tfkl.Conv2D(32, 3, activation='relu', strides=2),
                             tfkl.Conv2D(32, 3, activation='relu'),
                             tfkl.Conv2D(32, 3, activation='relu'),
                             tfkl.Flatten(),
                             tfkl.Dense(128, activation='relu'),
                             tfkl.Dense(num_latent_dims, activation='linear'),
                             ])
print(model.summary())

In [None]:
#@title Run PCA on the output of the untrained model (for comparison).
plot_imgs = []
embs_pre = []
for img_stack in ds[test_digit].take(imgs_to_plot//stack_size):
  plot_imgs.append(img_stack)
  embs_pre.append(model(img_stack, training=False))
plot_imgs = tf.concat(plot_imgs, 0)
embs_pre = tf.concat(embs_pre, 0)

pca_pre = PCA(n_components=2)
pca_pre.fit(embs_pre)
print('PCA2 explained variance before training:', pca_pre.explained_variance_ratio_)
t_pre = pca_pre.transform(embs_pre)

In [None]:
#@title The heart of ABC: helper functions for computing the loss
# Many were copied+modified from Dwibedi et al. (2019).
@tf.function
def pairwise_l2_distance(embs1, embs2):
  # embs are shape [stack_size, num_latent_dims]
  # returns shape [stack_size, stack_size] as the full matrix of distances btwn embs
  norm1 = tf.reduce_sum(tf.square(embs1), 1)
  norm1 = tf.reshape(norm1, [-1, 1])
  norm2 = tf.reduce_sum(tf.square(embs2), 1)
  norm2 = tf.reshape(norm2, [1, -1])
  dist = tf.maximum(
      norm1 + norm2 - 2.0 * tf.matmul(embs1, embs2, False, True), 0.0)
  return dist

@tf.function
def pairwise_l1_distance(embs1, embs2):
  ss2 = embs2.shape[0]
  embs1_tiled = tf.tile(tf.expand_dims(embs1, 1), [1, ss2, 1])
  dist = tf.reduce_sum(tf.abs(embs1_tiled-embs2), -1)
  return dist

@tf.function
def pairwise_linf_distance(embs1, embs2):
  ss2 = embs2.shape[0]
  embs1_tiled = tf.tile(tf.expand_dims(embs1, 1), [1, ss2, 1])
  dist = tf.reduce_max(tf.abs(embs1_tiled-embs2), -1)
  return dist

@tf.function
def get_scaled_similarity(embs1, embs2, similarity_type, temperature):
  if similarity_type == 'l2sq':
    similarity = -1.0 * pairwise_l2_distance(embs1, embs2)
  elif similarity_type == 'l2':
    similarity = -1.0 * tf.sqrt(pairwise_l2_distance(embs1, embs2) + eps)
  elif similarity_type == 'l1':
    similarity = -1.0 * pairwise_l1_distance(embs1, embs2)
  elif similarity_type == 'linf':
    similarity = -1.0 * pairwise_linf_distance(embs1, embs2)
  elif similarity_type == 'cosine':
    embs1, _ = tf.linalg.normalize(embs1, ord=2, axis=-1)
    embs2, _ = tf.linalg.normalize(embs2, ord=2, axis=-1)
    similarity = tf.matmul(embs1, embs2, transpose_b=True)
  else:
    raise ValueError('Unknown similarity type: {}'.format(similarity_type))

  similarity /= temperature
  return similarity

@tf.function
def align_pair_of_sequences(embs1, embs2, similarity_type, temperature):
  # Creates a soft nearest neighbor for each emb1 out of the elements of embs2
  ss1 = tf.shape(embs1)[0]
  sim_12 = get_scaled_similarity(embs1, embs2, similarity_type, temperature)
  softmaxed_sim_12 = tf.nn.softmax(sim_12, axis=1)
  nn_embs = tf.matmul(softmaxed_sim_12, embs2)
  sim_21 = get_scaled_similarity(nn_embs, embs1, similarity_type, temperature)

  loss = tf.keras.losses.sparse_categorical_crossentropy(tf.range(ss1), sim_21,
                                                         from_logits=True)

  return tf.reduce_mean(loss)

In [None]:
#@title display_pca_embs definition
def display_pca_embs(model, plot_imgs, step, avg_loss):
  embs_intermed = []
  for start_ind in range(0, len(plot_imgs), stack_size):
    embs_intermed.append(model(plot_imgs[start_ind:start_ind+stack_size], training=False))
  embs_intermed = tf.concat(embs_intermed, 0)

  pca_intermed = PCA(n_components=2)
  pca_intermed.fit(embs_intermed)
  t_intermed = pca_intermed.transform(embs_intermed)
  plt.figure(figsize=(8, 8))
  zoom_factor = 1.  # scales the size of the individual digit images

  ax = plt.gca()
  for img_id, img in enumerate(plot_imgs):
    img = tf.concat([1-img, 1-img, 1-img, img], -1)
    im = OffsetImage(img, zoom=zoom_factor)
    ab = AnnotationBbox(im, t_intermed[img_id], frameon=False)
    ax.add_artist(ab)
    plt.scatter(t_intermed[img_id, 0], t_intermed[img_id, 1], s=0)  # this is just so the axes bound the images
  plt.xlabel('PC0, var {:.3f}'.format(pca_intermed.explained_variance_ratio_[0]), fontsize=14.)
  plt.ylabel('PC1, var {:.3f}'.format(pca_intermed.explained_variance_ratio_[1]), fontsize=14.)
  plt.title('Step {}, ABC Loss = {:.3f}'.format(step, avg_loss), fontsize=16.)

  plt.show()
  return

In [None]:
# Train the embedder
opt = tf.keras.optimizers.get(optimizer_name)
opt.lr = lr

losses = []

for step, paired_stacks in enumerate(combined_dset.take(num_steps)):
  with tf.GradientTape() as tape:
    embs1 = model(paired_stacks[0], training=True)
    embs2 = model(paired_stacks[1], training=True)
    loss = align_pair_of_sequences(embs1, embs2, similarity_type, temperature)
    loss += align_pair_of_sequences(embs2, embs1, similarity_type, temperature)
  grads = tape.gradient(loss, model.trainable_variables)
  opt.apply_gradients(zip(grads, model.trainable_variables))
  losses.append(loss.numpy())

  if not step % output_loss_every:

    if output_plots_during_training:
      display_pca_embs(model, plot_imgs, step, np.average(losses[-output_loss_every:]))
    else:
      print('Step {} Loss: {:.2f}'.format(step,
                                          np.average(losses[-output_loss_every:])))
if output_plots_during_training:
  display_pca_embs(model, plot_imgs, step, np.average(losses[-output_loss_every:]))
else:
  print('Step {} Loss: {:.2f}'.format(step,
                                      np.average(losses[-output_loss_every:])))
print('Training completed.')

In [None]:
#@title Run PCA on the output of the trained model
embs_post = []
for start_ind in range(0, len(plot_imgs), stack_size):
  if not start_ind:
    embs_post = model(plot_imgs[start_ind:start_ind+stack_size], training=False)
  else:
    embs_post = tf.concat([embs_post, model(plot_imgs[start_ind:start_ind+stack_size], training=False)], 0)

pca_post = PCA(n_components=2)
pca_post.fit(embs_post)
print('PCA2 explained variance after training:', pca_post.explained_variance_ratio_)
t_post = pca_post.transform(embs_post)

In [None]:
#@title Compare the PCA embeddings before and after training.
plt.figure(figsize=(16, 8))
zoom_factor = 1.  # scales the size of the individual digit images

plt.subplot(121)
ax = plt.gca()
for img_id, img in enumerate(plot_imgs):
  img = tf.concat([1-img, 1-img, 1-img, img], -1)
  im = OffsetImage(img, zoom=zoom_factor)
  ab = AnnotationBbox(im, t_pre[img_id], frameon=False)
  ax.add_artist(ab)
  plt.scatter(t_pre[img_id, 0], t_pre[img_id, 1], s=0)  # this is just so the axes bound the images
plt.xlabel('PC0, var = {:.3f}'.format(pca_pre.explained_variance_ratio_[0]), fontsize=14.)
plt.ylabel('PC1, var = {:.3f}'.format(pca_pre.explained_variance_ratio_[1]), fontsize=14.)
plt.title('Before training', fontsize=16.)

plt.subplot(122)
ax = plt.gca()
for img_id, img in enumerate(plot_imgs):
  img = tf.concat([1-img, 1-img, 1-img, img], -1)
  im = OffsetImage(img, zoom=zoom_factor)
  ab = AnnotationBbox(im, t_post[img_id], frameon=False)
  ax.add_artist(ab)
  plt.scatter(t_post[img_id, 0], t_post[img_id, 1], s=0)
plt.xlabel('PC0, var = {:.3f}'.format(pca_post.explained_variance_ratio_[0]), fontsize=14.)
plt.ylabel('PC1, var = {:.3f}'.format(pca_post.explained_variance_ratio_[1]), fontsize=14.)
plt.title('After training', fontsize=16.)

plt.show()

In [None]:
#@title Check out other digits (0s and 1s are often easier to decipher).
imgs_to_plot = 200
digits_to_plot = [0, 1, 2]
plt.figure(figsize=(18, 6))
for plot_id, digit_id in enumerate(digits_to_plot):
  imgs = []
  embs = []
  for ind, img_stack in enumerate(ds[digit_id].take(imgs_to_plot//stack_size)):
    if not ind:
      imgs = img_stack
      embs = model(img_stack, training=False)
    else:
      imgs = tf.concat([imgs, img_stack], 0)
      embs = tf.concat([embs, model(img_stack, training=False)], 0)
  t_post = pca_post.transform(embs)
  plt.subplot(1, 3, plot_id+1)
  ax = plt.gca()
  for img_id, img in enumerate(imgs):
    img = tf.concat([1-img, 1-img, 1-img, img], -1)
    im = OffsetImage(img, zoom=zoom_factor)
    ab = AnnotationBbox(im, t_post[img_id], frameon=False)
    ax.add_artist(ab)
    plt.scatter(t_post[img_id, 0], t_post[img_id, 1], s=0.)

plt.show()

In [None]:
# Perform retrieval using random test digits
all_imgs = []; all_embs = []
num_images_to_use = 512
for d in ds_test:
  for stack_id, img_stack in enumerate(d.take(num_images_to_use//stack_size)):
    embs = model(img_stack, training=False)
    if stack_id:
      all_imgs[-1] = tf.concat([all_imgs[-1], img_stack], 0)
      all_embs[-1] = tf.concat([all_embs[-1], embs], 0)
    else:
      all_imgs.append(img_stack)
      all_embs.append(embs)
# Group by the nearest example of each digit
plt.figure(figsize=(10, 10))
for digit in range(10):
  template_img = all_imgs[digit][0]
  for other_digit in range(10):
    dists = pairwise_l2_distance(all_embs[digit][0:1], all_embs[other_digit])  # [1, stack_size]
    min_ind = tf.argmin(dists[0])
    plt.subplot(10, 10, digit*10 + other_digit + 1)
    plt.imshow(all_imgs[other_digit][min_ind, ..., 0], cmap='binary')
    plt.xticks([]); plt.yticks([])
    ax = plt.gca()
    plt.setp(ax.spines.values(), color='#ccdad1', linewidth=[0., 5.][digit==other_digit])

plt.show()