# Conditional GAN with MNIST

* MNIST data를 가지고 **Conditional GAN**를 `tf.contrib.slim`을 이용하여 만들어보자.
  * 참고 [TensorFlow slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim)
* based on dcgan
* Original paper: Mehdi Mirza and Simon Osindero, "Conditional Generative Adversarial Nets"
* reference code
  * [ilguyi's gans.tensorflow.slim](https://github.com/ilguyi/gans.tensorflow.slim)

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf

slim = tf.contrib.slim
tf.logging.set_verbosity(tf.logging.INFO)

#tf.set_random_seed(219)
#np.random.seed(219)

sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
# Training Flags (hyperparameter configuration)
train_dir = 'train/cgan/exp1/'
max_epochs = 20 * 2 # effectively 30 epoch due to twice call the same data
#max_epochs = 20
save_epochs = 10
summary_steps = 2500
print_steps = 1
batch_size = 64
learning_rate_D = 0.0002
learning_rate_G = 0.001
k = 1 # the number of step of learning D before learning G
num_samples = 10 # the number of class on MNIST

## Import MNIST

In [None]:
# Load training and eval data from tf.keras
(train_data, train_labels), _ = \
    tf.keras.datasets.mnist.load_data()

train_data = train_data
train_data = train_data / 255.
train_labels = np.asarray(train_labels, dtype=np.int)
train_labels = train_labels

## Set up dataset with `tf.data`

### create input pipeline with `tf.data.Dataset`

In [None]:
# for train
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_dataset = train_dataset.shuffle(buffer_size = 10000)
train_dataset = train_dataset.repeat(count=max_epochs)
#train_dataset = train_dataset.batch(batch_size = batch_size)
train_dataset = train_dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
print(train_dataset)

## Create the model

In [None]:
class CGAN(object):
  """Conditional Generative Adversarial Networks
  implementation based on http://arxiv.org/abs/1411.1784
  
  "Conditional Generative Adversarial Nets"
  Mehdi Mirza, Simon Osindero
  """
  
  def __init__(self, mode, train_dataset, test_dataset=None):
    """Basic setup.
    
    Args:
      mode (`string`): "train" or "generate".
      train_dataset (`tf.data.Dataset`): train_dataset.
      test_dataset (`tf.data.Dataset`): test_dataset.
    """
    assert mode in ["train", "generate"]
    self.mode = mode
    
    # hyper-parameters for model
    self.x_dim = 28
    self.z_dim = 100
    self.y_dim = 10 # for number of class on MNIST
    self.batch_size = batch_size
    self.num_samples = num_samples
    self.train_dataset = train_dataset
    self.test_dataset = test_dataset
    
    # Global step Tensor.
    self.global_step = None
    
    print('The mode is %s.' % self.mode)
    print('complete initializing model.')
    
    
  def build_random_z_inputs(self):
    """Build a vector random_z in latent space.
    
    Returns:
      self.random_z (`4-rank Tensor` with [batch_size, 1, 1, z_dim]):
          latent vector which size is generally 100 dim.
      self.sample_random_z (`4-rank Tensor` with [num_samples, 1, 1, z_dim]):
          latent vector which size is generally 100 dim.
    """
    # 여기를 채워 넣으세요
    # Setup variable of random vector z
    with tf.variable_scope('random_z'):
      self.random_z = tf.random_uniform([self.batch_size, 1, 1, self.z_dim],
                                        minval=-1.0, maxval=1.0)
      self.sample_random_z = tf.random_uniform([self.num_samples, 1, 1, self.z_dim],
                                               minval=-1.0, maxval=1.0)

    return self.random_z, self.sample_random_z
  
  
  def read_MNIST(self, dataset):
    """Read MNIST dataset and create a conditional vector c.
    
    Args:
      dataset (`tf.data.Dataset` format): MNIST dataset.
      
    Returns:
      self.mnist (`4-rank Tensor` with [batch, x_dim, x_dim, 1]): MNIST dataset with batch size.
      self.condition (`4-rank Tensor` with [batch, 1, 1, y_dim]): MNIST lable dataset with batch size.
    """
    with tf.variable_scope('mnist'):
      iterator = dataset.make_one_shot_iterator()

      self.mnist, self.condition = iterator.get_next()
      self.mnist = tf.cast(self.mnist, dtype = tf.float32)
      self.mnist = tf.expand_dims(self.mnist, axis=3)
      self.condition = tf.one_hot(self.condition, depth=self.y_dim)
      self.condition = tf.reshape(self.condition, shape=[-1, 1, 1, self.y_dim])
      self.condition = tf.cast(self.condition, dtype = tf.float32)
      
    return self.mnist, self.condition


  def Generator(self, random_z, condition, is_training=True, reuse=False):
    """Generator setup.
    
    Args:
      random_z (`4-rank Tensor` with [batch_size, 1, 1, z_dim]):
          latent vector which size is generally 100 dim.
      condition (`4-rank Tensor` with [batch_size, 1, 1, y_dim]):
          conditional vector which size is the number of class on MNIST.
      is_training (`bool`): whether training mode or test mode.
      reuse (`bool`): whether variable reuse or not.
      
    Returns:
      generated_data (`4-rank Tensor` with [batch_size, h, w, c])
          generated images from random vector z.
    """
    with tf.variable_scope('Generator', reuse=reuse) as scope:
      batch_norm_params = {'decay': 0.9,
                           'epsilon': 0.001,
                           'is_training': is_training,
                           'scope': 'batch_norm'}
      with slim.arg_scope([slim.conv2d_transpose],
                          kernel_size=[4, 4],
                          stride=[2, 2],
                          normalizer_fn=slim.batch_norm,
                          normalizer_params=batch_norm_params):
        # 여기를 채워 넣으세요
        # Use full conv2d_transpose instead of projection and reshape
        # random_z: 100 dim
        # condition: 10 dim (for MNIST)
        # inputs = random_z + condition
        self.inputs = tf.concat([random_z, condition], axis=3)
        # inputs: 1 x 1 x 110 dim
        # outputs: 3 x 3 x 256 dim
        self.layer1 = slim.conv2d_transpose(inputs=self.inputs,
                                            num_outputs=256,
                                            kernel_size=[3, 3],
                                            padding='VALID',
                                            scope='layer1')
        # inputs: 3 x 3 x 256 dim
        # outputs: 7 x 7 x 128 dim
        self.layer2 = slim.conv2d_transpose(inputs=self.layer1,
                                            num_outputs=128,
                                            kernel_size=[3, 3],
                                            padding='VALID',
                                            scope='layer2')
        # inputs: 7 x 7 x 128 dim
        # outputs: 14 x 14 x 64 dim
        self.layer3 = slim.conv2d_transpose(inputs=self.layer2,
                                            num_outputs=64,
                                            scope='layer3')
        # inputs: 14 x 14 x 64 dim
        # outputs: 28 x 28 x 1 dim
        self.layer4 = slim.conv2d_transpose(inputs=self.layer3,
                                            num_outputs=1,
                                            normalizer_fn=None,
                                            activation_fn=tf.sigmoid,
                                            scope='layer4')
        # generated_data = outputs: 28 x 28 x 1 dim
        generated_data = self.layer4

        return generated_data
    
    
  def Discriminator(self, data, condition, reuse=False):
    """Discriminator setup.
    
    Args:
      data (`4-rank Tensor` with [batch_size, x_dim, x_dim, 1]): MNIST real data.
      condition (`4-rank Tensor` with [batch_size, 1, 1, y_dim]):
          conditional vector which size is the number of class on MNIST.
      reuse (`bool`): whether variable reuse or not.
      
    Returns:
      logits (`1-rank Tensor` with [batch_size]): logits of data.
    """
    with tf.variable_scope('Discriminator', reuse=reuse) as scope:
      batch_norm_params = {'decay': 0.9,
                           'epsilon': 0.001,
                           'scope': 'batch_norm'}
      with slim.arg_scope([slim.conv2d],
                          kernel_size=[4, 4],
                          stride=[2, 2],
                          activation_fn=tf.nn.leaky_relu,
                          normalizer_fn=slim.batch_norm,
                          normalizer_params=batch_norm_params):
        # 여기를 채워 넣으세요
        # data: 28 x 28 x 1 dim
        # condition: 10 dim (for MNIST)
        # inputs = data + condition
        self.inputs = tf.concat([data,
                                 condition * tf.ones([self.batch_size,
                                                      self.x_dim, self.x_dim,
                                                      self.y_dim])], axis=3)
        # inputs: 28 x 28 x (1 + 10) dim
        # outputs: 14 x 14 x 64 dim
        self.layer1 = slim.conv2d(inputs=self.inputs,
                                  num_outputs=64,
                                  normalizer_fn=None,
                                  scope='layer1')
        # inputs: 14 x 14 x 64 dim
        # outputs: 7 x 7 x 128 dim
        self.layer2 = slim.conv2d(inputs=self.layer1,
                                  num_outputs=128,
                                  scope='layer2')
        # inputs: 7 x 7 x 128 dim
        # outputs: 3 x 3 x 256 dim
        self.layer3 = slim.conv2d(inputs=self.layer2,
                                  num_outputs=256,
                                  kernel_size=[3, 3],
                                  padding='VALID',
                                  scope='layer3')
        # inputs: 3 x 3 x 256 dim
        # outputs: 1 x 1 x 1 dim
        self.layer4 = slim.conv2d(inputs=self.layer3,
                                  num_outputs=1,
                                  kernel_size=[3, 3],
                                  stride=[1, 1],
                                  padding='VALID',
                                  normalizer_fn=None,
                                  activation_fn=None,
                                  scope='layer4')
        # logits = layer4: 1 x 1 x 1 dim -> 1 dim
        discriminator_logits = tf.squeeze(self.layer4, axis=[1, 2])

        return discriminator_logits
    
    
  def setup_global_step(self):
    """Sets up the global step Tensor."""
    if self.mode == "train":
      self.global_step = tf.train.get_or_create_global_step()
      
      print('complete setup global_step.')
      
      
  def GANLoss(self, logits, is_real=True, scope=None):
    """Computes standard GAN loss between `logits` and `labels`.
    
    Args:
      logits (`1-rank Tensor`): logits.
      is_real (`bool`): True means `1` labeling, False means `0` labeling.
      
    Returns:
      loss (`0-randk Tensor): the standard GAN loss value. (binary_cross_entropy)
    """
    if is_real:
      labels = tf.ones_like(logits)
    else:
      labels = tf.zeros_like(logits)

    # 여기를 채워 넣으세요
    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=labels,
                                           logits=logits,
                                           scope=scope)

    return loss

      
  def build(self):
    """Creates all ops for training or generate."""
    self.setup_global_step()
    
    if self.mode == "generate":
      pass
    
    else:
      # generating random vector
      self.random_z, self.sample_random_z = self.build_random_z_inputs()
      # read dataset and create conditional vector
      self.real_data, self.condition = self.read_MNIST(self.train_dataset)
      
      # 여기를 채워 넣으세요
      # generating images from Generator() via random vector z
      self.generated_data = self.Generator(self.random_z, self.condition)
      
      # 여기를 채워 넣으세요
      # discriminating real data by Discriminator()
      self.real_logits = self.Discriminator(self.real_data, self.condition)
      # discriminating fake data (generated)_images) by Discriminator()
      self.fake_logits = self.Discriminator(self.generated_data, self.condition, reuse=True)
      
      # 여기를 채워 넣으세요
      # losses of real with label "1"
      self.loss_real = self.GANLoss(logits=self.real_logits, is_real=True, scope='loss_D_real')
      # losses of fake with label "0"
      self.loss_fake = self.GANLoss(logits=self.fake_logits, is_real=False, scope='loss_D_fake')
      
      # losses of Discriminator
      with tf.variable_scope('loss_D'):
        self.loss_Discriminator = self.loss_real + self.loss_fake
      # 여기를 채워 넣으세요
      # losses of Generator with label "1" that used to fool the Discriminator
      self.loss_Generator = self.GANLoss(logits=self.fake_logits, is_real=True, scope='loss_G')
      
      # 여기를 채워 넣으세요
      # Separate variables for each function
      self.D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')
      self.G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator')
      
      
      # generating images for sample
      self.sample_condition = tf.eye(self.y_dim)
      self.sample_condition = tf.reshape(self.sample_condition, [-1, 1, 1, self.y_dim])
      self.sample_data = self.Generator(self.sample_random_z, self.sample_condition, is_training=False, reuse=True)
      
      # write summaries
      # Add loss summaries
      tf.summary.scalar('losses/loss_Discriminator', self.loss_Discriminator)
      tf.summary.scalar('losses/loss_Generator', self.loss_Generator)
      
      # Add histogram summaries
      for var in self.D_vars:
        tf.summary.histogram(var.op.name, var)
      for var in self.G_vars:
        tf.summary.histogram(var.op.name, var)
      
      # Add image summaries
      tf.summary.image('random_images', self.generated_data, max_outputs=4)
      tf.summary.image('real_images', self.real_data)
      
    print('complete model build.\n')

## Define plot function

In [None]:
def print_sample_data(sample_data, max_print=num_samples):
  print_images = sample_data[:max_print,:]
  print_images = print_images.reshape([max_print, 28, 28])
  print_images = print_images.swapaxes(0, 1)
  print_images = print_images.reshape([28, max_print * 28])
  
  plt.figure(figsize=(max_print, 1))
  plt.axis('off')
  plt.imshow(print_images, cmap='gray')
  plt.show()

## Build a model

In [None]:
model = CGAN(mode="train", train_dataset=train_dataset)
model.build()

# show info for trainable variables
t_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(t_vars, print_info=True)

In [None]:
opt_D = tf.train.AdamOptimizer(learning_rate=learning_rate_D, beta1=0.5)
opt_G = tf.train.AdamOptimizer(learning_rate=learning_rate_G, beta1=0.5)

In [None]:
# 여기를 채워 넣으세요
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Discriminator')):
  opt_D_op = opt_D.minimize(model.loss_Discriminator, var_list=model.D_vars)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Generator')):
  opt_G_op = opt_G.minimize(model.loss_Generator, global_step=model.global_step,
                            var_list=model.G_vars)

### Assign `tf.summary.FileWriter`

In [None]:
graph_location = train_dir
print('Saving graph to: %s' % graph_location)
train_writer = tf.summary.FileWriter(graph_location)
train_writer.add_graph(tf.get_default_graph()) 

### `tf.summary`

In [None]:
summary_op = tf.summary.merge_all()

### `tf.train.Saver`

In [None]:
saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000)

### `tf.Session` and train

In [None]:
sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
with tf.Session(config=sess_config) as sess:
  sess.run(tf.global_variables_initializer())
  tf.logging.info('Start Session.')
  
  num_examples = len(train_data)
  num_batches_per_epoch = int(num_examples / batch_size)
  
  # save loss values for plot
  loss_history = []
  pre_epochs = 0
  while True:
    try:
      start_time = time.time()
      
      for _ in range(k):
        _, loss_D = sess.run([opt_D_op, model.loss_Discriminator])
      _, global_step_, loss_G = sess.run([opt_G_op,
                                          model.global_step,
                                          model.loss_Generator])
      
      epochs = global_step_ * batch_size / float(num_examples)
      duration = time.time() - start_time

      if global_step_ % print_steps == 0:
        examples_per_sec = batch_size / float(duration)
        print("Epochs: {:.2f} global_step: {} loss_D: {:.3f} loss_G: {:.3f} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                  epochs, global_step_, loss_D, loss_G, examples_per_sec, duration))

        loss_history.append([epochs, loss_D, loss_G])

        # print sample data
        sample_data = sess.run(model.sample_data)
        print_sample_data(sample_data)

      # write summaries periodically
      if global_step_ % summary_steps == 0:
        summary_str = sess.run(summary_op)
        train_writer.add_summary(summary_str, global_step=global_step_)

      # save model checkpoint periodically
      if int(epochs) % save_epochs == 0  and  pre_epochs != int(epochs):
        tf.logging.info('Saving model with global step {} (= {} epochs) to disk.'.format(global_step_, int(epochs)))
        saver.save(sess, train_dir + 'model.ckpt', global_step=global_step_)
        pre_epochs = int(epochs)
        
    except tf.errors.OutOfRangeError:
      print("End of dataset")  # ==> "End of dataset"
      tf.logging.info('Saving model with global step {} (= {} epochs) to disk.'.format(global_step_, int(epochs)))
      saver.save(sess, train_dir + 'model.ckpt', global_step=global_step_)
      break
      
  tf.logging.info('complete training...')

## Plot loss functions

In [None]:
loss_history = np.asarray(loss_history)

plt.plot(loss_history[:,0], loss_history[:,1], label='loss_D')
plt.plot(loss_history[:,0], loss_history[:,2], label='loss_G')
plt.legend(loc='upper right')
plt.show()