# Pix2Pix, aka CGAN (Conditional GAN)
Pix2Pix is a type GAN use for image-to-image translation task. Here, an input image is translated into another image based on some condition. 

In this example, we will use the CMP Facade Database. We will use a preprocessed copy of this dataset provided by the authors of the above paper.

In [0]:
# Importing libraries
import tensorflow as tf
import matplotlib.pyplot as plt

from os import path
from datetime import datetime
from tqdm.notebook import tqdm

%matplotlib inline

In [0]:
# Downloading Tensorboard for visualization
!pip install -q -U tensorboard

In [0]:
# Global vars & hyper-params
IMG_HEIGHT = 256
IMG_WIDTH = 256
URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'

BUFFER_SIZE = 400
BATCH_SIZE = 1
LAMBDA = 100
EPOCHS = 150

## Utility Functions

In [0]:
def show_results(input_img, real_img, gen_img, save_fig=False):
  titles = ["Input Image", "Real Image", "Generated Image"]
  images = [input_img[0], real_img[0], gen_img[0]]
  fig = plt.figure() 
  fig.figsize=(15,15)
  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(titles[i])
    # Pixel values should be in the range 0-1 for float values
    plt.imshow(images[i] * 0.5 + 0.5)
    plt.axis("off")

  plt.show()
  
  if save_fig:
    f_name = datetime.now().strftime("%Y%m%d-%H%M%S")
    fig.savefig(f"{f_name}.png")

In [0]:
def load_img(img_path):
  raw_contents = tf.io.read_file(img_path)
  image = tf.image.decode_jpeg(raw_contents)
  w = tf.shape(image)[1] // 2
  real_img = tf.cast(image[:, :w, :], tf.float32)
  input_img = tf.cast(image[:, w:, :], tf.float32)

  return input_img, real_img

In [0]:
def resize_image(input_img, real_img, height, width):
  resized_input_img = tf.image.resize(input_img, [height, width], 
                                      method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  resized_real_img = tf.image.resize(real_img, [height, width], 
                                      method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return resized_input_img, resized_real_img

In [0]:
def get_random_crop(input_img, real_img):
  stacked_imgs = tf.stack([input_img, real_img], axis=0)
  cropped_imgs = tf.image.random_crop(stacked_imgs, [2, IMG_HEIGHT, IMG_WIDTH, 3])
  
  return cropped_imgs[0], cropped_imgs[1]

In [0]:
def normalize_img(input_img, real_img):
  """
  Normalizing images in the range of [-1, 1]
  """

  normalized_input_img = (input_img / 127.5) - 1
  normalized_real_img = (real_img / 127.5) - 1

  return normalized_input_img, normalized_real_img

In [0]:
def preprocess_image(input_img, real_img):
  """
  This function will apply random jittering, cropping, mirroring and 
  normalization. 
  """

  # Normalizing images
  normalize_input_img, normalized_real_img = normalize_img(input_img, real_img)

  # Resizing images to 286 X 286 X 3
  resized_input_img, resized_real_img = resize_image(normalize_input_img, 
                                      normalized_real_img, 286, 286)

  # Cropping images to 256 X 256 X 3
  processed_input_img, processed_real_img = get_random_crop(resized_input_img, 
                                                        resized_real_img)
  
  # Applying random mirroring
  if tf.random.uniform([]) > 0.5:
    processed_input_img = tf.image.flip_left_right(processed_input_img)
    processed_real_img = tf.image.flip_left_right(processed_real_img)

  return processed_input_img, processed_real_img

In [0]:
def prepare_training_data(img_path):
  input_img, real_img = load_img(img_path)
  return preprocess_image(input_img, real_img)

In [0]:
def prepare_test_data(img_path):
  input_img, real_img = load_img(img_path)
  resized_input_img, resized_real_img = resize_image(input_img, real_img, 
                                                   IMG_HEIGHT, IMG_WIDTH)
  return normalize_img(resized_input_img, resized_real_img)

In [0]:
def get_normal_init(mean, std):
  return tf.random_normal_initializer(mean, std)

In [0]:
def upsample(filters, kernel_size, stride=2, pad="same", apply_dropout=False):
  mini_net = tf.keras.Sequential()
  mini_net.add(
      tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides=stride, 
                              padding=pad, 
                              kernel_initializer=get_normal_init(0.0, 0.02), 
                              use_bias=False))
  mini_net.add(tf.keras.layers.BatchNormalization())
  
  if apply_dropout:
    mini_net.add(tf.keras.layers.Dropout(0.5))

  mini_net.add(tf.keras.layers.ReLU())

  return mini_net

In [0]:
def downsample(filters, kernel_size, stride=2, pad="same", apply_bn=True):
  mini_net = tf.keras.Sequential()
  mini_net.add(
      tf.keras.layers.Conv2D(filters, kernel_size, strides=stride, 
                              padding=pad, 
                              kernel_initializer=get_normal_init(0.0, 0.02), 
                              use_bias=False))  
  if apply_bn:
    mini_net.add(tf.keras.layers.BatchNormalization())

  mini_net.add(tf.keras.layers.LeakyReLU())

  return mini_net

In [0]:
@tf.function
def train(generator, discriminator, input_img, target_img, sum_writer, epoch):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_img, training=True)
    real_output = discriminator(tf.concat([input_img, target_img], axis=3), training=True)
    fake_output = discriminator(tf.concat([input_img, gen_output], axis=3), training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator.loss(
        fake_output, gen_output, target_img)
    disc_loss = discriminator.loss(real_output, fake_output)

  gen_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
  disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

  generator.optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
  discriminator.optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

  with sum_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
    tf.summary.scalar('disc_loss', disc_loss, step=epoch)

## Preparing Dataset

In [0]:
zip_path = tf.keras.utils.get_file("facades.tar.gz", URL, extract=True)
DATASET_PATH = path.join(path.split(zip_path)[0], "facades/")

In [0]:
# Preparing training data
training_data = tf.data.Dataset.list_files(path.join(DATASET_PATH+"train/*.jpg"))
training_data = training_data.map(prepare_training_data, 
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
training_data = training_data.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Preparing test data
test_data = tf.data.Dataset.list_files(path.join(DATASET_PATH+"test/*.jpg"))
test_data = test_data.map(prepare_test_data, 
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_data = test_data.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

## Models & Losses

In [0]:
class Generator(tf.keras.Model):

  def __init__(self, out_channels):
    super(Generator, self).__init__()
    self.loss_cal = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    self.optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    self.down_stack = [
          downsample(64, 4, apply_bn=False), # (BATCH_SIZE, 128, 128, 64)
          downsample(128, 4), # (BATCH_SIZE, 64, 64, 128)
          downsample(256, 4), # (BATCH_SIZE, 32, 32, 256)
          downsample(512, 4), # (BATCH_SIZE, 16, 16, 512)
          downsample(512, 4), # (BATCH_SIZE, 8, 8, 512)
          downsample(512, 4), # (BATCH_SIZE, 4, 4, 512)
          downsample(512, 4), # (BATCH_SIZE, 2, 2, 512)
          downsample(512, 4) # (BATCH_SIZE, 1, 1, 512)
    ]

    self.up_stack = [
          upsample(512, 4, apply_dropout=True), # (BATCH_SIZE, 2, 2, 512)
          upsample(512, 4, apply_dropout=True), # (BATCH_SIZE, 4, 4, 512)
          upsample(512, 4, apply_dropout=True), # (BATCH_SIZE, 8, 8, 512)
          upsample(512, 4), # (BATCH_SIZE, 16, 16, 512)
          upsample(256, 4), # (BATCH_SIZE, 32, 32, 256)
          upsample(128, 4), # (BATCH_SIZE, 64, 64, 128)
          upsample(64, 4) # (BATCH_SIZE, 128, 128, 64)
    ]

    self.out = tf.keras.layers.Conv2DTranspose(out_channels, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=get_normal_init(0.0, 0.02),
                                  activation='tanh') # (BATCH_SIZE, 256, 256, 3)
    
  def call(self, inputs):
    skips = []
    # Encoder part
    for down in self.down_stack:
      inputs = down(inputs)
      skips.append(inputs)

    skips = reversed(skips[:-1])

    # Decoder part with skip connections
    for up, skip in zip(self.up_stack, skips):
      inputs = up(inputs)
      inputs = tf.keras.layers.Concatenate()([inputs, skip])

    return self.out(inputs)

  def loss(self, fake_out, gen_img, target_img):
    # Typical GAN loss
    gan_loss = self.loss_cal(tf.ones_like(fake_out), fake_out)
    # Image construction loss
    l1_loss = tf.reduce_mean(tf.abs(target_img - gen_img))
    total_loss = gan_loss + (LAMBDA * l1_loss)

    return total_loss, gan_loss, l1_loss

In [0]:
class Discriminator(tf.keras.Model):

  def __init__(self):
    super(Discriminator, self).__init__()
    self.loss_cal = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    self.optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    self.down_stack = [
          downsample(64, 4, apply_bn=False), # (BATCH_SIZE, 128, 128, 64)
          downsample(128, 4), # (BATCH_SIZE, 64, 64, 128)
          downsample(256, 4), # (BATCH_SIZE, 32, 32, 256)
          tf.keras.layers.ZeroPadding2D(), # (BATCH_SIZE, 34, 34, 256)
          downsample(512, 4, 1, "valid"), # (BATCH_SIZE, 31, 31, 512)
          tf.keras.layers.ZeroPadding2D(), # (BATCH_SIZE, 33, 33, 256)
    ]

    self.out = tf.keras.layers.Conv2D(1, 4, 
                                  strides=1, 
                                  kernel_initializer=get_normal_init(0.0, 0.02)
                                  ) # (BATCH_SIZE, 30, 30, 1)

  def call(self, inputs):
    for down in self.down_stack:
      inputs = down(inputs)

    return self.out(inputs)

  def loss(self, real_out, fake_out):
    real_loss = self.loss_cal(tf.ones_like(fake_out), real_out)
    fake_loss = self.loss_cal(tf.zeros_like(fake_out), fake_out)
    return real_loss + fake_loss

## Let Training Begin

In [0]:
generator = Generator(3)
discriminator = Discriminator()

log_dir = "logs/"
summary_writer = tf.summary.create_file_writer(
    log_dir + "fit/" + datetime.now().strftime("%Y%m%d-%H%M%S"))

%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [0]:
for epoch in range(EPOCHS):
  print(f"Running epoch: {epoch}")
  train_data = training_data.as_numpy_iterator()
  for _ in tqdm(range(BUFFER_SIZE)):
    input_img, target_img = next(train_data)
    train(generator, discriminator, input_img, target_img, summary_writer, epoch)

  if (epoch+1) % 20 == 0:
    for input_img, target_img in test_data.take(1):
      gen_img = generator(input_img, training=True)
      show_results(input_img, target_img, gen_img, True)
      generator.save(f"gen_{epoch+1}")
      discriminator.save(f"disc_{epoch+1}")