# WGAN for creating cartoon images out of celebrity images
### Authors: Memeti Nurdzane, Wróbel Anna
Aim of this project is to translate cartoon style to celebrity images.
It uses celebrity dataset (aligned and cropped): celbA and cartoon10k dataset (https://google.github.io/cartoonset/download.html)
Information about image attributes was not used. In this part we try to implement and tune WGAN version of cyclegan.

The difference from cycleGAN implementation is in:
- the activation function of last layer of the discriminator to linear,
- implementing Wasserstein loss for generator and discriminator,
- clipped weights (Lipschitz contraint)
- use of RMSprop optimizers with no momentum.


### Prepare environment

Mount drive and change directories.


In [0]:
from google.colab import drive
# drive.flush_and_unmount()
drive.mount('/content/drive')
%cd /content/drive/My\ Drive/Colab/ANN/
!ls | head

Import tensorflow and check if GPU is available.


In [0]:
import tensorflow as tf
print(tf.__version__)
assert len(tf.config.list_physical_devices('GPU')) > 0
print('GPU device name: ' + tf.test.gpu_device_name())
!nvidia-smi

Import libraries needed for the project.

In [0]:
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output


## Prepare dataset

Create datasets as tf datasets together with image preprocessing:
- random jitter: the image is resized to 286 x 286 and then randomly cropped to 256 x 256
- random mirroring, the image is randomly flipped horizontally 
- normalization to range -1, 1

In [0]:
import glob
import tensorflow as tf

def load(filelist):
  assert len(filelist) > 0
  filenames = tf.constant(filelist)
  dataset = tf.data.Dataset.from_tensor_slices((filenames))
  def _parse_function(filename):
      image_string = tf.io.read_file(filename)
      image_decoded = tf.image.decode_jpeg(image_string, channels=3)
      image_resized = tf.image.resize(image_decoded, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
      cropped_image = tf.image.random_crop(image_resized, size=[256, 256, 3])
      flipped_image = tf.image.random_flip_left_right(cropped_image)
      casted_image = tf.cast(flipped_image, tf.float32)
      normalized_image = (casted_image / 127.5) - 1

      return normalized_image
  return dataset.map(_parse_function)

BUFFER_SIZE = 1000
BATCH_SIZE = 1

n =1000      #number of train images in one dataset
k = 100      #number of test images in one dataset

#Create filelist and cartoon train (n images) and test set (k images)
filelist = glob.glob('/content/drive/My Drive/Colab/ANN/cartoons/*.jpg')
cartoons_train = load(filelist[0:n]).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
cartoons_test = load(filelist[n:n+k]).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

#Create filelist#Create filelist and cartoon train (n images) and test set (k images)
filelist = glob.glob('/content/drive/My Drive/Colab/ANN/celebrities/*.jpg')
celebrities_train = load(filelist[0:n]).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
celebrities_test = load(filelist[n:n+k]).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)


Check sample images

In [0]:
import matplotlib.pyplot as plt

sample_cartoon = next(iter(cartoons_train))
sample_celebrity = next(iter(celebrities_train))
sample_cartoon_test = next(iter(cartoons_test))
sample_celebrity_test = next(iter(celebrities_test))

# plt.title('Cartoon')
# plt.imshow(sample_cartoon[0] * 0.5 + 0.5)

plt.title('Celebrity')
plt.imshow(sample_celebrity[0] * 0.5 + 0.5)



## Build network and define loss functions

Define modified unet network based on pix2pix architecture from tensorflow examples (https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) adjusted to the needs of the cycleGAN for our purpose. For the cycleGAN instance normalization will be used. 
For the WGAN implementation activation function in discriminators last layer is changed to linear.

In [0]:
class InstanceNormalization(tf.keras.layers.Layer):
  """Instance Normalization Layer."""

  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    self.scale = self.add_weight(
        name='scale',
        shape=input_shape[-1:],
        initializer=tf.random_normal_initializer(1., 0.02),
        trainable=True)

    self.offset = self.add_weight(
        name='offset',
        shape=input_shape[-1:],
        initializer='zeros',
        trainable=True)

  def call(self, x):
    mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
    inv = tf.math.rsqrt(variance + self.epsilon)
    normalized = (x - mean) * inv
    return self.scale * normalized + self.offset

def downsample(filters, size, apply_norm=True):
  """Downsamples an input.
  Conv2D => Instancenorm => LeakyRelu
  Args:
    filters: number of filters
    size: filter size,
    apply_norm: If True, adds the instance norm layer
  Returns: Downsample Sequential Model
  """
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_norm:
      result.add(InstanceNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

def upsample(filters, size, apply_dropout=False):
  """Upsamples an input.
  Conv2DTranspose => Instancenorm => Dropout => Relu
  Args:
    filters: number of filters
    size: filter size
    apply_dropout: If True, adds the dropout layer
  Returns: Upsample Sequential Model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))
  
  result.add(InstanceNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

def unet_generator(output_channels):
  """Modified u-net generator model.
  Returns: Generator model
  """

  down_stack = [
      downsample(64, 4, apply_norm=False),
      downsample(128, 4),
      downsample(256, 4),
      downsample(512, 4),
      downsample(512, 4),
      downsample(512, 4),
      downsample(512, 4),
      downsample(512, 4),
  ]

  up_stack = [
      upsample(512, 4, apply_dropout=True),
      upsample(512, 4, apply_dropout=True),
      upsample(512, 4, apply_dropout=True),
      upsample(512, 4),
      upsample(256, 4),
      upsample(128, 4),
      upsample(64, 4),
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 4, strides=2,
      padding='same', kernel_initializer=initializer,
      activation='tanh')  

  # concat = tf.keras.layers.Concatenate()

  inputs = tf.keras.layers.Input(shape=[256, 256, 3])
  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    # x = concat([x, skip])
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)


def discriminator():
  """PatchGan discriminator model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  x = inp

  down1 = downsample(64, 4, False)(x) 
  down2 = downsample(128, 4)(down1) 
  down3 = downsample(256, 4)(down2)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)
  conv = tf.keras.layers.Conv2D(
      512, 4, strides=1, kernel_initializer=initializer,
      use_bias=False)(zero_pad1) 

  norm1 = InstanceNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)

  last = tf.keras.layers.Conv2D(
      1, 4, strides=1,
      kernel_initializer=initializer, activation='linear')(zero_pad2)

  return tf.keras.Model(inputs=inp, outputs=last)


Create generators and discrimnators with defined above modified unet structure and instance normalization. Wasserstein loss will is used to determine genrators and discriminators losses. During training weight clipping will be applied to limit weights to the range defined by the threshold.

In [0]:
OUTPUT_CHANNELS = 3

# Generators
generator_AB = unet_generator(OUTPUT_CHANNELS)
generator_BA = unet_generator(OUTPUT_CHANNELS)

# Discriminators
discriminator_A = discriminator()
discriminator_B = discriminator()

LAMBDA = 10

# Define Wasserstein loss
from keras import backend
def wasserstein_loss(true, pred):

  return backend.mean(true * pred)

# Define discrimnator loss 
def discriminator_loss(real, generated):
  real_loss = wasserstein_loss(-tf.ones_like(real), real)

  generated_loss = wasserstein_loss(tf.ones_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

# Define generator loss
def generator_loss(generated):
  return wasserstein_loss(tf.ones_like(generated), generated)

# Define cycle loss
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

#Define identity loss
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

lr1 = 0.00005
lr2 = 0.00001
generator_AB_optimizer = tf.keras.optimizers.RMSprop(lr=lr2)  #slower generator
generator_BA_optimizer = tf.keras.optimizers.RMSprop(lr=lr2)
discriminator_A_optimizer = tf.keras.optimizers.RMSprop(lr=lr1) #faster discriminator
discriminator_B_optimizer = tf.keras.optimizers.RMSprop(lr=lr1)

# weights clipping for Lipschitz constraint
# threshold = 0.1
threshold = 0.01
def update_weights(model):
  for l in model.layers:
    weights = l.get_weights()
    weights = [np.clip(w,-threshold ,threshold) for w in weights]
    l.set_weights(weights)


### Model structure
Check the structure of generators and discriminatores
Both generators have the same structure so it is enough to check only one.

In [0]:
# Generator AB structure
tf.keras.utils.plot_model(generator_AB, show_shapes=True, dpi=64)


In [0]:
# generator AB summary
generator_AB.summary()

In [0]:
# discriminator A structure
tf.keras.utils.plot_model(discriminator_A, show_shapes=True, dpi=64)

In [0]:
# discriminator A summary
discriminator_A.summary()

Saving models

In [0]:
# Save checpoint for results analysis and retraining from a certain point
checkpoint_path = "./checkpoints2/train"

ckpt = tf.train.Checkpoint(generator_AB=generator_AB,
                           generator_BA=generator_BA,
                           discriminator_A=discriminator_A,
                           discriminator_B=discriminator_B,
                           generator_AB_optimizer=generator_AB_optimizer,
                           generator_BA_optimizer=generator_BA_optimizer,
                           discriminator_A_optimizer=discriminator_A_optimizer,
                           discriminator_B_optimizer=discriminator_B_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=2)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print('Latest checkpoint restored!!')

## Training

Define training steps and function fro generating and plotting images.

In [0]:
# Plot image with its prediction and save it as .png file
def generate_images(model, test_input, fname):
  prediction = model(test_input)
  
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  # save figure
  plt.savefig('/content/drive/My Drive/Colab/ANN/results_final/' + fname + '.png')
  plt.show()

# Define trainign step
@tf.function
def train_step(real_A, real_B):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  # A stands for celebrity
  # B stands for cartoon

  with tf.GradientTape(persistent=True) as tape:
    # Generator AB translates A(celebrity) -> B(cartoon)
    # Generator BA translates B -> A
    
    fake_B = generator_AB(real_A, training=True)
    cycled_A = generator_BA(fake_B, training=True)

    fake_A = generator_BA(real_B, training=True)
    cycled_B = generator_AB(fake_A, training=True)

    # same_x and same_y are used for identity loss.
    same_A = generator_BA(real_A, training=True)
    same_B = generator_BA(real_B, training=True)

    disc_real_A = discriminator_A(real_A, training=True)
    disc_real_B = discriminator_B(real_B, training=True)

    disc_fake_A = discriminator_A(fake_A, training=True)
    disc_fake_B = discriminator_B(fake_B, training=True)

    # calculate the loss
    gen_AB_loss = generator_loss(disc_fake_B) # adversarial
    gen_BA_loss = generator_loss(disc_fake_A) # adversarial
    
    total_cycle_loss = calc_cycle_loss(real_A, cycled_A) + calc_cycle_loss(real_B, cycled_B)
    
    # Total generator loss = adversarial loss + cycle loss + identity
    total_gen_AB_loss = gen_AB_loss + total_cycle_loss + identity_loss(real_B, same_B)
    total_gen_BA_loss = gen_BA_loss + total_cycle_loss + identity_loss(real_A, same_A)

    disc_A_loss = discriminator_loss(disc_real_A, disc_fake_A)
    disc_B_loss = discriminator_loss(disc_real_B, disc_fake_B)
  
  # Calculate the gradients for generators and discriminators
  generator_AB_gradients = tape.gradient(total_gen_AB_loss, 
                                        generator_AB.trainable_variables)
  generator_BA_gradients = tape.gradient(total_gen_BA_loss, 
                                        generator_BA.trainable_variables)
  
  discriminator_A_gradients = tape.gradient(disc_A_loss, 
                                            discriminator_A.trainable_variables)
  discriminator_B_gradients = tape.gradient(disc_B_loss, 
                                            discriminator_B.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_AB_optimizer.apply_gradients(zip(generator_AB_gradients, 
                                            generator_AB.trainable_variables))

  generator_BA_optimizer.apply_gradients(zip(generator_BA_gradients, 
                                            generator_BA.trainable_variables))
  
  discriminator_A_optimizer.apply_gradients(zip(discriminator_A_gradients,
                                                discriminator_A.trainable_variables))
  
  discriminator_B_optimizer.apply_gradients(zip(discriminator_B_gradients,
                                                discriminator_B.trainable_variables))

  return total_gen_AB_loss, total_gen_BA_loss, disc_A_loss, disc_A_loss

Training is performed for a defined number of EPOCHS with monitoring progres on chosen sample images and saving loss functions.

In [0]:
# if does not exist create folder for storing results
!mkdir -p '/content/drive/My Drive/Colab/ANN/results_final'

# pick two images for training progress monitoring
sample_celeb = next(iter(celebrities_train))
sample_cart = next(iter(cartoons_train))

# initialize arrays for generators and discriminators losses
gen_AB_loss_all = []
gen_BA_loss_all = []
disc_A_loss_all = []
disc_B_loss_all = []

EPOCHS = 20

for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_A, image_B in tf.data.Dataset.zip((celebrities_train, cartoons_train)):

    total_gen_AB_loss, total_gen_BA_loss, disc_A_loss, disc_B_loss = (
        train_step(image_A, image_B))
    
    update_weights(discriminator_A)
    update_weights(discriminator_B)
    
    gen_AB_loss_all.append(total_gen_AB_loss)
    gen_BA_loss_all.append(total_gen_BA_loss)
    disc_A_loss_all.append(disc_A_loss)
    disc_B_loss_all.append(disc_B_loss)

    if n % 5 == 0:
      print ('.', end='')
    n+=1

  # Using a consistent image (sample_celeb and sample_cart ) to monitor 
  # training progress
  generate_images(generator_AB, sample_celeb, 'train_celeb_' + str(epoch))
  generate_images(generator_BA, sample_cart, 'train_cart' + str(epoch))

  # save every 10th checkpoint
  if (epoch + 1) % 10 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))


## Check model performance

Plot the loss functions 

In [0]:
plt.plot(gen_AB_loss_all, 'y', label = 'gen_AB_loss')
plt.plot(gen_BA_loss_all, 'b', label = 'gen_BA_loss')
plt.plot(disc_A_loss_all, 'r', label = 'disc_A_loss')
plt.plot(disc_B_loss_all, 'g', label = 'disc_B_loss')
plt.legend()

plt.savefig('/content/drive/My Drive/Colab/ANN/results_final/losses.png')


Generate images using test dataset

In [0]:
# Run the trained model on the test dataset
i = 1
for inp in celebrities_test.take(3):
  generate_images(generator_AB, inp, 'celeb_to_cartoon_test_' + str(i))
  i+=1

i = 1
for inp in cartoons_test.take(3):
  generate_images(generator_BA, inp, 'cartoon_to_celeb_test_' + str(i))
  i+=1