# StarGAN [with celebA dataset]

* `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks`, [arXiv:1703.10593](https://arxiv.org/abs/1703.10593)
  * Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros
  
* This code is available to tensorflow version 2.0
* Implemented by [`tf.keras.layers`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers) [`tf.losses`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/losses)

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display
import urllib.request
import zipfile

import tensorflow as tf
from tensorflow.keras import layers

sys.path.append(os.path.dirname(os.path.abspath('.')))
from utils.image_utils import *
from utils.ops import *

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
tf.__version__

## Setting hyperparameters

In [None]:
# Training Flags (hyperparameter configuration)
model_name = 'stargan'
train_dir = os.path.join('train', model_name, 'exp1')
dataset_name = 'celebA'
assert dataset_name in ['celebA']

constant_lr_epochs = 10
decay_lr_epochs = 10
max_epochs = constant_lr_epochs + decay_lr_epochs
save_model_epochs = 2
print_steps = 10
save_images_epochs = 1
batch_size = 16
learning_rate_D = 1e-4
learning_rate_G = 1e-4
k = 1 # the number of step of learning D before learning G
num_examples_to_generate = 1

BUFFER_SIZE = 10000
IMG_SIZE = 128
num_domain = 5
LAMBDA_class = 1
LAMBDA_reconstruction = 10
gp_lambda = 10

## Load the dataset

You can download celebA dataset from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/). 

As mentioned in the [paper](https://arxiv.org/abs/1703.10593) we apply random jittering and mirroring to the training dataset.
* In random jittering, the image is resized to 286 x 286 and then randomly cropped to 256 x 256
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

#### Actually create random data

In [None]:
N = 2000
train_images = np.random.uniform(low=-1., high=1., size=[N, IMG_SIZE, IMG_SIZE, 3]).astype(np.float32)
train_labels = np.random.uniform(low=0, high=num_domain, size=[N]).astype(np.int32)

test_images = np.random.uniform(low=-1., high=1., size=[N, IMG_SIZE, IMG_SIZE, 3]).astype(np.float32)
test_labels = np.random.uniform(low=0, high=num_domain, size=[N]).astype(np.int32)

## Set up dataset with `tf.data`

### Use tf.data to create batches, map(do preprocessing) and shuffle the dataset

In [None]:
def preprocessing(image, label):
  one_hot_label = tf.one_hot(label, depth=num_domain)
  return image, one_hot_label

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(preprocessing)
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
test_dataset = test_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.map(preprocessing)
test_dataset = test_dataset.batch(num_examples_to_generate, drop_remainder=True)

## Write the generator and discriminator models

### Generator

* The architecture of generator is similiar to [Johnson's architecture](https://arxiv.org/abs/1603.08155).
* Conv block in the generator is (Conv -> InstanceNorm -> ReLU)
* Res block in the generator is (Conv -> IN -> ReLU -> Conv -> IN -> add X -> ReLU)
* ConvTranspose block in the generator is (Transposed Conv -> IN -> ReLU) (except last layer: tanh)

In [None]:
class InstanceNormalization(layers.Layer):
  """InstanceNormalization for only 4-rank Tensor (image data)
  """
  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    shape = tf.TensorShape(input_shape)
    param_shape = shape[-1]
    # Create a trainable weight variable for this layer.
    self.gamma = self.add_weight(name='gamma',
                                 shape=param_shape,
                                 initializer='ones',
                                 trainable=True)
    self.beta = self.add_weight(name='beta',
                                shape=param_shape,
                                initializer='zeros',
                                trainable=True)
    # Make sure to call the `build` method at the end
    super(InstanceNormalization, self).build(input_shape)

  def call(self, inputs):
    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.get_shape()
    reduction_axes = [1, 2] # only shape index
    mean, variance = tf.nn.moments(inputs, reduction_axes, keepdims=True)
    normalized = (inputs - mean) / tf.sqrt(variance + self.epsilon)
    return self.gamma * normalized + self.beta

In [None]:
class Conv(tf.keras.Model):
  def __init__(self, filters, size, strides=1, padding='same',
               activation='relu', apply_norm='instance'):
    super(Conv, self).__init__()
    assert apply_norm in ['instance', 'none']
    self.apply_norm = apply_norm
    assert activation in ['relu', 'tanh', 'leaky_relu', 'none']
    self.activation = activation
    
    if self.apply_norm == 'none':
      use_bias = True
    else:
      use_bias = False
    
    self.conv = layers.Conv2D(filters=filters,
                              kernel_size=(size, size),
                              strides=strides,
                              padding=padding,
                              kernel_initializer=tf.random_normal_initializer(0., 0.02),
                              use_bias=use_bias)
    
    if self.apply_norm == 'instance':
      self.instancenorm = InstanceNormalization()
  
  def call(self, x):
    # convolution
    x = self.conv(x)
    
    # normalization
    if self.apply_norm == 'instance':
      x = self.instancenorm(x)
    
    # activation
    if self.activation == 'relu':
      x = tf.nn.relu(x)
    elif self.activation == 'tanh':
      x = tf.nn.tanh(x)
    elif self.activation == 'leaky_relu':
      x = tf.nn.leaky_relu(x, alpha=0.01)
    else:
      pass
    
    return x

In [None]:
class ResBlock(tf.keras.Model):
  def __init__(self, filters, size):
    super(ResBlock, self).__init__()
    self.conv1 = Conv(filters, size, activation='relu')
    self.conv2 = Conv(filters, size, activation='none')
  
  def call(self, x):
    conv = self.conv1(x)
    conv = self.conv2(conv)
    x = tf.nn.relu(x + conv)
    
    return x

In [None]:
class ConvTranspose(tf.keras.Model):
  def __init__(self, filters, size, apply_norm='instance'):
    super(ConvTranspose, self).__init__()
    assert apply_norm in ['instance', 'none']
    self.apply_norm = apply_norm
    self.up_conv = layers.Conv2DTranspose(filters=filters,
                                          kernel_size=(size, size),
                                          strides=2,
                                          padding='same',
                                          kernel_initializer=tf.random_normal_initializer(0., 0.02),
                                          use_bias=False)
    
    if self.apply_norm == 'instance':
      self.instancenorm = InstanceNormalization()

  def call(self, x):
    x = self.up_conv(x)
    if self.apply_norm == 'instance':
      x = self.instancenorm(x)
    x = tf.nn.relu(x)
    
    return x

In [None]:
class Generator(tf.keras.Model):
  def __init__(self):
    super(Generator, self).__init__()
    self.down1 = Conv(64, 7)
    self.down2 = Conv(128, 4, 2)
    self.down3 = Conv(256, 4, 2)
    
    self.res1 = ResBlock(256, 3)
    self.res2 = ResBlock(256, 3)
    self.res3 = ResBlock(256, 3)
    self.res4 = ResBlock(256, 3)
    self.res5 = ResBlock(256, 3)
    self.res6 = ResBlock(256, 3)
    
    self.up1 = ConvTranspose(128, 4)
    self.up2 = ConvTranspose(64, 3)
    self.last = Conv(3, 7, activation='tanh')

  def call(self, images, labels):
    # images shape: (bs, 128, 128, 3)
    # labels shape: (bs, num_domain) -> (bs, 128, 128, num_domain)
    # x shape: (bs, 128, 128, 3 + num_domain)
    labels = tf.expand_dims(tf.expand_dims(labels, axis=1), axis=2)
    x = tf.concat([images,
                   labels * tf.ones([images.shape[0],
                                     IMG_SIZE, IMG_SIZE, num_domain])], axis=3)
    x1 = self.down1(x)     # x1 shape: (bs, 128, 128, 32)
    x2 = self.down2(x1)    # x2 shape: (bs, 64, 64, 64)
    x3 = self.down3(x2)    # x3 shape: (bs, 32, 32, 128)
    
    x4 = self.res1(x3)     # x4 shape: (bs, 32, 32, 128)
    x5 = self.res2(x4)     # x5 shape: (bs, 32, 32, 128)
    x6 = self.res3(x5)     # x6 shape: (bs, 32, 32, 128)
    x7 = self.res4(x6)     # x7 shape: (bs, 32, 32, 128)
    x8 = self.res5(x7)     # x8 shape: (bs, 32, 32, 128)
    x9 = self.res6(x8)     # x8 shape: (bs, 32, 32, 128)

    x10 = self.up1(x9)     # x10 shape: (bs, 64, 64, 64)
    x11 = self.up2(x10)    # x11 shape: (bs, 128, 128, 32)
    generated_images = self.last(x11) # generated_images shape: (bs, 128, 128, 3)

    return generated_images

In [None]:
for images, labels in train_dataset.take(1):
  pass

In [None]:
# Create and test a generators
generator = Generator()

#gen_output = generator(images[tf.newaxis, ...], training=False)
gen_output = generator(images, labels)
plt.imshow(gen_output[0, ...])

### Discriminator

* The Discriminator is a variation of PatchGAN.
* Each block in the discriminator is (Conv -> Leaky ReLU), **NO** normalization
* The shape of the output after the last layer is (batch_size, 2, 2, 1)

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1711.09020).

In [None]:
class Discriminator(tf.keras.Model):
  def __init__(self):
    super(Discriminator, self).__init__()    
    self.down1 = Conv(64, 4, 2, activation='leaky_relu', apply_norm='none')
    self.down2 = Conv(128, 4, 2, activation='leaky_relu', apply_norm='none')
    self.down3 = Conv(256, 4, 2, activation='leaky_relu', apply_norm='none')
    self.down4 = Conv(512, 4, 2, activation='leaky_relu', apply_norm='none')
    self.down5 = Conv(1024, 4, 2, activation='leaky_relu', apply_norm='none')
    self.down6 = Conv(2048, 4, 2, activation='leaky_relu', apply_norm='none')
    
    self.source = Conv(1, 3, activation='none', apply_norm='none')
    self.classification = Conv(5, 2, padding='valid', activation='none', apply_norm='none')
  
  @tf.function
  def call(self, x):
    # x shape == (bs, 128, 128, 3)
    x = self.down1(x) # (bs, 64, 64, 64)
    x = self.down2(x) # (bs, 32, 32, 128)
    x = self.down3(x) # (bs, 16, 16, 256)
    x = self.down4(x) # (bs, 8, 8, 512)
    x = self.down5(x) # (bs, 4, 4, 1024)
    x = self.down6(x) # (bs, 2, 2, 2048)
    
    disc_logits = self.source(x)                   # (bs, 2, 2, 1)
    classification_logits = self.classification(x) # (bs, 1, 1, 5)
    classification_logits = tf.squeeze(classification_logits, axis=[1, 2])

    return disc_logits, classification_logits

In [None]:
# Create and test a discriminator
discriminator = Discriminator()

#disc_out = discriminator(images[tf.newaxis,...], training=False)
disc_out1, disc_out2 = discriminator(images)
print(disc_out2[0])
plt.imshow(disc_out1[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

## Model summary

In [None]:
generator.summary()

In [None]:
discriminator.summary()

## Define the loss functions and the optimizer

* **Discriminator loss**
  * The discriminator loss function takes 2 inputs; real images, generated images
  * real_loss is a sigmoid cross entropy loss of the real images and an array of ones(since these are the real images)
  * generated_loss is a sigmoid cross entropy loss of the generated images and an array of zeros(since these are the fake images)
  * Then the total_loss is the sum of real_loss and the generated_loss
* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an array of ones.
  * The paper also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the paper.

In [None]:
bce_object = tf.losses.BinaryCrossentropy(from_logits=True)
mse_object = tf.losses.MeanSquaredError()
mae_object = tf.losses.MeanAbsoluteError()

In [None]:
def GANLoss(logits, is_real=True, use_lsgan=True):
  """Computes standard GAN or LSGAN loss between `logits` and `labels`.

  Args:
    logits (`2-rank Tensor`): logits.
    is_real (`bool`): True means `1` labeling, False means `0` labeling.
    use_lsgan (`bool`): True means LSGAN loss, False means standard GAN loss.

  Returns:
    loss (`0-rank Tensor`): the standard GAN or LSGAN loss value. (binary_cross_entropy or mean_squared_error)
  """
  if is_real:
    labels = tf.ones_like(logits)
  else:
    labels = tf.zeros_like(logits)
    
  if use_lsgan:
    loss = mse_object(y_true=labels, y_pred=tf.nn.sigmoid(logits))
  else:
    loss = bce_object(y_true=labels, y_pred=logits)
    
  return loss

In [None]:
def WGANLoss(logits, is_real=True):
  """Computes Wasserstain GAN loss

  Args:
    logits (`2-rank Tensor`): logits
    is_real (`bool`): boolean, Treu means `-` sign, False means `+` sign.

  Returns:
    loss (`0-rank Tensor`): the WGAN loss value.
  """
  loss = tf.reduce_mean(logits)
  if is_real:
    loss = -loss

  return loss

In [None]:
def discriminator_loss(real_logits, fake_logits, real_class_logits, original_labels):
  # losses of real with label "1"
  real_loss = WGANLoss(logits=real_logits, is_real=True)
  # losses of fake with label "0"
  fake_loss = WGANLoss(logits=fake_logits, is_real=False)
  
  # domain classification loss
  domain_class_loss = bce_object(real_class_logits, original_labels)
  
  return real_loss + fake_loss + (LAMBDA_class * domain_class_loss)

In [None]:
def cycle_consistency_loss(X, X2Y2X):
  cycle_loss = mae_object(y_true=X, y_pred=X2Y2X) # L1 loss
  #cycle_loss = mse_object(y_true=X, y_pred=X2Y2X) # L2 loss
  
  return cycle_loss

In [None]:
def generator_loss(fake_logits, fake_class_logits, target_domain, input_images, generated_images_o2t2o):
  # losses of Generator with label "1" that used to fool the Discriminator
  gan_loss = WGANLoss(logits=fake_logits, is_real=True)
  
  # domain classification loss
  domain_class_loss = bce_object(fake_class_logits, target_domain)
  
  # mean absolute error
  cycle_loss = cycle_consistency_loss(input_images, generated_images_o2t2o)

  return gan_loss + (LAMBDA_class * domain_class_loss) + (LAMBDA_reconstruction * cycle_loss)

### Define learning rate decay functions

In [None]:
global_step = tf.Variable(0, trainable=False)

In [None]:
lr_D = learning_rate_D
def get_lr_D(global_step):
  global lr_D
  num_steps_per_epoch = int(N / batch_size)
  if global_step.numpy() > num_steps_per_epoch * constant_lr_epochs:
    decay_step = num_steps_per_epoch * decay_lr_epochs
    lr_D = lr_D - (learning_rate_D * 1. / decay_step) # tf.train.polynomial_decay (linear decay)
    return lr_D
  else:
    return lr_D

In [None]:
lr_G = learning_rate_G
def get_lr_G(global_step):
  global lr_G
  num_steps_per_epoch = int(N / batch_size)
  if global_step.numpy() > num_steps_per_epoch * constant_lr_epochs:
    decay_step = num_steps_per_epoch * decay_lr_epochs
    lr_G = lr_G - (learning_rate_G * 1. / decay_step) # tf.train.polynomial_decay (linear decay)
    return lr_G
  else:
    return lr_G

In [None]:
discriminator_optimizer = tf.keras.optimizers.Adam(get_lr_D(global_step), beta_1=0.5)
generator_optimizer = tf.keras.optimizers.Adam(get_lr_G(global_step), beta_1=0.5)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = train_dir
if not tf.io.gfile.exists(checkpoint_dir):
  tf.io.gfile.makedirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Define generate_and_print_or_save functions

In [None]:
def generate_and_print_or_save(inputs, lables, target_domain=None,
                               is_save=False, epoch=None, checkpoint_dir=checkpoint_dir):
  n = inputs.shape[0]
  if target_domain is None:
    target_domain = tf.random.uniform(shape=[n], minval=0, maxval=num_domain, dtype=tf.int32)
    target_domain = tf.one_hot(target_domain, depth=num_domain)
    
  assert n == target_domain.shape[0]
  generated_images_o2t = generator(inputs, target_domain)
  generated_images_o2t2o = generator(generated_images_o2t, lables)

  print_or_save_sample_images_pix2pix(const_test_inputs, generated_images_o2t, generated_images_o2t2o,
                                      model_name='stargan', name=None,
                                      is_save=is_save, epoch=epoch, checkpoint_dir=checkpoint_dir)

In [None]:
# keeping the constant test input for generation (prediction) so
# it will be easier to see the improvement of the pix2pix.
for inputs, labels in test_dataset.take(1):
  const_test_inputs = inputs
  const_test_labels = labels
  
const_target_domains = tf.random.uniform(shape=[const_test_inputs.shape[0]], minval=0, maxval=num_domain, dtype=tf.int32)
const_target_domains = tf.one_hot(const_target_domains, depth=num_domain)

In [None]:
# Check for test data X -> Y -> X
generate_and_print_or_save(const_test_inputs, const_test_labels, const_target_domains)

## Training

### Define training one step function

In [None]:
@tf.function()
def discriminator_train_step(input_images, labels):
  # generating target domain
  target_domain = tf.random.uniform(shape=[batch_size], minval=0, maxval=num_domain, dtype=tf.int32)
  target_domain = tf.one_hot(target_domain, depth=num_domain)
  
  with tf.GradientTape() as disc_tape:
    # Image generation from original domain to target domain
    generated_images_o2t = generator(input_images, target_domain)
    # Image generation from target domain to original domain
    generated_images_o2t2o = generator(generated_images_o2t, labels)

    real_logits, real_class_logits = discriminator(input_images)
    fake_logits, fake_class_logits = discriminator(generated_images_o2t)
    
    
    # interpolation of x hat for gradient penalty : epsilon * real image + (1 - epsilon) * generated image
    epsilon = tf.random.uniform([batch_size])
    epsilon = tf.expand_dims(tf.stack([tf.stack([epsilon]*IMG_SIZE, axis=1)]*IMG_SIZE, axis=1), axis=3)
    interpolated_images_4gp = epsilon * images + (1. - epsilon) * generated_images_o2t
    with tf.GradientTape() as gp_tape:
      gp_tape.watch(interpolated_images_4gp)
      interpolated_images_logits, _ = discriminator(interpolated_images_4gp)
      
    gradients_of_interpolated_images = gp_tape.gradient(interpolated_images_logits, interpolated_images_4gp)
    norm_grads = tf.sqrt(tf.reduce_sum(tf.square(gradients_of_interpolated_images), axis=[1, 2, 3]))
    gradient_penalty_loss = tf.reduce_mean(tf.square(norm_grads - 1.))
    
    disc_loss = discriminator_loss(real_logits, fake_logits, real_class_logits, labels) + \
                    gp_lambda * gradient_penalty_loss
    gen_loss = generator_loss(fake_logits, fake_class_logits, target_domain, input_images, generated_images_o2t2o)
    
  gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
  
  return gen_loss, disc_loss

In [None]:
@tf.function()
def generator_train_step(input_images, labels):
  # generating target domain
  target_domain = tf.random.uniform(shape=[batch_size], minval=0, maxval=num_domain, dtype=tf.int32)
  target_domain = tf.one_hot(target_domain, depth=num_domain)
  
  with tf.GradientTape() as gen_tape:
    # Image generation from original domain to target domain
    generated_images_o2t = generator(input_images, target_domain)
    # Image generation from target domain to original domain
    generated_images_o2t2o = generator(generated_images_o2t, labels)

    real_logits, real_class_logits = discriminator(input_images)
    fake_logits, fake_class_logits = discriminator(generated_images_o2t)

    gen_loss = generator_loss(fake_logits, fake_class_logits, target_domain, input_images, generated_images_o2t2o)

  gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

### Training until max_epochs

In [None]:
print('Start Training.')
num_batches_per_epoch = int(N / batch_size)
num_learning_critic = 0
for epoch in range(max_epochs):

  for step, (images, labels) in enumerate(train_dataset):
    start_time = time.time()
    
    if num_learning_critic < k:
      gen_loss, disc_loss = discriminator_train_step(images, labels)
      num_learning_critic += 1
      global_step.assign_add(1)
    else:
      generator_train_step(images, labels)
      num_learning_critic = 0

    # print the result images every print_steps
    if global_step.numpy() % print_steps == 0:
      epochs = epoch + step / float(num_batches_per_epoch)
      duration = time.time() - start_time
      examples_per_sec = batch_size / float(duration)
      display.clear_output(wait=True)
      print("Epochs: {:.2f} lr: {:.3g}, {:.3g}, global_step: {} loss_D: {:.3g} loss_G: {:.3g} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, generator_optimizer.lr.numpy(), discriminator_optimizer.lr.numpy(), global_step.numpy(), disc_loss, gen_loss, examples_per_sec, duration))
      # generate image to target domain for test_dataset
      for test_inputs, test_labels in test_dataset.take(1):
        generate_and_print_or_save(test_inputs, test_labels)

  # saving the result image files every save_images_epochs
  if (epoch + 1) % save_images_epochs == 0:
    display.clear_output(wait=True)
    print("This images are saved at {} epoch".format(epoch+1))
    generate_and_print_or_save(const_test_inputs, const_test_labels, const_target_domains,
                               is_save=True, epoch=epoch+1, checkpoint_dir=checkpoint_dir)

  # saving (checkpoint) the model every save_epochs
  if (epoch + 1) % save_model_epochs == 0:
    checkpoint.save(file_prefix=checkpoint_prefix)
    
print('Training Done.')

In [None]:
# generating after the final epoch
display.clear_output(wait=True)
generate_and_print_or_save(const_test_inputs, const_test_labels, const_target_domains)

## Restore the latest checkpoint

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Display an image using the epoch number

In [None]:
display_image(max_epochs, checkpoint_dir=checkpoint_dir)

## Generate a GIF of all the saved images.

In [None]:
filename = model_name + '_' + dataset_name + '.gif'
generate_gif(filename, checkpoint_dir)

In [None]:
display.Image(filename=filename + '.png')