# InfoGAN with MNIST

* `InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets` [arXiv:1606.03657](https://arxiv.org/abs/1606.03657)
  * Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, Pieter Abbeel
* Implemented by [`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) and [`eager execution`](https://www.tensorflow.org/guide/eager).

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers
tf.enable_eager_execution()

import image_utils as utils

tf.logging.set_verbosity(tf.logging.INFO)

os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Setting hyperparameters

In [None]:
# Training Flags (hyperparameter configuration)
model_name = 'infogan'
train_dir = 'train/' + model_name + '/exp1/'
max_epochs = 30
save_model_epochs = 10
print_steps = 100
save_images_epochs = 1
batch_size = 256
learning_rate_D = 1e-4
learning_rate_G = 1e-4
k = 1 # the number of step of learning D before learning G
num_classes = 10 # number of classes for MNIST
num_examples_to_generate = num_classes
noise_dim = 62
categorical_code_dim = num_classes # for MNIST
continuous_code_dim = 2
MNIST_SIZE = utils.MNIST_SIZE

## Load the MNIST dataset

In [None]:
# Load training and eval data from tf.keras
(train_data, train_labels), _ = \
    tf.keras.datasets.mnist.load_data()

train_data = train_data.reshape(-1, 28, 28, 1).astype('float32')
train_data = train_data / 255.
train_labels = train_labels.astype(np.int32)

## Set up dataset with `tf.data`

### create input pipeline with `tf.data.Dataset`

In [None]:
tf.set_random_seed(219)

# for train
N = len(train_data)
N = 320
train_dataset = tf.data.Dataset.from_tensor_slices((train_data[:N]))
train_dataset = train_dataset.shuffle(buffer_size = N)
train_dataset = train_dataset.batch(batch_size = batch_size, drop_remainder=True)
print(train_dataset)

## Create the generator and discriminator models

In [None]:
class Generator(tf.keras.Model):
  def __init__(self):
    super(Generator, self).__init__()
    self.fc1 = layers.Dense(units=1024, use_bias=False)
    self.fc1_bn = layers.BatchNormalization()
    self.fc2 = layers.Dense(units=7 * 7 * 128, use_bias=False)
    self.fc2_bn = layers.BatchNormalization()
    self.conv1 = layers.Conv2DTranspose(filters=64, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False)
    self.conv1_bn = layers.BatchNormalization()
    self.conv2 = layers.Conv2DTranspose(filters=1, kernel_size=(4, 4), strides=(2, 2), padding='same')

  def call(self, noise_inputs, categorical_code, continuous_code, training=True):
    """Run the model."""
    # random z vector (noise dim): 62 dim
    # categorical code (for MNIST): 10 dim
    # continuous code: 2 dim
    # inputs: 62 + 10 + 2 = 74 dim
    inputs = tf.concat([noise_inputs, categorical_code, continuous_code], axis=1)
    
    # fc1: (1024,) shape
    fc1 = self.fc1(inputs)
    fc1 = self.fc1_bn(fc1, training=training)
    fc1 = tf.nn.relu(fc1)
    
    # fc2: (7, 7, 128) shape
    fc2 = self.fc2(fc1)
    fc2 = self.fc2_bn(fc2, training=training)
    fc2 = tf.nn.relu(fc2)
    fc2 = tf.reshape(fc2, [-1, 7, 7, 128])
    
    # conv1: (14, 14, 64) shape
    conv1 = self.conv1(fc2)
    conv1 = self.conv1_bn(conv1, training=training)
    conv1 = tf.nn.relu(conv1)
    
    # generated_images = conv2: (28, 28, 1) shape
    conv2 = self.conv2(conv1)
    generated_images = tf.nn.sigmoid(conv2)
    
    return generated_images

In [None]:
class Discriminator(tf.keras.Model):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.conv1 = layers.Conv2D(filters=64, kernel_size=(4, 4), strides=(2, 2), padding='same')
    self.conv2 = layers.Conv2D(filters=128, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False)
    self.conv2_bn = layers.BatchNormalization()
    self.flatten = layers.Flatten()
    self.fc1 = layers.Dense(units=1024, use_bias=False)
    self.fc1_bn = layers.BatchNormalization()
    self.fc2 = layers.Dense(units=1)

  def call(self, image_inputs, training=True):
    # image_inputs: (28, 28, 1) shape
    # conv1: (14, 14, 64) shape
    conv1 = self.conv1(image_inputs)
    conv1 = tf.nn.leaky_relu(conv1)
    
    # conv2: (7, 7, 128) shape
    conv2 = self.conv2(conv1)
    conv2 = self.conv2_bn(conv2, training=training)
    conv2 = tf.nn.leaky_relu(conv2)
    
    # flatten: (7 x 7 x 128,) shape
    flatten = self.flatten(conv2)
    
    # fc1: (1024,) shape
    fc1 = self.fc1(flatten)
    fc1 = self.fc1_bn(fc1, training=training)
    fc1 = tf.nn.leaky_relu(fc1)
    
    # discriminator_logits: (1,) shape
    discriminator_logits = self.fc2(fc1)
    
    return discriminator_logits, fc1

In [None]:
class RecognitionNetwork(tf.keras.Model):
  def __init__(self):
    super(RecognitionNetwork, self).__init__()
    self.fc1 = layers.Dense(units=128, use_bias=False)
    self.fc1_bn = layers.BatchNormalization()
    self.fc2 = layers.Dense(units=categorical_code_dim + continuous_code_dim)
    
  def call(self, inputs, training=True):
    # inputs: (1024,) shape
    # fc1: (128,) shape
    fc1 = self.fc1(inputs)
    fc1 = self.fc1_bn(fc1, training=training)
    fc1 = tf.nn.leaky_relu(fc1)
    
    # q_logits: (10 + 2,) shape
    q_logits = self.fc2(fc1)
    
    # q_softmax: (10 + 2,) shape
    q_softmax = tf.nn.softmax(q_logits)
    
    return q_logits, q_softmax

In [None]:
generator = Generator()
discriminator = Discriminator()
q_network = RecognitionNetwork()

In [None]:
# Defun for performance boost
generator.call = tf.contrib.eager.defun(generator.call)
discriminator.call = tf.contrib.eager.defun(discriminator.call)
q_network.call = tf.contrib.eager.defun(q_network.call)

## Define the loss functions and the optimizer

In [None]:
def GANLoss(logits, is_real=True):
  """Computes standard GAN loss between `logits` and `labels`.

  Args:
    logits (`1-rank Tensor`): logits.
    is_real (`bool`): True means `1` labeling, False means `0` labeling.

  Returns:
    loss (`0-randk Tensor): the standard GAN loss value. (binary_cross_entropy)
  """
  if is_real:
    labels = tf.ones_like(logits)
  else:
    labels = tf.zeros_like(logits)

  return tf.losses.sigmoid_cross_entropy(multi_class_labels=labels,
                                         logits=logits)

In [None]:
def discriminator_loss(real_logits, fake_logits):
  # losses of real with label "1"
  real_loss = GANLoss(logits=real_logits, is_real=True)
  # losses of fake with label "0"
  fake_loss = GANLoss(logits=fake_logits, is_real=False)
  
  return real_loss + fake_loss

In [None]:
def generator_loss(fake_logits):
  # losses of Generator with label "1" that used to fool the Discriminator
  return GANLoss(logits=fake_logits, is_real=True)

In [None]:
def mutual_information_loss(q_logits, q_softmax, categorical_code, continuous_code):
  # Categorical code
  loss_Q_cate = tf.losses.softmax_cross_entropy(onehot_labels=categorical_code,
                                                logits=q_logits[:, :categorical_code_dim])
  # Continuous code
  loss_Q_cont = tf.losses.mean_squared_error(labels=continuous_code,
                                             predictions=q_softmax[:, categorical_code_dim:])
#   loss_Q_cont = tf.reduce_mean(
#                   tf.reduce_sum(
#                     tf.square(continuous_code - q_softmax[:, categorical_code_dim:]), axis=1))

  # losses of Recognition network
  return loss_Q_cate + loss_Q_cont    

In [None]:
discriminator_optimizer = tf.train.AdamOptimizer(learning_rate_D, beta1=0.5)
generator_optimizer = tf.train.AdamOptimizer(learning_rate_G, beta1=0.5)
q_network_optimizer = tf.train.AdamOptimizer(learning_rate_G, beta1=0.5)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = train_dir
if not tf.gfile.Exists(checkpoint_dir):
  tf.gfile.MakeDirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 q_network_optimizer=q_network_optimizer,
                                 generator=generator,
                                 discriminator=discriminator,
                                 q_network=q_network)                                

## Training

In [None]:
def print_or_save_sample_data(sample_images1, sample_images2, max_print_size=num_examples_to_generate,
                              is_save=False, epoch=None, checkpoint_dir=checkpoint_dir):

  images1 = sample_images1[:max_print_size,:]
  images1 = images1.reshape([max_print_size, 28, 28])
  images1 = images1.swapaxes(0, 1)
  images1 = images1.reshape([28, max_print_size * 28])
  
  images2 = sample_images2[:max_print_size,:]
  images2 = images2.reshape([max_print_size, 28, 28])
  images2 = images2.swapaxes(0, 1)
  images2 = images2.reshape([28, max_print_size * 28])
  
  print_images = np.concatenate((images1, images2), axis=0)  
   
  plt.figure(figsize=(max_print_size, 2))
  plt.axis('off')
  plt.imshow(print_images, cmap='gray')
  
  if is_save and epoch is not None:
    filepath = os.path.join(checkpoint_dir, 'image_at_epoch_{:04d}.png'.format(epoch))
    plt.savefig(filepath)
  
  plt.show()

In [None]:
def sampling_images(random_vector_for_generation):
  sample_noise = tf.stack(random_vector_for_generation * num_examples_to_generate)

  sample_number = 3 # 0 ~ 9, actually arbitrary number
  sample_categorical_code = tf.stack([tf.one_hot(sample_number, depth=num_classes)] * num_examples_to_generate)

  c1 = tf.reshape(tf.linspace(-2.0, 2.0, num=num_examples_to_generate), [num_examples_to_generate, 1])
  c2 = tf.zeros([num_examples_to_generate, 1])
  sample_continuous_code = tf.concat((c1, c2), axis=1)

  sample_images1 = generator(sample_noise, sample_categorical_code, sample_continuous_code, training=False)

  sample_number = 7 # 0 ~ 9, actually arbitrary number
  sample_categorical_code = tf.stack([tf.one_hot(sample_number, depth=num_classes)] * num_examples_to_generate)

  sample_continuous_code = np.concatenate((c2, c1), axis=1)

  sample_images2 = generator(sample_noise, sample_categorical_code, sample_continuous_code, training=False)

  return sample_images1.numpy(), sample_images2.numpy()

In [None]:
# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement of the gan.
random_vector_for_generation = tf.random_normal([num_examples_to_generate, noise_dim])

In [None]:
tf.logging.info('Start Training.')
global_step = tf.train.get_or_create_global_step()

sample_condition = tf.eye(num_classes)
sample_condition = tf.reshape(sample_condition, [-1, num_classes])
  
for epoch in range(max_epochs):
  
  for images in train_dataset:
    start_time = time.time()
    
    # generating noise from a uniform distribution
    noise = tf.random_normal([batch_size, noise_dim])
    #categorical_code = np.random.multinomial(1, categorical_code_dim * [1. / categorical_code_dim],
    #                                         size=[batch_size])
    categorical_code = tf.one_hot(
                          tf.multinomial([categorical_code_dim * [1. / categorical_code_dim]], batch_size)[0],
                                          depth=categorical_code_dim)
    continuous_code = tf.random_uniform(shape=[batch_size, continuous_code_dim], minval=-1.0, maxval=1.0)

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape, tf.GradientTape() as q_net_tape:
      generated_images = generator(noise, categorical_code, continuous_code, training=True)

      real_logits, _ = discriminator(images, training=True)
      fake_logits, recog_inputs = discriminator(generated_images, training=True)
      q_logits, q_softmax = q_network(recog_inputs, training=True)
      
      gen_loss = generator_loss(fake_logits)
      disc_loss = discriminator_loss(real_logits, fake_logits)
      mi_loss = mutual_information_loss(q_logits, q_softmax, categorical_code, continuous_code)
      
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)
    gradients_of_q_network = q_net_tape.gradient(mi_loss, q_network.variables)
    

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables),
                                        global_step=global_step)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))
    q_network_optimizer.apply_gradients(zip(gradients_of_q_network, q_network.variables))
    
    epochs = global_step.numpy() * batch_size / float(N)
    duration = time.time() - start_time

    if global_step.numpy() % print_steps == 0:
      display.clear_output(wait=True)
      examples_per_sec = batch_size / float(duration)
      print("Epochs: {:.2f} global_step: {} loss_D: {:.3f} loss_G: {:.3f} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step.numpy(), disc_loss, gen_loss, examples_per_sec, duration))
      
      sample_images1, sample_images2 = sampling_images(random_vector_for_generation)
      print_or_save_sample_data(sample_images1, sample_images2)

  if epoch % save_images_epochs == 0:
    display.clear_output(wait=True)
    print("This images are saved at {} epoch".format(epoch+1))
    sample_images1, sample_images2 = sampling_images(random_vector_for_generation)
    print_or_save_sample_data(sample_images1, sample_images2, num_examples_to_generate,
                              is_save=True, epoch=epoch+1, checkpoint_dir=checkpoint_dir)

  # saving (checkpoint) the model every save_epochs
  if (epoch + 1) % save_model_epochs == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
# generating after the final epoch
display.clear_output(wait=True)
sample_images1, sample_images2 = sampling_images(random_vector_for_generation)
print_or_save_sample_data(sample_images1, sample_images2, num_examples_to_generate,
                          is_save=True, epoch=epoch+1, checkpoint_dir=checkpoint_dir)

## Restore the latest checkpoint

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Display an image using the epoch number

In [None]:
def display_image(epoch_no, checkpoint_dir=checkpoint_dir):
  filepath = os.path.join(checkpoint_dir, 'image_at_epoch_{:04d}.png'.format(epoch_no))
  return PIL.Image.open(filepath)

In [None]:
display_image(max_epochs, checkpoint_dir)

## Generate a GIF of all the saved images.

In [None]:
filename = model_name + '.gif'
utils.generate_gif(filename, checkpoint_dir)

In [None]:
display.Image(filename=filename + '.png')