# pix2pix: Image-to-image translation with a conditional GAN

In [77]:
! pip install pydot



## Import TensorFlow and other libraries

In [1]:
import tensorflow as tf

import os
import pathlib
import time
import datetime
import pydot as pyd
import keras

from matplotlib import pyplot as plt
from IPython import display

keras.utils.vis_utils.pydot = pyd
import cv2
import numpy as np
import uuid
from PIL import Image

Using TensorFlow backend.


Each original image is of size `256 x 512` containing two `256 x 256` images:

In [3]:
PATH = '/home/cgac/Documents/agri-train-ds/'
SUB_PATH = "mask-all-in-one"
OUTPUT_GENERATOR= "/home/cgac/Documents/agri-train-ds/mask_disease_classification_pix2pix/corns"
number_generate = 5
# sample_image = tf.io.read_file(PATH  + 'train/1.jpg')
# sample_image = tf.io.decode_jpeg(sample_image)
# print(sample_image.shape)

In [80]:
# plt.figure()
# plt.imshow(sample_image)

You need to separate real building facade images from the architecture label images—all of which will be of size `256 x 256`.

Define a function that loads image files and outputs two image tensors:

In [4]:
def load(image_file):
    # Read and decode an image file to a uint8 tensor
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)

    # Split each image tensor into two tensors:
    # - one with a real building facade image
    # - one with an architecture label image 
    w = tf.shape(image)[1]
    w =   w // 2
    input_image = image[:, w:, :]
    real_image = image[:, :w, :]

#     print(real_image)
    input_image =tf.image.resize(input_image, (256,256), preserve_aspect_ratio=False,antialias=False)
    real_image = tf.image.resize(real_image, (256,256), preserve_aspect_ratio=False,antialias=False)
    # Convert both images to float32 tensors
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image

Plot a sample of the input (architecture label image) and real (building facade photo) images:

In [82]:
# inp, re = load(PATH + 'train/1.jpg' )
# # Casting to int for matplotlib to display the images
# plt.figure()
# plt.imshow(inp / 255.0)
# plt.figure()
# plt.imshow(re / 255.0)

As described in the [pix2pix paper](https://arxiv.org/abs/1611.07004), you need to apply random jittering and mirroring to preprocess the training set.

Define several functions that:

1. Resize each `256 x 256` image to a larger height and width—`286 x 286`.
2. Randomly crop it back to `256 x 256`.
3. Randomly flip the image horizontally i.e. left to right (random mirroring).
4. Normalize the images to the `[-1, 1]` range.

In [5]:
# The facade training set consist of 400 images
BUFFER_SIZE = 400
# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment
BATCH_SIZE = 1
# Each image is 256x256 in size
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [6]:
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

In [7]:
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]

In [8]:
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

In [9]:
@tf.function()
def random_jitter(input_image, real_image):
  # Resizing to 286x286
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # Random cropping back to 256x256
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # Random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

You can inspect some of the preprocessed output:

In [10]:
# plt.figure(figsize=(6, 6))
# for i in range(4):
#   rj_inp, rj_re = random_jitter(inp, re)
#   plt.subplot(2, 2, i + 1)
#   plt.imshow(rj_inp / 255.0)
#   plt.axis('off')
# plt.show()

Having checked that the loading and preprocessing works, let's define a couple of helper functions that load and preprocess the training and test sets:

In [12]:
def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

In [13]:
def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

## Build an input pipeline with `tf.data`

In [None]:

PATH_IMG = pathlib.Path(PATH)
# PATH_IMG = list(PATH_IMG.glob('*.*'))
# print(PATH_IMG)
train_dataset = tf.data.Dataset.list_files( str(PATH_IMG / f'{SUB_PATH}/*.*'))


train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

In [92]:
# for i in  train_dataset.take(1):
#     print(i)

## Build the generator

The generator of your pix2pix cGAN is a _modified_ [U-Net](https://arxiv.org/abs/1505.04597). A U-Net consists of an encoder (downsampler) and decoder (upsampler). (You can find out more about it in the [Image segmentation](https://www.tensorflow.org/tutorials/images/segmentation) tutorial and on the [U-Net project website](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/).)

- Each block in the encoder is: Convolution -> Batch normalization -> Leaky ReLU
- Each block in the decoder is: Transposed convolution -> Batch normalization -> Dropout (applied to the first 3 blocks) -> ReLU
- There are skip connections between the encoder and decoder (as in the U-Net).

Define the downsampler (encoder):

In [14]:
OUTPUT_CHANNELS = 3

In [15]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [95]:
# down_model = downsample(3, 4)
# down_result = down_model(tf.expand_dims(inp, 0))
# print (down_result.shape)

Define the upsampler (decoder):

In [16]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [97]:
# up_model = upsample(3, 4)
# up_result = up_model(down_result)
# print (up_result.shape)

Define the generator with the downsampler and the upsampler:

In [17]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[256, 256, 3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),  # (batch_size, 64, 64, 128)
    downsample(256, 4),  # (batch_size, 32, 32, 256)
    downsample(512, 4),  # (batch_size, 16, 16, 512)
    downsample(512, 4),  # (batch_size, 8, 8, 512)
    downsample(512, 4),  # (batch_size, 4, 4, 512)
    downsample(512, 4),  # (batch_size, 2, 2, 512)
    downsample(512, 4),  # (batch_size, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(512, 4),  # (batch_size, 16, 16, 1024)
    upsample(256, 4),  # (batch_size, 32, 32, 512)
    upsample(128, 4),  # (batch_size, 64, 64, 256)
    upsample(64, 4),  # (batch_size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Visualize the generator model architecture:

In [None]:
generator = Generator()
len(generator.layers)
# tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64,to_file='pix2pix_figure/gen_net.png')

In [100]:
print(tf.newaxis)

None


Test the generator:

In [101]:
# gen_output = generator(inp[tf.newaxis, ...], training=False)
# plt.imshow(gen_output[0, ...])

### Define the generator loss

GANs learn a loss that adapts to the data, while cGANs learn a structured loss that penalizes a possible structure that differs from the network output and the target image, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).

- The generator loss is a sigmoid cross-entropy loss of the generated images and an **array of ones**.
- The pix2pix paper also mentions the L1 loss, which is a MAE (mean absolute error) between the generated image and the target image.
- This allows the generated image to become structurally similar to the target image.
- The formula to calculate the total generator loss is `gan_loss + LAMBDA * l1_loss`, where `LAMBDA = 100`. This value was decided by the authors of the paper.

In [102]:
LAMBDA = 100

In [103]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [104]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # Mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

The training procedure for the generator is as follows:

## Build the discriminator

The discriminator in the pix2pix cGAN is a convolutional PatchGAN classifier—it tries to classify if each image _patch_ is real or not real, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).

- Each block in the discriminator is: Convolution -> Batch normalization -> Leaky ReLU.
- The shape of the output after the last layer is `(batch_size, 30, 30, 1)`.
- Each `30 x 30` image patch of the output classifies a `70 x 70` portion of the input image.
- The discriminator receives 2 inputs: 
    - The input image and the target image, which it should classify as real.
    - The input image and the generated image (the output of the generator), which it should classify as fake.
    - Use `tf.concat([inp, tar], axis=-1)` to concatenate these 2 inputs together.

Let's define the discriminator:

In [105]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
  down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
  down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

Visualize the discriminator model architecture:

In [106]:
discriminator = Discriminator()
# tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64,to_file='pix2pix_figure/dis_net.png')

Test the discriminator:

In [107]:
# disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)
# plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
# plt.colorbar()

### Define the discriminator loss

- The `discriminator_loss` function takes 2 inputs: **real images** and **generated images**.
- `real_loss` is a sigmoid cross-entropy loss of the **real images** and an **array of ones(since these are the real images)**.
- `generated_loss` is a sigmoid cross-entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**.
- The `total_loss` is the sum of `real_loss` and `generated_loss`.

In [108]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

The training procedure for the discriminator is shown below.

To learn more about the architecture and the hyperparameters you can refer to the [pix2pix paper](https://arxiv.org/abs/1611.07004).

## Define the optimizers and a checkpoint-saver


In [109]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [110]:
checkpoint_dir = './pix2pix_mask_all_in_one_logs'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Generate images

Write a function to plot some images during training.

- Pass images from the test set to the generator.
- The generator will then translate the input image into the output.
- The last step is to plot the predictions and _voila_!

Note: The `training=True` is intentional here since
you want the batch statistics, while running the model on the test dataset. If you use `training=False`, you get the accumulated statistics learned from the training dataset (which you don't want).

In [111]:
import random
from PIL import Image
def generate_images(model, test_input, tar):
    if not os.path.exists(OUTPUT_GENERATOR+ f"/{SUB_PATH}"):
        FULL_DIR = OUTPUT_GENERATOR + "/" + SUB_PATH
        print("ful dir",FULL_DIR)
        os.makedirs(FULL_DIR)
    for i in range(number_generate):
        uique = uuid.uuid4()
        fname = f"{OUTPUT_GENERATOR}/{SUB_PATH}/{SUB_PATH}_{uique}.jpg"
        prediction = model(test_input, training=True)
        img = (prediction[0].numpy()*0.5+0.5) * 255 # increasing value to be in 0-255 
        img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
        is_save = cv2.imwrite(fname,img)
        print("save ", is_save)
#         plt.imshow((prediction[0]*0.5+0.5) )
#         plt.show()

## Restore the latest checkpoint and test the network

In [112]:
!ls {checkpoint_dir}

checkpoint		     ckpt-21.data-00000-of-00001
ckpt-10.data-00000-of-00001  ckpt-21.index
ckpt-10.index		     ckpt-22.data-00000-of-00001
ckpt-11.data-00000-of-00001  ckpt-22.index
ckpt-11.index		     ckpt-23.data-00000-of-00001
ckpt-12.data-00000-of-00001  ckpt-23.index
ckpt-12.index		     ckpt-24.data-00000-of-00001
ckpt-13.data-00000-of-00001  ckpt-24.index
ckpt-13.index		     ckpt-2.data-00000-of-00001
ckpt-14.data-00000-of-00001  ckpt-2.index
ckpt-14.index		     ckpt-3.data-00000-of-00001
ckpt-15.data-00000-of-00001  ckpt-3.index
ckpt-15.index		     ckpt-4.data-00000-of-00001
ckpt-16.data-00000-of-00001  ckpt-4.index
ckpt-16.index		     ckpt-5.data-00000-of-00001
ckpt-17.data-00000-of-00001  ckpt-5.index
ckpt-17.index		     ckpt-6.data-00000-of-00001
ckpt-18.data-00000-of-00001  ckpt-6.index
ckpt-18.index		     ckpt-7.data-00000-of-00001
ckpt-19.data-00000-of-00001  ckpt-7.index
ckpt-19.index		     ckpt-8.data-00000-of-00001
ckpt-1.data-00000-of-00001   ckpt-8.index
ckpt-1.index		  

In [113]:
# Restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fced5042d90>

## Generate some images using the test set

In [114]:
# Run the trained model on a few examples from the test set
# as_numpy_iterator
for inp, tar in train_dataset.as_numpy_iterator():
    generate_images(generator, inp, tar)
    

ful dir /home/cgac/Documents/agri-train-ds/mask_disease_classification_pix2pix/corns/corn_gray_leaf
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  True
save  Tru