In [None]:
# For Google Colab to access Drive
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
# David Cabezas Berrido

# p2pGAN_models-train.ipynb

# Definition of the models (generator and discriminator) and training

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

PATH = '.' #'/content/drive/My Drive/TFG-Image-Optimization' # for Google Colab
INPATH = PATH+'/data/jpeg/short/'
TGPATH = PATH+'/data/jpeg/long/'
MODELS = PATH+'/models/' # for the trained models to be saved

# Dimensions for training images
TRAIN_HEIGHT = 1024 
TRAIN_WIDTH = 1536
TRAIN_SQUARE_SIZE = 512 # Length of the random cropped square side

# Dimensions for test images
TEST_HEIGHT = 3072 #=12*256 # This resolution requires a large RAM
TEST_WIDTH = 4608 #=18*256

CHANNELS = 3 # RGB

In [None]:
# Dataset reading

imgurls = !ls -1 "{INPATH}" # short exposure images

# Match each short exposure image with the corresponding long exposure image
# Example:
# Short exposure: fuji-00001-x4.jpg
# Long exposure: fuji-00001.jpg
def urlTarget(url):
    ext=url.split('.')[-1]
    u=url.split('-')
    return ''.join([u[0],'-',u[1],'.'+ext])

imgurls_tg=[(url,urlTarget(url)) for url in imgurls]

# Train/Test split
def isForTrain(url):
    u=url.split('-')
    return u[1][0] in {'0','1'}

# Train images
tr_urls = [url for url in imgurls_tg if isForTrain(url[0])] # First digit is 0 or 1
# Test images
ts_urls = [url for url in imgurls_tg if not isForTrain(url[0])] # First digit is 2

np.random.shuffle(tr_urls)
#np.random.shuffle(ts_urls)

In [None]:
# Dataset processing

# Resizes image to height x width
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return input_image, real_image

# Randomly crops a height x width square
def random_crop(input_image, real_image, height, width):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, size=[2, height, width, 3])
    return cropped_image[0], cropped_image[1]

# Randomly mirrors (flips horizontally) both images
def random_flip(input_image, real_image):
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)
    return input_image, real_image

# Normalize to [-1,1]
def normalize(inimg,tgimg):
    inimg = inimg/127.5-1
    tgimg = tgimg/127.5-1
    return inimg, tgimg

@tf.function
def loadTrainImage(filenames):
    # Read image
    inimg = tf.cast(tf.image.decode_jpeg(tf.io.read_file(INPATH+filenames[0])),tf.float32)[...,:3]
    tgimg = tf.cast(tf.image.decode_jpeg(tf.io.read_file(TGPATH+filenames[1])),tf.float32)[...,:3]
    inimg, tgimg = resize(inimg, tgimg, TRAIN_HEIGHT, TRAIN_WIDTH)
    inimg, tgimg = random_crop(inimg, tgimg, TRAIN_SQUARE_SIZE, TRAIN_SQUARE_SIZE)
    inimg, tgimg = random_flip(inimg, tgimg)
    inimg, tgimg = normalize(inimg, tgimg)
    return inimg, tgimg
    
def loadTestImage(filenames):
    # Read image
    inimg = tf.cast(tf.image.decode_jpeg(tf.io.read_file(INPATH+filenames[0])),tf.float32)[...,:3]
    tgimg = tf.cast(tf.image.decode_jpeg(tf.io.read_file(TGPATH+filenames[1])),tf.float32)[...,:3]
    inimg, tgimg = resize(inimg, tgimg, TEST_HEIGHT, TEST_WIDTH)
    inimg, tgimg = normalize(inimg, tgimg)
    return inimg, tgimg
    
train_dataset = tf.data.Dataset.from_tensor_slices(tr_urls)
train_dataset = train_dataset.map(loadTrainImage, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)

test_dataset = tf.data.Dataset.from_tensor_slices(ts_urls)
test_dataset = test_dataset.map(loadTestImage, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(1)

In [None]:
# Block of 2-3 layers, appears in discriminator and first half of generator (encoder)
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    # Convolution layer
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
               kernel_initializer=initializer, use_bias=not apply_batchnorm))
    if apply_batchnorm: # Batch normalization layer (if indicated)
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU()) # Leaky ReLU layer
    return result

# Block of 3-4 layers, appear in second half of generator (decoder)
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    # Deconvolution layer
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
               padding='same', kernel_initializer=initializer, use_bias=False))
    result.add(tf.keras.layers.BatchNormalization()) # Batch normalization layer
    if apply_dropout: # Dropout layer (if indicated)
        result.add(tf.keras.layers.Dropout(0.5))
    result.add(tf.keras.layers.ReLU()) # ReLU layer
    return result

In [None]:
# The generator use U-Net architecture: enconder-decoder with skip connections

def Generator(): 
    inputs = tf.keras.layers.Input(shape=[None,None,CHANNELS]) # Shape (if input has 512x512x3)
    #inputs = tf.keras.layers.Input(shape=[512,512,CHANNELS]) # Shape (if input has 512x512x3)
    
    down_stack = [ # Encoder                      
        downsample(64, 4, apply_batchnorm=False), # (bs, 256, 256, 64)
        downsample(128, 4),                       # (bs, 128, 128, 128)
        downsample(256, 4),                       # (bs, 64, 64, 256)
        downsample(512, 4),                       # (bs, 32, 32, 512)
        downsample(512, 4),                       # (bs, 16, 16, 512)
        downsample(512, 4),                       # (bs, 8, 8, 512)
        downsample(512, 4),                       # (bs, 4, 4, 512)
        downsample(512, 4),                       # (bs, 2, 2, 512)
    ]

    up_stack = [ # Decoder
        upsample(512, 4, apply_dropout=True),     # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True),     # (bs, 8, 8, 1024)
        upsample(512, 4, apply_dropout=True),     # (bs, 16, 16, 1024)
        upsample(512, 4),                         # (bs, 32, 32, 1024)
        upsample(256, 4),                         # (bs, 64, 64, 512)
        upsample(128, 4),                         # (bs, 128, 128, 256)
        upsample(64, 4),                          # (bs, 256, 256, 128)
    ]

    # Final layer: must ouput same shape (batch size, height, width, channels) as input
    # Pixel values must be in [-1,1] so activation=tanh
    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(filters=CHANNELS, kernel_size=4,
                                           strides=2, padding='same',
                                           kernel_initializer=initializer,
                                           activation='tanh') # (bs, 256, 256, 3)

    x = inputs
    # Downsampling through the model (encoder)
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])
    # Upsampling and establishing the skip connections (decoder)
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])
    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

generator=Generator()

In [None]:
# Plot generator architecture (for Jupyter notebook)
#tf.keras.utils.plot_model(generator, show_shapes=True, dpi=96)
#tf.keras.utils.plot_model(generator, show_shapes=False, dpi=96)
#tf.keras.utils.plot_model(generator, show_shapes=False, expand_nested=True, dpi=96)
tf.keras.utils.plot_model(generator, show_shapes=True, expand_nested=True, dpi=96)

In [None]:
# PatchGAN:
# Outputs NxN 1-channel image where each pixel classifies a 70x70 portion of input
def Discriminator():
    inp = tf.keras.layers.Input(shape=[None, None, CHANNELS], name='input_image')
    tar = tf.keras.layers.Input(shape=[None, None, CHANNELS], name='target_image')
    #inp = tf.keras.layers.Input(shape=[512, 512, CHANNELS], name='input_image')
    #tar = tf.keras.layers.Input(shape=[512, 512, CHANNELS], name='target_image')
    
    # Shape (if input has 512x512x3)
    x = tf.keras.layers.concatenate([inp, tar]) # (bs, 512, 512, CHANNELS*2=6)

    down1 = downsample(64, 4, False)(x) # (bs, 256, 256, 64)
    down2 = downsample(128, 4)(down1) # (bs, 128, 128, 128)
    down3 = downsample(256, 4)(down2) # (bs, 64, 64, 256)

    initializer = tf.random_normal_initializer(0., 0.02)
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 66, 66, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 63, 63, 512)

    batchnorm1 = tf.keras.layers.BatchNormalization()(conv) # (bs, 63, 63, 512)
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1) # (bs, 63, 63, 512)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 65, 65, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1,
            kernel_initializer=initializer)(zero_pad2) # (bs, 62, 62, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

discriminator=Discriminator()

In [None]:
# Plot discriminator architecture (for Jupyter notebook)
#tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=96)
#tf.keras.utils.plot_model(discriminator, show_shapes=False, dpi=96)
#tf.keras.utils.plot_model(discriminator, show_shapes=False, expand_nested=True, dpi=96)
tf.keras.utils.plot_model(discriminator, show_shapes=True, expand_nested=True, dpi=96)

In [None]:
# Cross-entropy between true and prediction
# from_logits forces the prediction to be between [0,1] (applies sigmoid to the prediction)
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

LAMBDA = 100 # Weight of L1 loss
# Loss function for generator
def generator_loss(disc_generated_output, gen_output, target):
    # Generator tries to fool discriminator
    # Discriminator is fooled if disc_generated_output has low values
    # (discriminator thinks the generated image is real)
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    # Mean Absolute Error: prediction and ground truth should look alike
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + LAMBDA * l1_loss
    return total_gen_loss, gan_loss, l1_loss

# Loss function for discriminator
def discriminator_loss(disc_real_output, disc_generated_output):
    # Discriminator must know that ground truth is real, should output low values
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    # Discriminator must know that generated image is false, should output high values
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    return real_loss + generated_loss

In [None]:
# Optimizers: Adam algorithm with beta_1=0.5, beta_2=0.999 (default)
# the rest of parameters are set to default
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
# Code for checkpoint saving, to restore the status of both models
# and their respective optimizers
checkpoint_dir = PATH+'/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# Code for restoring a specific checkpoint or the most recent
#checkpoint.restore(checkpoint_prefix+'-18').assert_consumed()
#checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
# A single train step through one example
@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True) # Generated image
        # Discriminator tries to guess whether the images are real or generated
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        # Loss of each model
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
        # Gradients are computed and weights updated
        generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

In [None]:
# Plot input-target-prediction
# to observe the generator performance or it's evolution throughout training
def generate_images(model, test_input, tar):
    # training=True so batchnorm use test data instead of moving mean and variance from train data
    # Batch size must be 1 so prediction does not depend in other test examples
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15,15))
    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

In [None]:
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        # generator evolution throughout training
        for example_input, example_target in test_ds.take(1):
            generate_images(generator, example_input, example_target)
        print("Epoch: ", epoch)
        # Train: iterate through all training examples every single epoch
        for n, (input_image, target) in train_ds.enumerate():
            if (n+1) % 50 == 0: # Number of examples in the current epoch
                print(int(n+1),end=', ')
                if (n+1) % 1000 == 0:
                    print()
            train_step(input_image, target)
        print()
        # saving (checkpoint) the model every 5 epochs
        if (epoch+1) % 5 == 0 and epoch+1 < epochs: # Not the last one
            checkpoint.save(file_prefix = checkpoint_prefix)
            
    checkpoint.save(file_prefix = checkpoint_prefix) # Saving when finished

In [None]:
# 5 epochs of training
# The total number of epochs was 75
fit(train_dataset, 5, test_dataset)

In [None]:
# Save models: SavedModel format
generator.save(MODELS+'GAN-generator')
discriminator.save(MODELS+'GAN-discriminator')

In [None]:
# Save models: HDF5 format
generator.save(MODELS+'GAN-generator.h5')
discriminator.save(MODELS+'GAN-discriminator.h5')

In [None]:
# Export generator directly to JavaScript
# this can also be done from the terminal with:
"""
tensorflowjs_converter --input_format keras path/to/GAN-generator.h5 path/to/TFJS_GAN-generator
"""
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(generator, MODELS+'TFJS_GAN-generator')