# Experiments reported in "[Domain Conditional Predictors for Domain Adaptation](http://preregister.science/papers_20neurips/45_paper.pdf)"

Copyright 2021 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Preamble

In [None]:
#@test {"skip": true}
!pip install dm-sonnet==2.0.0 --quiet
!pip install tensorflow_addons==0.12 --quiet

In [None]:
#@test {"output": "ignore"}
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa

try:
  import sonnet.v2 as snt
except ModuleNotFoundError:
  import sonnet as snt

#### Colab tested with:

```
       TensorFlow version: 2.4.1
           Sonnet version: 2.0.0
TensorFlow Addons version: 0.12.0
```

In [None]:
#@test {"skip": true}
print("       TensorFlow version: {}".format(tf.__version__))
print("           Sonnet version: {}".format(snt.__version__))
print("TensorFlow Addons version: {}".format(tfa.__version__))

# Data preparation

Define 4 domains by transforming the data on the fly. Current transformations are rotation, blurring, flipping colors between background and digits, and horizontal flip.

In [None]:
#@test {"output": "ignore"}
batch_size = 100
NUM_DOMAINS = 4

def process_batch_train(images, labels):
  images = tf.image.grayscale_to_rgb(images)
  images = tf.cast(images, dtype=tf.float32)
  images = images / 255.
  domain_index_candidates = tf.convert_to_tensor(list(range(NUM_DOMAINS)))
  samples = tf.random.categorical(tf.math.log([[1/NUM_DOMAINS for i in range(NUM_DOMAINS)]]), 1) # note log-prob
  domain_index=domain_index_candidates[tf.cast(samples[0][0], dtype=tf.int64)]
  if tf.math.equal(domain_index, tf.constant(0)):
    images = tfa.image.rotate(images, np.pi/3)
  elif tf.math.equal(domain_index, tf.constant(1)):
    images = tfa.image.gaussian_filter2d(images, filter_shape=[8,8])
  elif tf.math.equal(domain_index, tf.constant(2)):
    images = tf.ones_like(images) - images
  elif tf.math.equal(domain_index, tf.constant(3)):
    images = tf.image.flip_left_right(images)
  domain_label = tf.cast(domain_index, tf.int64)
  return images, labels, domain_label

def process_batch_test(images, labels):
  images = tf.image.grayscale_to_rgb(images)
  images = tf.cast(images, dtype=tf.float32)
  images = images / 255.
  return images, labels

def mnist(split, multi_domain_test=False):
  dataset = tfds.load("mnist", split=split, as_supervised=True)

  if split == "train":
    process_batch = process_batch_train
  else:
    if multi_domain_test:
      process_batch = process_batch_train
    else:
      process_batch = process_batch_test

  dataset = dataset.map(process_batch)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  dataset = dataset.cache()
  return dataset

mnist_train = mnist("train").shuffle(1000)
mnist_test = mnist("test")
mnist_test_multidomain = mnist("test", multi_domain_test=True)

Look at samples from the training domains. Domain labels are such that: Rotation >> 0, Blurring >> 1, Color flipping >> 2, Horizontal flip >> 3.

In [None]:
#@test {"skip": true}
import matplotlib.pyplot as plt
images, label, domain_label = next(iter(mnist_train))
print(label[0], domain_label[0])
plt.imshow(images[0]);

# Baseline 1: Unconditional model

A baseline model is defined below and referred to as unconditional since it does not take domain labels into account in any way.

In [None]:
class M_unconditional(snt.Module):

  def __init__(self):
    super(M_unconditional, self).__init__()
    self.hidden1 = snt.Conv2D(output_channels=10, kernel_shape=5, name="hidden1")
    self.hidden2 = snt.Conv2D(output_channels=20, kernel_shape=5, name="hidden2")
    self.flatten = snt.Flatten()
    self.logits = snt.Linear(10, name="logits")

  def __call__(self, images):
    output = tf.nn.relu(self.hidden1(images))
    output = tf.nn.relu(self.hidden2(output))
    output = self.flatten(output)
    output = self.logits(output)
    return output

In [None]:
m_unconditional = M_unconditional()

Training of the baseline:

In [None]:
#@test {"output": "ignore"}
opt_unconditional = snt.optimizers.SGD(learning_rate=0.01)

num_epochs = 10
loss_log_unconditional = []

def step(images, labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    logits_unconditional = m_unconditional(images)
    loss_unconditional = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_unconditional,
                                                          labels=labels)

    loss_unconditional = tf.reduce_mean(loss_unconditional)

  params_unconditional = m_unconditional.trainable_variables
  grads_unconditional = tape.gradient(loss_unconditional, params_unconditional)
  opt_unconditional.apply(grads_unconditional, params_unconditional)
  return loss_unconditional

for images, labels, domain_labels in mnist_train.repeat(num_epochs):
  loss_unconditional = step(images, labels)
  loss_log_unconditional.append(loss_unconditional.numpy())

print("\n\nFinal loss: {}".format(loss_log_unconditional[-1]))

In [None]:
REDUCTION_FACTOR = 0.2  ## Factor in [0,1] used to check whether the training loss reduces during training

## Checks whether the training loss reduces
assert loss_log_unconditional[-1] < REDUCTION_FACTOR*loss_log_unconditional[0]

# Baseline 2: Domain invariant representations

DANN-like model where the domain discriminator is replaced by a domain classifier aiming to induce invariance across training domains

In [None]:
#@test {"skip": true}
class DANN_task(snt.Module):

  def __init__(self):
    super(DANN_task, self).__init__()
    self.hidden1 = snt.Conv2D(output_channels=10, kernel_shape=5, name="hidden1")
    self.hidden2 = snt.Conv2D(output_channels=20, kernel_shape=5, name="hidden2")
    self.flatten = snt.Flatten()
    self.logits = snt.Linear(10, name="logits")

  def __call__(self, images):
    output = tf.nn.relu(self.hidden1(images))
    output = tf.nn.relu(self.hidden2(output))
    z = self.flatten(output)
    output = self.logits(z)
    return output, z

In [None]:
#@test {"skip": true}
class DANN_domain(snt.Module):

  def __init__(self):
    super(DANN_domain, self).__init__()
    self.logits = snt.Linear(NUM_DOMAINS, name="logits")

  def __call__(self, z):
    output = self.logits(z)
    return output

In [None]:
#@test {"skip": true}
m_DANN_task = DANN_task()
m_DANN_domain = DANN_domain()

Training of the DANN baseline

In [None]:
#@test {"skip": true}
opt_task = snt.optimizers.SGD(learning_rate=0.01)
opt_domain = snt.optimizers.SGD(learning_rate=0.01)
domain_loss_weight = 0.2 ## Hyperparameter - factor to be multiplied by the domain entropy term when training the task classifier

num_epochs = 20 ## Doubled the number of epochs to train the task classifier for as many iterations as the other methods since we have alternate updates
loss_log_dann = {'task_loss':[],'domain_loss':[]}
number_of_iterations = 0

def step(images, labels, domain_labels, iteration_count):
  """Performs one optimizer step on a single mini-batch."""
  if iteration_count%2==0: ## Alternate between training the class classifier and the domain classifier
    with tf.GradientTape() as tape:
      logits_DANN_task, z_DANN = m_DANN_task(images)
      logist_DANN_domain = m_DANN_domain(z_DANN)
      loss_DANN_task = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_DANN_task,
                                                            labels=labels)
      loss_DANN_domain = tf.nn.softmax_cross_entropy_with_logits(logits=logist_DANN_domain,
                                                            labels=1/NUM_DOMAINS*tf.ones_like(logist_DANN_domain)) ## Negative entropy of P(Y|X) measured as the cross-entropy against the uniform dist.

      loss_DANN = tf.reduce_mean(loss_DANN_task + domain_loss_weight*loss_DANN_domain)

    params_DANN = m_DANN_task.trainable_variables
    grads_DANN = tape.gradient(loss_DANN, params_DANN)
    opt_task.apply(grads_DANN, params_DANN)
    return 'task_loss', loss_DANN
  else:
    with tf.GradientTape() as tape:
      _, z_DANN = m_DANN_task(images)
      logist_DANN_domain_classifier = m_DANN_domain(z_DANN)
      loss_DANN_domain_classifier = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logist_DANN_domain_classifier,
                                                            labels=domain_labels)

      loss_DANN_domain_classifier = tf.reduce_mean(loss_DANN_domain_classifier)

    params_DANN_domain_classifier = m_DANN_domain.trainable_variables
    grads_DANN_domain_classifier = tape.gradient(loss_DANN_domain_classifier, params_DANN_domain_classifier)
    opt_domain.apply(grads_DANN_domain_classifier, params_DANN_domain_classifier)
    return 'domain_loss', loss_DANN_domain_classifier

for images, labels, domain_labels in mnist_train.repeat(num_epochs):
  number_of_iterations += 1
  loss_tag, loss_dann = step(images, labels, domain_labels, number_of_iterations)
  loss_log_dann[loss_tag].append(loss_dann.numpy())

print("\n\nFinal losses: {} - {}, {} - {}".format('task_loss', loss_log_dann['task_loss'][-1], 'domain_loss', loss_log_dann['domain_loss'][-1]))

# Definition of our models

The models for our proposed setting are defined below. 

*   The FiLM layer simply projects z onto 2 tensors (independent dense layers for each projection) matching the shape of the features. Each such tensor is used for element-wise multiplication and addition with the input features.
*   m_domain corresponds to a domain classifier. It outputs the output of the second conv. layer to be used as z, as well as a set of logits over the set of train domains.
*   m_task is the main classifier and it contains FiLM layers that take z as input. Its output corresponds to the set of logits over the labels.

In [None]:
#@test {"skip": true}
class FiLM(snt.Module):
  def __init__(self, features_shape):
    super(FiLM, self).__init__()
    self.features_shape = features_shape
    target_dimension = np.prod(features_shape)
    self.hidden_W = snt.Linear(output_size=target_dimension)
    self.hidden_B = snt.Linear(output_size=target_dimension)

  def __call__(self, features, z):
    W = snt.reshape(self.hidden_W(z), output_shape=self.features_shape)
    B = snt.reshape(self.hidden_B(z), output_shape=self.features_shape)
    output = W*features+B
    return output

In [None]:
#@test {"skip": true}
class M_task(snt.Module):

  def __init__(self):
    super(M_task, self).__init__()
    self.hidden1 = snt.Conv2D(output_channels=10, kernel_shape=5, name="hidden1")
    self.film1 = FiLM(features_shape=[28,28,10])
    self.hidden2 = snt.Conv2D(output_channels=20, kernel_shape=5, name="hidden2")
    self.film2 = FiLM(features_shape=[28,28,20])
    self.flatten = snt.Flatten()
    self.logits = snt.Linear(10, name="logits")

  def __call__(self, images, z):
    output = tf.nn.relu(self.hidden1(images))
    output = self.film1(output,z)
    output = tf.nn.relu(self.hidden2(output))
    output = self.film2(output,z)
    output = self.flatten(output)
    output = self.logits(output)
    return output

In [None]:
#@test {"skip": true}
class M_domain(snt.Module):

  def __init__(self):
    super(M_domain, self).__init__()
    self.hidden = snt.Conv2D(output_channels=10, kernel_shape=5, name="hidden")
    self.flatten = snt.Flatten()
    self.logits = snt.Linear(NUM_DOMAINS, name="logits")

  def __call__(self, images):
    output = tf.nn.relu(self.hidden(images))
    z = self.flatten(output)
    output = self.logits(z)
    return output, z

In [None]:
#@test {"skip": true}
m_task = M_task()
m_domain = M_domain()

In [None]:
#@test {"skip": true}
images, labels = next(iter(mnist_test))
domain_logits, z = m_domain(images)
logits = m_task(images, z)
  
prediction = tf.argmax(logits[0]).numpy()
actual = labels[0].numpy()
print("Predicted class: {} actual class: {}".format(prediction, actual))
plt.imshow(images[0])

# Training of the proposed model

In [None]:
#@test {"skip": true}
from tqdm import tqdm

# MNIST training set has 60k images.
num_images = 60000

def progress_bar(generator):
  return tqdm(
      generator,
      unit='images',
      unit_scale=batch_size,
      total=(num_images // batch_size) * num_epochs)

In [None]:
#@test {"skip": true}
opt = snt.optimizers.SGD(learning_rate=0.01)

num_epochs = 10

loss_log = {'total_loss':[], 'task_loss':[], 'domain_loss':[]}

def step(images, labels, domain_labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    domain_logits, z = m_domain(images)
    logits = m_task(images, z)
    loss_task = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                          labels=labels)
    loss_domain = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=domain_logits,
                                                          labels=domain_labels)

    loss = loss_task + loss_domain

    loss = tf.reduce_mean(loss)
    loss_task = tf.reduce_mean(loss_task)
    loss_domain = tf.reduce_mean(loss_domain)

  params = m_task.trainable_variables + m_domain.trainable_variables
  grads = tape.gradient(loss, params)
  opt.apply(grads, params)
  return loss, loss_task, loss_domain

for images, labels, domain_labels in progress_bar(mnist_train.repeat(num_epochs)):
  loss, loss_task, loss_domain = step(images, labels, domain_labels)
  loss_log['total_loss'].append(loss.numpy())
  loss_log['task_loss'].append(loss_task.numpy())
  loss_log['domain_loss'].append(loss_domain.numpy())

print("\n\nFinal total loss: {}".format(loss.numpy()))
print("\n\nFinal task loss: {}".format(loss_task.numpy()))
print("\n\nFinal domain loss: {}".format(loss_domain.numpy()))

# Ablation 1: Learned domain-wise context variable z

Here we consider a case where the context variables z used for conditioning are learned directly from data, and the domain predictor is discarded. **This only allows for in-domain prediction though**.

In [None]:
#@test {"skip": true}
class M_learned_z(snt.Module):

  def __init__(self):
    super(M_learned_z, self).__init__()
    self.context = snt.Embed(vocab_size=NUM_DOMAINS, embed_dim=128)
    self.hidden1 = snt.Conv2D(output_channels=10, kernel_shape=5, name="hidden1")
    self.film1 = FiLM(features_shape=[28,28,10])
    self.hidden2 = snt.Conv2D(output_channels=20, kernel_shape=5, name="hidden2")
    self.film2 = FiLM(features_shape=[28,28,20])
    self.flatten = snt.Flatten()
    self.logits = snt.Linear(10, name="logits")

  def __call__(self, images, domain_labels):
    z = self.context(domain_labels)
    output = tf.nn.relu(self.hidden1(images))
    output = self.film1(output,z)
    output = tf.nn.relu(self.hidden2(output))
    output = self.film2(output,z)
    output = self.flatten(output)
    output = self.logits(output)
    return output

In [None]:
#@test {"skip": true}
m_learned_z = M_learned_z()

In [None]:
#@test {"skip": true}
opt_learned_z = snt.optimizers.SGD(learning_rate=0.01)

num_epochs = 10
loss_log_learned_z = []

def step(images, labels, domain_labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    logits_learned_z = m_learned_z(images, domain_labels)
    loss_learned_z = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_learned_z,
                                                          labels=labels)

    loss_learned_z = tf.reduce_mean(loss_learned_z)

  params_learned_z = m_learned_z.trainable_variables
  grads_learned_z = tape.gradient(loss_learned_z, params_learned_z)
  opt_learned_z.apply(grads_learned_z, params_learned_z)
  return loss_learned_z

for images, labels, domain_labels in mnist_train.repeat(num_epochs):
  loss_learned_z = step(images, labels, domain_labels)
  loss_log_learned_z.append(loss_learned_z.numpy())

print("\n\nFinal loss: {}".format(loss_log_learned_z[-1]))

# Ablation 2: Dropping the domain classification term of the loss

We consider an ablation where the same models as in our conditional predictor are used, but training is carried out with the classification loss only. This gives us a model with the same capacity as ours but no explicit mechanism to account for domain variations in train data. The goal of this ablation is to understand whether the improvement might be simply coming from the added capacity rather than the conditional modeling.

In [None]:
#@test {"skip": true}
m_task_ablation = M_task()
m_domain_ablation = M_domain()
m_DANN_ablation = DANN_domain() ## Used for evaluating how domain dependent the representations are

In [None]:
#@test {"skip": true}
opt_ablation = snt.optimizers.SGD(learning_rate=0.01)

num_epochs = 10

loss_log_ablation = []

def step(images, labels, domain_labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    domain_logits_ablation, z = m_domain_ablation(images)
    logits_ablation = m_task_ablation(images, z)
    loss_ablation = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_ablation,
                                                          labels=labels)

    loss_ablation = tf.reduce_mean(loss_ablation)

  params_ablation = m_task_ablation.trainable_variables + m_domain_ablation.trainable_variables
  grads_ablation = tape.gradient(loss_ablation, params_ablation)
  opt_ablation.apply(grads_ablation, params_ablation)
  return loss_ablation

for images, labels, domain_labels in mnist_train.repeat(num_epochs):
  loss_ablation = step(images, labels, domain_labels)
  loss_log_ablation.append(loss_ablation.numpy())

print("\n\nFinal task loss: {}".format(loss_ablation.numpy()))

In [None]:
#@test {"skip": true}
opt_ablation_domain_classifier = snt.optimizers.SGD(learning_rate=0.01)

num_epochs = 10

log_loss_ablation_domain_classification = []

def step(images, labels, domain_labels):
  """Performs one optimizer step on a single mini-batch."""
  with tf.GradientTape() as tape:
    _, z = m_domain_ablation(images)
    logits_ablation_domain_classification = m_DANN_ablation(z)
    loss_ablation_domain_classification = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_ablation_domain_classification,
                                                          labels=domain_labels)

    loss_ablation_domain_classification = tf.reduce_mean(loss_ablation_domain_classification)

  params_ablation_domain_classification = m_DANN_ablation.trainable_variables
  grads_ablation_domain_classification = tape.gradient(loss_ablation_domain_classification, params_ablation_domain_classification)
  opt_ablation.apply(grads_ablation_domain_classification, params_ablation_domain_classification)
  return loss_ablation_domain_classification

for images, labels, domain_labels in mnist_train.repeat(num_epochs):
  loss_ablation_domain_classifier = step(images, labels, domain_labels)
  log_loss_ablation_domain_classification.append(loss_ablation_domain_classifier.numpy())

print("\n\nFinal task loss: {}".format(loss_ablation_domain_classifier.numpy()))

# Results

## Plots of training losses

In [None]:
#@test {"skip": true}
f = plt.figure(figsize=(32,8))
ax = f.add_subplot(1,3,1)
ax.plot(loss_log['total_loss'])
ax.set_title('Total Loss')

ax = f.add_subplot(1,3,2)
ax.plot(loss_log['task_loss'])
ax.set_title('Task loss')

ax = f.add_subplot(1,3,3)
ax.plot(loss_log['domain_loss'])
ax.set_title('Domain loss')



In [None]:
#@test {"skip": true}
f = plt.figure(figsize=(8,8))
ax = f.add_axes([1,1,1,1])
ax.plot(loss_log_unconditional)
ax.set_title('Unconditional baseline - Train Loss')

In [None]:
#@test {"skip": true}
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(loss_log_dann['task_loss'])
ax.set_title('Domain invariant baseline - Task loss (Class. + -Entropy)')

ax = f.add_subplot(1,2,2)
ax.plot(loss_log_dann['domain_loss'])
ax.set_title('Domain invariant baseline - Domain classification loss')

In [None]:
#@test {"skip": true}
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(loss_log_learned_z)
ax.set_title('Ablation 1 - Task loss')

ax = f.add_subplot(1,2,2)
ax.plot(loss_log_ablation)
ax.set_title('Ablation 2 - Task loss')

In [None]:
#@test {"skip": true}
f = plt.figure(figsize=(8,8))
ax = f.add_axes([1,1,1,1])
ax.plot(log_loss_ablation_domain_classification)
ax.set_title('Ablation 2: Domain classification - Train Loss')

## Out-of-domain evaluations

The original test set of mnist without any transformations is considered

In [None]:
#@test {"skip": true}
total = 0
correct = 0
correct_unconditional = 0
correct_adversarial = 0
correct_ablation2 = 0 ## The model corresponding to ablation 1 can only be used with in-domain data (with domain labels)
for images, labels in mnist_test:
  domain_logits, z = m_domain(images)
  logits = m_task(images, z)
  logits_unconditional = m_unconditional(images)
  logits_adversarial, _ = m_DANN_task(images)
  domain_logits_ablation, z_ablation = m_domain_ablation(images)
  logits_ablation2 = m_task_ablation(images, z_ablation)
  predictions = tf.argmax(logits, axis=1)
  predictions_unconditional = tf.argmax(logits_unconditional, axis=1)
  predictions_adversarial = tf.argmax(logits_adversarial, axis=1)
  predictions_ablation2 = tf.argmax(logits_ablation2, axis=1)
  correct += tf.math.count_nonzero(tf.equal(predictions, labels))
  correct_unconditional += tf.math.count_nonzero(tf.equal(predictions_unconditional, labels))
  correct_adversarial += tf.math.count_nonzero(tf.equal(predictions_adversarial, labels))
  correct_ablation2 += tf.math.count_nonzero(tf.equal(predictions_ablation2, labels))
  total += images.shape[0]

print("Got %d/%d (%.02f%%) correct" % (correct, total, correct / total * 100.))
print("Unconditional baseline perf.: %d/%d (%.02f%%) correct" % (correct_unconditional, total, correct_unconditional / total * 100.))
print("Adversarial baseline perf.: %d/%d (%.02f%%) correct" % (correct_adversarial, total, correct_adversarial / total * 100.))
print("Ablation 2 perf.: %d/%d (%.02f%%) correct" % (correct_ablation2, total, correct_ablation2 / total * 100.))

## In-domain evaluations and domain prediction

The same transformations applied in train data are applied during test

In [None]:
#@test {"skip": true}
n_repetitions = 10 ## Going over the test set multiple times to account for multiple transformations
total = 0
correct_class = 0
correct_unconditional = 0
correct_adversarial = 0
correct_ablation1 = 0
correct_ablation2 = 0
correct_domain = 0
correct_domain_adversarial = 0
correct_domain_ablation = 0
for images, labels, domain_labels in mnist_test_multidomain.repeat(n_repetitions):
  domain_logits, z = m_domain(images)
  class_logits = m_task(images, z)
  logits_unconditional = m_unconditional(images)
  logits_adversarial, z_adversarial = m_DANN_task(images)
  domain_logits_adversarial = m_DANN_domain(z_adversarial)
  logits_ablation1 = m_learned_z(images, domain_labels)
  _, z_ablation = m_domain_ablation(images)
  domain_logits_ablation = m_DANN_ablation(z_ablation)
  logits_ablation2 = m_task_ablation(images, z_ablation)

  predictions_class = tf.argmax(class_logits, axis=1)
  predictions_unconditional = tf.argmax(logits_unconditional, axis=1)
  predictions_adversarial = tf.argmax(logits_adversarial, axis=1)
  predictions_ablation1 = tf.argmax(logits_ablation1, axis=1)
  predictions_ablation2 = tf.argmax(logits_ablation2, axis=1)
  predictions_domain = tf.argmax(domain_logits, axis=1)
  predictions_domain_adversarial = tf.argmax(domain_logits_adversarial, axis=1)
  predictions_domain_ablation = tf.argmax(domain_logits_ablation, axis=1)

  correct_class += tf.math.count_nonzero(tf.equal(predictions_class, labels))
  correct_unconditional += tf.math.count_nonzero(tf.equal(predictions_unconditional, labels))
  correct_adversarial += tf.math.count_nonzero(tf.equal(predictions_adversarial, labels))
  correct_ablation1 += tf.math.count_nonzero(tf.equal(predictions_ablation1, labels))
  correct_ablation2 += tf.math.count_nonzero(tf.equal(predictions_ablation2, labels))
  correct_domain += tf.math.count_nonzero(tf.equal(predictions_domain, domain_labels))
  correct_domain_adversarial += tf.math.count_nonzero(tf.equal(predictions_domain_adversarial, domain_labels))
  correct_domain_ablation += tf.math.count_nonzero(tf.equal(predictions_domain_ablation, domain_labels))
  total += images.shape[0]

print("In domain unconditional baseline perf.: %d/%d (%.02f%%) correct" % (correct_unconditional, total, correct_unconditional / total * 100.))
print("In domain adversarial baseline perf.: %d/%d (%.02f%%) correct" % (correct_adversarial, total, correct_adversarial / total * 100.))
print("In domain ablation 1: %d/%d (%.02f%%) correct" % (correct_ablation1, total, correct_ablation1 / total * 100.))
print("In domain ablation 2: %d/%d (%.02f%%) correct" % (correct_ablation2, total, correct_ablation2 / total * 100.))
print("In domain class predictions: Got %d/%d (%.02f%%) correct" % (correct_class, total, correct_class / total * 100.))
print("\n\nDomain predictions: Got %d/%d (%.02f%%) correct" % (correct_domain, total, correct_domain / total * 100.))
print("Adversarial baseline domain predictions: Got %d/%d (%.02f%%) correct" % (correct_domain_adversarial, total, correct_domain_adversarial / total * 100.))
print("Ablation 2 domain predictions: Got %d/%d (%.02f%%) correct" % (correct_domain_ablation, total, correct_domain_ablation / total * 100.))

In [None]:
#@test {"skip": true}
def sample(correct, rows, cols):
  n = 0

  f, ax = plt.subplots(rows, cols)
  if rows > 1:    
    ax = tf.nest.flatten([tuple(ax[i]) for i in range(rows)])
  f.set_figwidth(14)
  f.set_figheight(4 * rows)


  for images, labels in mnist_test:
    domain_logits, z = m_domain(images)
    logits = m_task(images, z)
    predictions = tf.argmax(logits, axis=1)
    eq = tf.equal(predictions, labels)
    for i, x in enumerate(eq):
      if x.numpy() == correct:
        label = labels[i]
        prediction = predictions[i]
        image = images[i]

        ax[n].imshow(image)
        ax[n].set_title("Prediction:{}\nActual:{}".format(prediction, label))

        n += 1
        if n == (rows * cols):
          break

    if n == (rows * cols):
      break

## Samples and corresponding predictions

In [None]:
#@test {"skip": true}
sample(correct=True, rows=1, cols=5)

In [None]:
#@test {"skip": true}
sample(correct=False, rows=2, cols=5)