# pix2pix: Image-to-image translation with a conditional GAN

## Import TensorFlow and other libraries

In [None]:
import tensorflow as tf
import numpy as np

import os
import shutil
import pathlib
import time
import datetime
from keras_preprocessing.image import ImageDataGenerator

from matplotlib import pyplot as plt
from IPython import display

In [None]:
EPOCHS = 100
# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment
BATCH_SIZE = 1
# Each image is 256x256 in size
IMG_WIDTH = 256
IMG_HEIGHT = 256
PATCHES_METHOD = 'independent_patches'   # Choose between independent_patches, overlapped_patches an random_patches

In [None]:
PATH = 'D:/Gerasimos/my_icdar/datasets'
train_set_path = os.path.join(PATH, f'train/RGB/{IMG_WIDTH}by512/{PATCHES_METHOD}/101' )
test_set_path = os.path.join(PATH, f'test/RGB/{IMG_WIDTH}by512/independent_patches/301_small')

In [None]:
# The buffer size is set to be equal with the total number of the training images
buffer_size = len(os.listdir(train_set_path))
test_set_num = len(os.listdir(test_set_path))
steps_per_epoch = int(buffer_size / BATCH_SIZE)
steps = buffer_size * EPOCHS # TODO: Delete it 

In [None]:
model_path = f'D:/Gerasimos/my_icdar/models/{IMG_WIDTH}_inputs/lr0002_aug/pix2pix'
gen_name = model_name = f'pix2pix_input{IMG_WIDTH}_filters64_images{buffer_size}_{PATCHES_METHOD}_optAdam_lr0.0002_eps{EPOCHS}' #TODO: loss and metric'
gen_path = os.path.join(model_path, gen_name)
checkpoint_dir = os.path.join(gen_path, f'training_checkpoints_{PATCHES_METHOD}')

In [None]:
sample_image = tf.io.read_file(str(train_set_path  + '/8.jpg'))
sample_image = tf.io.decode_jpeg(sample_image)
print(sample_image.shape)

In [None]:
plt.figure()
plt.imshow(sample_image)
plt.show()

You need to separate real building facade images from the architecture label images—all of which will be of size `256 x 256`.

Define a function that loads image files and outputs two image tensors:

In [None]:
def load(image_file):
    # Read and decode an image file to a uint8 tensor
    image = tf.io.read_file(image_file)
    image = tf.io.decode_jpeg(image)

    # Split each image tensor into two tensors:
    # - one with a real building facade image
    # - one with an architecture label image 
    w = tf.shape(image)[1]
    w = w // 2
    input_image = image[:, :w, :] # NOTE:  !!!!!!! Needs to change
    #target_image = image[:, :w, :]
    target_image = image[:, w:, 0] # Gerasimos changed it, because the target image is binary
    # Convert both images to float32 tensors
    input_image = tf.cast(input_image, tf.float32)
    target_image = tf.cast(target_image, tf.float32)

    return input_image, target_image

Plot a sample of the input (architecture label image) and real (building facade photo) images:

In [None]:
inp, tar = load(str(train_set_path + '/92.jpg'))
# Casting to int for matplotlib to display the images
plt.figure()
plt.imshow(inp / 255.0)
plt.figure()
plt.imshow(tar / 255.0)

As described in the [pix2pix paper](https://arxiv.org/abs/1611.07004), you need to apply random jittering and mirroring to preprocess the training set.

Define several functions that:

1. Resize each `256 x 256` image to a larger height and width—`286 x 286`.
2. Randomly crop it back to `256 x 256`.
3. Randomly flip the image horizontally i.e. left to right (random mirroring).
4. Normalize the images to the `[-1, 1]` range.

In [None]:
def resize(input_image, target_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    target_image = tf.expand_dims(target_image, axis=-1)
    target_image = tf.image.resize(target_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, target_image

In [None]:
def random_crop(input_image, real_image):
    stacked_image = tf.concat([input_image, real_image], -1)
    cropped_image = tf.image.random_crop(
        stacked_image, size=[IMG_HEIGHT, IMG_WIDTH, 4])

    return cropped_image[:, :, :3], tf.expand_dims(cropped_image[:, :, 3], axis=-1)

In [None]:
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image

In [None]:
@tf.function()
def random_jitter(input_image, target_image):
    # Resizing to 286x286
    input_image, target_image = resize(input_image, target_image, 286, 286)
    
    # Random cropping back to 256x256
    input_image, target_image = random_crop(input_image, target_image)

    if tf.random.uniform(()) > 0.5:
        # Random mirroring
        input_image = tf.image.flip_left_right(input_image)
        target_image = tf.image.flip_left_right(target_image)

    return input_image, target_image

You can inspect some of the preprocessed output:

In [None]:
plt.figure(figsize=(6, 6))
for i in range(4):
    rj_inp, rj_tar = random_jitter(inp, tar)
    plt.subplot(2, 2, i + 1)
    plt.imshow(rj_inp / 255.0)
    plt.axis('off')
plt.show()

Having checked that the loading and preprocessing works, let's define a couple of helper functions that load and preprocess the training and test sets:

In [None]:
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

In [None]:
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

In [None]:
train_dataset = tf.data.Dataset.list_files(train_set_path + '/*jpg')
print(train_dataset)
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)

train_dataset = train_dataset.shuffle(buffer_size)

train_dataset = train_dataset.batch(BATCH_SIZE)

In [None]:
test_dataset = tf.data.Dataset.list_files(test_set_path + '/*jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(test_set_num)

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

In [None]:
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)

In [None]:
OUTPUT_CHANNELS = 1

In [None]:
def Generator():
    inputs = tf.keras.layers.Input(shape=[256, 256, 3])

    down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),  # (batch_size, 64, 64, 128)
    downsample(256, 4),  # (batch_size, 32, 32, 256)
    downsample(512, 4),  # (batch_size, 16, 16, 512)
    downsample(512, 4),  # (batch_size, 8, 8, 512)
    downsample(512, 4),  # (batch_size, 4, 4, 512)
    downsample(512, 4),  # (batch_size, 2, 2, 512)
    downsample(512, 4),  # (batch_size, 1, 1, 512)
    ]

    up_stack = [
    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(512, 4),  # (batch_size, 16, 16, 1024)
    upsample(256, 4),  # (batch_size, 32, 32, 512)
    upsample(128, 4),  # (batch_size, 64, 64, 256)
    upsample(64, 4),  # (batch_size, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

Visualize the generator model architecture:

In [None]:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])

In [None]:
LAMBDA = 100

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

    # Mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)

    return total_gen_loss, gan_loss, l1_loss

The training procedure for the generator is as follows:

Let's define the discriminator:

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[256, 256, 1], name='target_image')

    x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)

    down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
    down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
    down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)

    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

In [None]:
disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss

The training procedure for the discriminator is shown below.

To learn more about the architecture and the hyperparameters you can refer to the [pix2pix paper](https://arxiv.org/abs/1611.07004).

![Discriminator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/dis.png?raw=1)


## Define the optimizers and a checkpoint-saver


In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
def generate_images(model, test_input, tar):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15, 15))

    display_list = [test_input[0], tar[0], np.where(prediction[0] > 0.5, 1, 0)]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # Getting the pixel values in the [0, 1] range to plot.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

In [None]:
def validate(generator, discriminator, input_images, target_images):
    gen_output = generator(input_images, training=True)

    disc_real_output = discriminator([input_images, target_images], training=True)
    disc_generated_output = discriminator([input_images, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target_images)
    
    return gen_total_loss, gen_gan_loss, gen_l1_loss

In [None]:
for example_input, example_target in test_dataset.take(1):
    generate_images(generator, example_input, example_target)

In [None]:
log_dir= os.path.join(model_path, "logs")

summary_writer = tf.summary.create_file_writer(
  log_dir + "/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

## Training

- For each example input generates an output.
- The discriminator receives the `input_image` and the generated image as the first input. The second input is the `input_image` and the `target_image`.
- Next, calculate the generator and the discriminator loss.
- Then, calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.
- Finally, log the losses to TensorBoard.

In [None]:
@tf.function()
def train_step(input_image, target, step):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
        tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
        tf.summary.scalar('disc_loss', disc_loss, step=step//1000)

The actual training loop. Since this tutorial can run of more than one dataset, and the datasets vary greatly in size the training loop is setup to work in steps instead of epochs.

- Iterates over the number of steps.
- Every 10 steps print a dot (`.`).
- Every 1k steps: clear the display and run `generate_images` to show the progress.
- Every 5k steps: save a checkpoint.

In [None]:
# Create dictionaries with the alterations to be used for the data augmentation operations for inputs and outputs
""""
NOTE: It is important to use exact the same values for the parameters in both dictionaries. 
The reason why we create 2 dictionaries instead of a common one is because  we want to add the 
preprocessing function for the output masks (this function sets all the pixel values of the mask to 0 or 1. 
While the initial ground truth images are binary, after the data augmentation operations, such as rotations and 
shifts, pixels with intermediate values are appeared due to interpolation)
"""
x_datagen_args = dict(
    rotation_range=25, 
    width_shift_range=0.1, 
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='reflect')

y_datagen_args = dict(
    rotation_range=25, 
    width_shift_range=0.1, 
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='reflect', 
    preprocessing_function = lambda x: np.where(x>0, 1, 0).astype(x.dtype))

# Instatiate the generators
x_datagen = ImageDataGenerator(**x_datagen_args)
y_datagen = ImageDataGenerator(**y_datagen_args)

In [None]:
def fit(train_ds, test_ds, steps):

    validate_inputs, validate_targets = next(iter(test_ds.take(1)))
   
    best_valid_loss = 10**4
    previous_model_dir = None
    for epoch in range(EPOCHS):
        display.clear_output(wait=True)
        
        #generate_images(generator, example_input, example_target)
        r = [np.random.randint(test_set_num)]
        r = 5
        generate_images(generator, tf.expand_dims(validate_inputs[r], axis=0), tf.expand_dims(validate_targets[r], axis=0))
        
        gen_total_loss, gen_gan_loss, gen_l1_loss = validate(generator, discriminator, validate_inputs, validate_targets)
        print(f'Epoch {epoch + 1} of {EPOCHS}, gen_total_loss: {gen_total_loss.numpy()}, '
              f'gen_gan_loss: {gen_gan_loss.numpy()}, gen_l1_loss: {gen_l1_loss.numpy()}\n')
        for step, (input_image, target) in enumerate(train_ds):
                #print(f'Epoch: {epoch + 1}, step: {step}/{steps_per_epoch}')
                seed = np.random.randint(10**4)
                aug_input = x_datagen.flow(input_image, seed=seed, batch_size=BATCH_SIZE).next()
                aug_target = y_datagen.flow(target, seed=seed, batch_size=BATCH_SIZE).next()
                train_step(aug_input, aug_target, step)
                
        if gen_total_loss < best_valid_loss:
            best_model_dir = os.path.join(gen_path.split('eps')[0]) + 'eps' + str(epoch + 1)
            generator.save(best_model_dir)
            print('A new best model was saved')
            best_valid_loss = gen_total_loss
            
            if previous_model_dir:
                shutil.rmtree(previous_model_dir)  
                previous_model_dir = best_model_dir

        # NOTE: I commented the following lines of code, because an error is raised, due to the excessed size of the filepath name
        """
        # Save (checkpoint) the model every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
        """
           

This training loop saves logs that you can view in TensorBoard to monitor the training progress.

If you work on a local machine, you would launch a separate TensorBoard process. When working in a notebook, launch the viewer before starting the training to monitor with TensorBoard.

To launch the viewer paste the following into a code-cell:

In [None]:
%load_ext tensorboard
%tensorboard --logdir {log_dir}

Finally, run the training loop:

In [None]:
fit(train_dataset, test_dataset, steps=steps_per_epoch)

If you want to share the TensorBoard results _publicly_, you can upload the logs to [TensorBoard.dev](https://tensorboard.dev/) by copying the following into a code-cell.

Note: This requires a Google account.

```
!tensorboard dev upload --logdir {log_dir}
```

Caution: This command does not terminate. It's designed to continuously upload the results of long-running experiments. Once your data is uploaded you need to stop it using the "interrupt execution" option in your notebook tool.

In [None]:
generator.save(gen_path)
discriminator.save(os.path.join(gen_path, 'd_model'))

You can view the [results of a previous run](https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw) of this notebook on [TensorBoard.dev](https://tensorboard.dev/).

TensorBoard.dev is a managed experience for hosting, tracking, and sharing ML experiments with everyone.

It can also included inline using an `<iframe>`:

In [None]:
display.IFrame(
    src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
    width="100%",
    height="1000px")

Interpreting the logs is more subtle when training a GAN (or a cGAN like pix2pix) compared to a simple classification or regression model. Things to look for:

- Check that neither the generator nor the discriminator model has "won". If either the `gen_gan_loss` or the `disc_loss` gets very low, it's an indicator that this model is dominating the other, and you are not successfully training the combined model.
- The value `log(2) = 0.69` is a good reference point for these losses, as it indicates a perplexity of 2 - the discriminator is, on average, equally uncertain about the two options.
- For the `disc_loss`, a value below `0.69` means the discriminator is doing better than random on the combined set of real and generated images.
- For the `gen_gan_loss`, a value below `0.69` means the generator is doing better than random at fooling the discriminator.
- As training progresses, the `gen_l1_loss` should go down.

## Restore the latest checkpoint and test the network

In [None]:
#!ls {checkpoint_dir}
reader = tf.train.load_checkpoint(checkpoint_dir)

In [None]:
#model = tf.keras.Model()
#checkpoint = tf.train.Checkpoint(model)
generator = tf.keras.models.load_model('models/256_inputs/f0002/pix2pix/pix2pix_input256_filters64_images561_independent_patches_optAdam_lr0.0002_eps100')

In [None]:
checkpoint

In [None]:
# Restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
checkpoint.read(checkpoint_dir)

In [None]:
# Run the trained model on a few examples from the test set
for inp, tar in train_dataset.take(5):
    generate_images(generator, inp, tar)
