# Last frame generation using CGAN, based on pix2pix

In this project I make a conditional GAN with the first and last frame from clips from the flying shapes data set.  
Constant variables(below the imports) are changed to change the behavior of the GAN.  
The data set is a derivation of the flying shapes data set, adding colour and one more shape.  


Refs:  
Flying shapes - https://arxiv.org/abs/1807.00703  
pix2pix - https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py  

In [None]:
# utils
import glob
import sys
import os
import time
import random
import datetime
import functools

# for image display
import IPython.display as display
import PIL.Image

# ML, vector util
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator


### Configuration block

In [None]:
# constant variables
training_data_dir = "/FlyingObjectDataset_10K/training/image/"
validation_data_dir = "/FlyingObjectDataset_10K/validation/image/"
testing_data_dir = "/FlyingObjectDataset_10K/testing/image/"   
GPU = 0 # GPU ID
DROPOUT_PROB = 0.5
IMAGE_WIDTH = 128
IMAGE_HEIGHT = 128 
IMAGE_CHANNEL = 3 
EPOCHS = 12
BATCH_SIZE = 5           
SEQUENCE_LENGTH = 10 
LEARNING_RATE_GEN = 1e-4
LEARNING_RATE_DISC = 1e-4
DATA_AUGMENTATION = True
DATASET_LENGTH = 300
VAL_DATASET_LENGTH = 50


In [None]:

# GPU configuration
if GPU >=0:
    print("creating network model using gpu " + str(GPU))
    os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU)
elif GPU >=-1:
    print("creating network model using cpu ")  
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = ""

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])


### Loading data set

In [None]:
def loadImage(image_dir):
    """
        Load image from dir, apply resizing and normalization
    """

    temp_image = tf.io.read_file(image_dir)
    temp_image = tf.image.decode_png(temp_image)
    
    # change to specified size
    temp_image = tf.image.resize(temp_image, [IMAGE_HEIGHT, IMAGE_WIDTH])
    temp_image = tf.cast(temp_image, tf.float32)

    return temp_image

def normalize(temp_image):
    
    # change pixel values to float numbers between -1 and 1
    temp_image = (temp_image / 127.5) - 1
    return temp_image

def showTensorImage(tensor_img):
    """
        Displays an image given as tensor object
    """
    
    # remove norm and re-cast
    tensor_img = tf.squeeze(tensor_img, 0)
    disp_image = tensorToNumpy(tensor_img)
    display.display(PIL.Image.fromarray(disp_image))
    
def tensorToNumpy(tensor):
    
    disp_image = (tensor + 1) *127.5
    disp_image = tf.cast(disp_image, tf.uint8)
    return np.array(disp_image)
    
def randomJitter(first, last):
    """Random jittering.
        Resizes to 286 x 286 and then randomly crops to IMG_HEIGHT x IMG_WIDTH.
        Args:
        input_image: Input Image
        real_image: Real Image
        Returns:
        Input Image, real image
    """
    # resizing to 148 x 148 x 3
    input_image1 = tf.image.resize(first, [148, 148],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    input_image2 = tf.image.resize(last, [148, 148],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 128 x 128 x 3
    stacked_image = tf.stack([input_image1, input_image2], axis=0)

    cropped_image = tf.image.random_crop(
            stacked_image, size=[2, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
    
    input_image1, input_image2 = cropped_image[0], cropped_image[1]

    if tf.random.uniform(()) > 0.5:
        # random mirroring
        input_image1 = tf.image.flip_left_right(input_image1)
        input_image2 = tf.image.flip_left_right(input_image2)

    return input_image1, input_image2

In [None]:
def data_eren_generator(data_dir):

    # data set lists
    first_frame_dirs, last_frame_dirs = [], []

    # globs for png images
    training_dirs = glob.glob(data_dir + '/*.png')

    # sort files, this way all the frames are in order
    training_dirs.sort()
    
    # first and last frame dirs, set first for bootstrap
    temp_first_dir, temp_last_dir = training_dirs[0], ''

    # iterate over image dirs and load them using tf.io
    first_frame_count = 0
    for temp_dir in training_dirs[1:]:

        # parse dir
        file_name   = temp_dir.split('/')[-1]
        frame_name  = file_name.split('_')[-1]
        frame_count = int(frame_name.replace('.png', ''))

        if frame_count == 1:
            
            # save dirs
            last_frame_dirs += [first_frame_dirs[-1]]*first_frame_count
            first_frame_dirs.append(temp_first_dir)

            # remove last image
            first_frame_dirs = first_frame_dirs[:-1]
            first_frame_count = 0
        else:
            first_frame_count += 1
            first_frame_dirs.append(temp_dir)
      
    pair_dirs = list(zip(first_frame_dirs, last_frame_dirs))

    while True:
        random.shuffle(pair_dirs)
        
        f1_ar = []
        f2_ar = []
        # iterate over image dirs and load them using tf.io
        for first_frame, last_frame in pair_dirs:

            f1, f2 = loadImage(first_frame), loadImage(last_frame)
            if DATA_AUGMENTATION:
                f1, f2 = randomJitter(f1, f2)
            f1, f2 = normalize(f1), normalize(f2)
            
            yield f1, f2

In [None]:
def data_task_generator(data_dir):

    # data set lists
    first_frame_dirs, last_frame_dirs = [], []

    # globs for png images
    training_dirs = glob.glob(data_dir + '/*.png')

    # sort files, this way all the frames are in order
    training_dirs.sort()
    
    # first and last frame dirs, set first for bootstrap
    temp_first_dir, temp_last_dir = training_dirs[0], ''

    # iterate over image dirs and load them using tf.io
    for temp_dir in training_dirs[1:]:

        # parse dir
        file_name   = temp_dir.split('/')[-1]
        frame_name  = file_name.split('_')[-1]
        frame_count = int(frame_name.replace('.png', ''))

        if frame_count == 1:
            # save dirs
            first_frame_dirs.append(temp_first_dir)
            last_frame_dirs.append(temp_last_dir)
            
            # set next first frame
            temp_first_dir = temp_dir
        else:
            temp_last_dir = temp_dir
            
    pair_dirs = list(zip(first_frame_dirs, last_frame_dirs))
    
    while True:
        random.shuffle(pair_dirs)
        
        f1_ar = []
        f2_ar = []
        # iterate over image dirs and load them using tf.io
        for first_frame, last_frame in pair_dirs:

            f1, f2 = loadImage(first_frame), loadImage(last_frame)
            if DATA_AUGMENTATION:
                f1, f2 = randomJitter(f1, f2)
            f1, f2 = normalize(f1), normalize(f2)
            
            yield f1, f2

## GAN utility

In [None]:
class InstanceNormalization(tf.keras.layers.Layer):
    """Instance Normalization Layer (https://arxiv.org/abs/1607.08022)."""

    def __init__(self, epsilon=1e-5):
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        self.scale = self.add_weight(
            name='scale',
            shape=input_shape[-1:],
            initializer=tf.random_normal_initializer(1., 0.02),
            trainable=True)

        self.offset = self.add_weight(
            name='offset',
            shape=input_shape[-1:],
            initializer='zeros',
            trainable=True)

    def call(self, x):
        mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
        inv = tf.math.rsqrt(variance + self.epsilon)
        normalized = (x - mean) * inv
        return self.scale * normalized + self.offset


def downsample(filters, size, norm_type='batchnorm', apply_norm=True):
    """Downsamples an input.
        Conv2D => Batchnorm => LeakyRelu
        Args:
        filters: number of filters
        size: filter size
        norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
        apply_norm: If True, adds the batchnorm layer
        Returns:
        Downsample Sequential Model
    """
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
                 tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                 kernel_initializer=initializer))

    if apply_norm:
        if norm_type.lower() == 'batchnorm':
            result.add(tf.keras.layers.BatchNormalization())
        elif norm_type.lower() == 'instancenorm':
            result.add(InstanceNormalization())
            
    result.add(tf.keras.layers.LeakyReLU(alpha=0.2))

    return result

def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
    """Upsamples an input.
        Conv2DTranspose => Batchnorm => Dropout => Relu
        Args:
        filters: number of filters
        size: filter size
        norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
        apply_dropout: If True, adds the dropout layer
        Returns:
        Upsample Sequential Model
    """

    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                  padding='same',
                  kernel_initializer=initializer))
    
    if norm_type.lower() == 'batchnorm':
        result.add(tf.keras.layers.BatchNormalization())
    elif norm_type.lower() == 'instancenorm':
        result.add(InstanceNormalization())
        
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
        
    result.add(tf.keras.layers.LeakyReLU(alpha=0.2))

    return result

In [None]:
def unet_generator(output_channels, input_shape, norm_type='batchnorm'):
    """Modified u-net generator model (https://arxiv.org/abs/1611.07004).
        Args:
        output_channels: Output channels
        norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
        Returns:
        Generator model
    """

    down_stack = [
        downsample(128, 4, norm_type, apply_norm=False),  # (bs, 64, 64, 128)
        downsample(256, 4, norm_type),  # (bs, 32, 32, 256)
        downsample(256, 4, norm_type),  # (bs, 16, 16, 256)
        downsample(512, 4, norm_type),  # (bs, 8, 8, 512)
    ]

    up_stack = [
        upsample(512, 4, norm_type, apply_dropout=True),  # (bs, 16, 16, 512)
        upsample(256, 4, norm_type, apply_dropout=True),  # (bs, 32, 32, 256)
        upsample(256, 4, norm_type),  # (bs, 64, 64, 256)
        upsample(128, 4, norm_type),  # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(
            output_channels, 4, strides=2,
            padding='same', kernel_initializer=initializer,
            activation='tanh')  # (bs, 128, 128, 3)

    concat = tf.keras.layers.Concatenate()

    inputs = tf.keras.layers.Input(shape=input_shape)
    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
def discriminator(input_shape):
    """PatchGan discriminator model (https://arxiv.org/abs/1611.07004).
        Args:
        input_shape: shape of the input image.
        Returns:
        Discriminator model
    """

    inp = tf.keras.layers.Input(shape=input_shape, name='input_image')

    # first down-stack
    down1 = tf.keras.layers.Conv2D(128, (4, 4), strides=2, padding='same')(inp)
    down1 = tf.keras.layers.Activation('relu')(down1)
    # second down-stack    
    down2 = tf.keras.layers.Conv2D(256, (4, 4), strides=2, padding='same')(down1)
    down2 = tf.keras.layers.Activation('relu')(down2)
    # third down-stack  
    down3 = tf.keras.layers.Conv2D(512, (4, 4), strides=2, padding='same')(down2)
    down3 = tf.keras.layers.Activation('relu')(down3)
    
    output = tf.keras.layers.Conv2D(1, (4, 4), strides=2, padding='same', activation='sigmoid')(down3)

    return tf.keras.Model(inputs=inp, outputs=output)

def c_discriminator(input_shape):
    """PatchGan discriminator model (https://arxiv.org/abs/1611.07004).
        with conditioned input
        Args:
        input_shape: shape of the input image.
        Returns:
        Discriminator model
    """

    inp = tf.keras.layers.Input(shape=input_shape, name='input_image')
    tar = tf.keras.layers.Input(shape=input_shape, name='target_image')
    
    inp_con = tf.keras.layers.Concatenate(axis=-1)([inp, tar])  # (bs, 256, 256, channels*2)

    # first down-stack
    down1 = tf.keras.layers.Conv2D(128, (4, 4), strides=2, padding='same')(inp_con)
    down1 = tf.keras.layers.Activation('relu')(down1)
    # second down-stack    
    down2 = tf.keras.layers.Conv2D(256, (4, 4), strides=2, padding='same')(down1)
    down2 = tf.keras.layers.Activation('relu')(down2)
    # third down-stack  
    down3 = tf.keras.layers.Conv2D(512, (4, 4), strides=2, padding='same')(down2)
    down3 = tf.keras.layers.Activation('relu')(down3)
    
    output = tf.keras.layers.Conv2D(1, (4, 4), strides=2, padding='same', activation='sigmoid')(down3)

    return tf.keras.Model(inputs=[inp, tar], outputs=output)

## GAN class implementation

In [None]:
class P2pGAN(object):
    """
        Class implementation of a pix2pix GAN with specified optimizers, loss-functions.
    """
    
    def __init__(self, epochs, input_shape):
        
        self.epochs = epochs
        self.lambda_value = 100
        self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.generator_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE_GEN)
        self.discriminator_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE_DISC)
        self.generator = unet_generator(3, input_shape)
        self.discriminator = c_discriminator(input_shape)
        
        self.checkpoint = tf.train.Checkpoint(
            generator_optimizer=self.generator_optimizer,
            discriminator_optimizer=self.discriminator_optimizer)
    
    def train(self, train_gen, val_gen, checkpoint_pr):
        """Train the GAN for x number of epochs.
        Args:
        dataset: train dataset.
        checkpoint_pr: prefix in which the checkpoints are stored.
        Returns:
        Training step losses
        """
        
        loss_list = []

        for epoch in range(1, self.epochs+1):
            
            # total loss variables
            t_disc_real_loss = 0
            t_disc_fake_loss = 0
            t_gen_bce_loss   = 0
            t_gen_mse_loss   = 0
            gen_update_index = 0
            
            for _ in range(DATASET_LENGTH):
                
                # create batch
                input_images  = []
                target_images = []
                for _ in range(BATCH_SIZE):
                    a, b = next(train_gen)
                    input_images.append(a)
                    target_images.append(b)
                input_images  = tf.convert_to_tensor(input_images)
                target_images = tf.convert_to_tensor(target_images)
                
                # only update generator every third discriminator update
                disc_loss, gen_loss = self.train_step(input_images,
                                                          target_images,
                                                          True,
                                                          True)                   
                # update index
                gen_update_index += 1
                
                # update totals
                t_disc_real_loss += disc_loss[0]
                t_disc_fake_loss += disc_loss[1]
                t_gen_bce_loss   += gen_loss[0]
                t_gen_mse_loss   += gen_loss[1]
                    
            # update loss totals, avg
            tgl = (t_gen_bce_loss/DATASET_LENGTH, t_gen_mse_loss/DATASET_LENGTH)
            tdl = (t_disc_real_loss/DATASET_LENGTH, t_disc_fake_loss/DATASET_LENGTH)
            loss_list.append((tgl, tdl))
            
            
                
            # print epoch status
            val_str = self.validate(val_gen)
            template = 'Epoch {}, Generator loss {}, Discriminator Loss {}'
            print(template.format(epoch, tgl[0]+tgl[1], tdl[0]+tdl[1]) + val_str)
            
            # saving (checkpoint) the model every 20 epochs
            if (epoch + 1) % 20 == 0:
                self.checkpoint.save(file_prefix=checkpoint_pr)

        return loss_list
    
    def validate(self, data_gen):
        """One validation step over the generator and discriminator model.
           This function does not update any weights.
        Args:
          data_gen: data generator, validation set.
        Returns:
          generator loss, discriminator loss.
        """
        
        t_disc_loss, t_gen_loss = 0, 0
        for _ in range(VAL_DATASET_LENGTH):
            # create batch
            input_images  = []
            target_images = []
            for _ in range(BATCH_SIZE):
                a, b = next(data_gen)
                input_images.append(a)
                target_images.append(b)
            input_images  = tf.convert_to_tensor(input_images)
            target_images = tf.convert_to_tensor(target_images)

            # run w/o updating gradients
            disc_loss, gen_loss = self.train_step(input_images,
                                                  target_images,
                                                  False,
                                                  False)
            t_disc_loss += disc_loss[0] + disc_loss[1]
            t_gen_loss  += gen_loss[0] + gen_loss[1]
        # return validation string
        template = '\nValidation, Generator loss {}, Discriminator Loss {}\n'
        return template.format(t_gen_loss/VAL_DATASET_LENGTH, t_disc_loss/VAL_DATASET_LENGTH)
        
    
    def train_step(self, input_image, target_image, update_gen=True, update_disc=True):
        """One train step over the generator and discriminator model.
        Args:
          input_image: Input Image, batch
          target_image: Target image, batch
          update_gen: Boolean to update.
          update_disc: Boolean to update.
        Returns:
          generator loss, discriminator loss.
        """
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

            # Generate a fake last frame   
            gen_output = self.generator(input_image)

            # Discriminator infer fake and real frame
            disc_real_output = self.discriminator([input_image, target_image])
            disc_gen_output = self.discriminator([input_image, gen_output])

            # Get loss
            d_real_loss, d_gen_loss = self.discriminator_loss_bce(
                disc_real_output, disc_gen_output)
            gan_bce_loss, gan_mse_loss = self.generator_loss(
                disc_gen_output, gen_output, target_image)
            
            if update_gen:
                # Get gradients from gradient tape
                generator_gradients = gen_tape.gradient(
                    gan_bce_loss + gan_mse_loss, 
                    self.generator.trainable_variables)
                # Apply new gradients to discriminator and generator
                self.generator_optimizer.apply_gradients(zip(
                    generator_gradients, self.generator.trainable_variables))
                
            if update_disc:
                # Get gradients from gradient tape
                discriminator_gradients = disc_tape.gradient(
                    d_real_loss + d_gen_loss, self.discriminator.trainable_variables)
                # Apply new gradients to discriminator and generator
                self.discriminator_optimizer.apply_gradients(zip(
                    discriminator_gradients, self.discriminator.trainable_variables))
            
        return (d_real_loss, d_gen_loss), (gan_bce_loss, gan_mse_loss)
    
    def discriminator_loss(self, disc_real_output, disc_gen_output):
        """Calculates discriminator loss, on both fake and real samples
        Args:
          disc_real_output: Discriminator output, real image
          disc_gen_output: Discriminator output, fake image
        Returns:
          discriminator loss.
        """
        
        # calc loss on real/fake inputs to discriminator
        real_loss      = -tf.reduce_mean(disc_real_output)
        generated_loss = tf.reduce_mean(disc_gen_output)

        return real_loss, generated_loss

    def generator_loss(self, disc_gen_output, gen_output, target):
        """Calculates generator loss, with the chosen loss function and
           MeanSquareError on faked image and target image.
        Args:
          disc_gen_output: Discriminator output, real image
          gen_output: faked image
          target: Target
        Returns:
          Generator loss
        """        
        # calc loss from discriminator
        gan_loss = self.loss_object(tf.ones_like(
            disc_gen_output), disc_gen_output)

        # mean square error
        l1_loss = tf.reduce_mean(tf.pow(target - gen_output, 2))

        return gan_loss, l1_loss*self.lambda_value
    
    def discriminator_loss_bce(self, disc_real_output, disc_gen_output):
        """Calculates discriminator loss, with chosen loss function(BinaryCrossentropy)
        Args:
          disc_gen_output: Discriminator output, real image
          gen_output: faked image
          target: Target
        Returns:
          Generator loss
        """ 
        # calc loss from discriminator, real images
        real_loss = self.loss_object(
            tf.ones_like(disc_real_output), disc_real_output)

        # calc loss from discriminator, fakes images
        generated_loss = self.loss_object(tf.zeros_like(
            disc_gen_output), disc_gen_output)
        
        return real_loss, generated_loss

In [None]:
# The following blocks are for testing
current_log_folder = './log/ckpt-{}/'.format(str(datetime.datetime.now()))
os.mkdir(current_log_folder)
image_count = 0
p2p = P2pGAN(EPOCHS, (128, 128, 3))
t1 = []

In [None]:
# Train
gan_losses = p2p.train(data_task_generator(training_data_dir), data_task_generator(validation_data_dir), current_log_folder)

In [None]:
# create testing data set generator, swap training/test generator
dg = data_eren_generator(testing_data_dir)
generator = p2p.generator
# generator = tf.keras.models.load_model('./log/keras-gen')

In [None]:
# next testing sample
f1, f2 = next(dg)
# for batch
f1, f2 = tf.expand_dims(f1, 0), tf.expand_dims(f2, 0)
img = generator(f1)

In [None]:
# round-up predicted colors
img_np = img.numpy()
with np.nditer(img_np, op_flags=['readwrite']) as it:
    for x in it:
        if x > 0.5:
            x[...] = 1.0
        else:
            x[...] = 0.0
img_tens = tf.convert_to_tensor(img_np, tf.float32)

# display input/target/pred and rounded up version of pred(img_tens)
img_both = np.hstack([tensorToNumpy(tf.squeeze(f1, 0)), tensorToNumpy(tf.squeeze(f2, 0)), tensorToNumpy(tf.squeeze(img, 0))])
result = PIL.Image.fromarray(img_both)
image_count += 1
result.save(current_log_folder +'out{}.bmp'.format(image_count))
display.display(PIL.Image.fromarray(img_both))

In [None]:
import matplotlib.pyplot as plt
# epoch list
x = [ t for t in range(1, len(gan_losses)+1)]
# loss lists
y1 = [float(t[0][0]) for t in gan_losses]
y2 = [float(t[0][1]) for t in gan_losses]
y3 = [float(t[1][0]) for t in gan_losses]
y4 = [float(t[1][1]) for t in gan_losses]
fig, axs = plt.subplots(2,2, figsize=(10,4))
# disc_real
axs[0, 0].plot(x, y1)
axs[0, 0].set_xlabel('Epochs')
axs[0, 0].set_ylabel('G bce Loss')
# disc fake
axs[0, 1].plot(x, y2)
axs[0, 1].set_xlabel('Epochs')
axs[0, 1].set_ylabel('G mse Loss')
# gen bce
axs[1, 0].plot(x, y3)
axs[1, 0].set_xlabel('Epochs')
axs[1, 0].set_ylabel('D real Loss')
# gen mse
axs[1, 1].plot(x, y4)
axs[1, 1].set_xlabel('Epochs')
axs[1, 1].set_ylabel('D fake Loss')

fig.tight_layout()
plt.show()
plt.savefig('colapse_to_white.png')

In [None]:
# save
#tf.keras.models.save_model(p2p.discriminator, './log/keras-disc/')
#tf.keras.models.save_model(p2p.generator, './log/keras-gen/')