**IMPORTING LIBRARIES AND PACKAGES**

In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import tensorflow as tf
import datetime
from IPython.display import clear_output

**LOADING DATASET**

* using 256x256 images from MSCOCO dataset
* X is array containing input images generated by resizing original images to 128x128 size
* Y is array containing original images of 256x256 size
* Only 1000 images are used out of 80K images for faster computation

In [None]:
X = []
Y = []

c = 0;
data_size = 1000
dirname = "../input/mscoco/mscoco_resized/train2014"
for filename in tqdm(os.listdir(dirname)):
    im  = Image.open(os.path.join(dirname, filename))
    Y.append(np.array(im))
    im = im.resize((128,128))
    X.append(np.array(im))
    
    c+=1
    if(c == data_size+20):
        break
    
X = np.array(X, dtype = 'float32')
Y = np.array(Y, dtype = 'float32')

X = (X/127.5)-1
Y = (Y/127.5)-1

X_train = X[0:data_size].reshape(-1,1,128,128,3)
Y_train = Y[0:data_size].reshape(-1,1,256,256,3)

X_test = X[data_size:].reshape(-1,1,128,128,3)
Y_test = Y[data_size:].reshape(-1,1,256,256,3)

print(X_train.shape)
print(Y_train.shape)

print(X_test.shape)
print(Y_test.shape)


**CREATING GENERATOR MODEL**

* GAN uses two models, a generator model that generates output and a discriminator model that classifies wether the output is generated by the generator model(fake) or taken from the dataset(real)

* The generator model is a U-Net. It is a neural network used for image to image tasks. It has three major components : downsampling blocks, upsampling blocks and skip connections.

* Downsampling blocks convert image input to tensors of lower dimesions until it becomes a 1D tensor. Upsampling blocks convert output of downsampling blocks back to image output. Skip connections provide connections between downsampling and upsampling blocks at each level.

* Generator and discriminator compete against each other.

In [None]:
#downsampling block
#Structure : Conv2D -> BatchNorm -> LeakyReLU

def downsample(filters, size, apply_batchnorm = True):
    initializer = tf.random_normal_initializer(0. , 0.02)
    
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',kernel_initializer=initializer, use_bias=False))
    
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
        
    result.add(tf.keras.layers.LeakyReLU())
    
    return result

#upsampling block
#Structure : Conv2DTranspose -> BatchNorm -> Dropout -> ReLU

def upsample(filters, size, apply_dropout = True):
    initializer = tf.random_normal_initializer(0. , 0.02)
    
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same',kernel_initializer=initializer, use_bias=False))
    result.add(tf.keras.layers.BatchNormalization())
    
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    
    result.add(tf.keras.layers.ReLU())
    
    return result
    

In [None]:
def Generator():
    inputs = tf.keras.layers.Input(shape = [128,128,3])
    
    down_stack = [
        downsample(64, 4, apply_batchnorm = False), #(bs,64,64,64)
        downsample(128, 4), #(bs,32,32,128)
        downsample(256, 4), #(bs,16,16,256)
        downsample(512, 4), #(bs,8,8,512)
        downsample(512, 4), #(bs,4,4,512)
        downsample(512, 4), #(bs,2,2,512)
        downsample(512, 4), #(bs,1,1,512)
    ] 
    #each downsampling reduces size by 2 because of stride = 2,
    #bs = batch size, 4th value is number of filters
    
    up_stack = [
        upsample(512, 4, apply_dropout = True), #(bs,2,2,1024)
        upsample(512, 4, apply_dropout = True), #(bs,4,4,1024)
        upsample(512, 4, apply_dropout = True), #(bs,8,8,1024)
        upsample(512, 4), #(bs,16,16,1024)
        upsample(256, 4), #(bs,32,32,512)
        upsample(128, 4), #(bs,64,64,256)
        
    ]
    
    initializer = tf.random_normal_initializer(0. , 0.02)
    last = tf.keras.layers.Conv2DTranspose(3,4,strides=4,padding='same',kernel_initializer = initializer,activation = 'tanh')#(bs,256,256,3)                                      
    
    x = inputs
    
    
    skips = []
    
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])
    
    for up,skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x,skip])
    
    x = last(x)
    
    return tf.keras.Model(inputs = inputs, outputs = x)

In [None]:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes = True)

**GENERATOR LOSS**

* Generator loss consists of two components
* L1 loss which is mean absolute error between the generated image and target image to make generated images structurally similar to target images
* GAN loss which is binary crossentropy loss of discriminator's output on generated images and array of ones.
* Total loss = GAN loss + (LAMBDA * L1 loss)

In [None]:
LAMBDA = 500
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits = True)

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    
    l1_loss = tf.reduce_mean(tf.abs(target-gen_output))
    
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    
    return total_gen_loss

**CREATING DISCRIMINATOR MODEL**

* Discriminator model is a PatchGAN
* In a PatchGAN, the output is a 3D vector referring to similarity between patches of input and target images
* Model consists of downsampling blocks : Conv->BatchNorm->LeakyReLU
* It receives two inputs : Input image and generated image which is classified as fake and input image and target image which is classified as real


In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    
    inp = tf.keras.layers.Input(shape=[128,128,3], name='input_image')
    tar = tf.keras.layers.Input(shape=[256,256,3], name='target_image')
    
    inp_resized = tf.image.resize(inp, (256,256), method = 'bicubic')
    
    x = tf.keras.layers.concatenate([inp_resized,tar]) #(bs,256,256,6)
    x = downsample(64,4,False)(x) #(bs,128,128,64)
    x = downsample(128,4)(x) #(bs,64,64,128)
    x = downsample(256,4)(x) #(bs,32,32,256)
    
    x = tf.keras.layers.ZeroPadding2D()(x) #(bs,34,34,256)
    x = tf.keras.layers.Conv2D(512,4,strides=1,kernel_initializer=initializer,use_bias=False)(x) #(bs,31,31,512)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.ZeroPadding2D()(x) #(bs,33,33,512)
    x = tf.keras.layers.Conv2D(1,4,strides=1,kernel_initializer=initializer)(x) #(bs,30,30,1)
    
    return tf.keras.Model(inputs = [inp,tar], outputs = x)

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True)

**DISCRIMINATOR LOSS**

* It takes two inputs : discriminator output for real images and generated images and it has two components : real loss and generated loss
* real loss is sigmoid cross entropy loss of real image output and array of ones
* generated loss is sigmoid cross entropy loss of generated image output and array of zeros
* total loss is sum of real loss and generated loss

In [None]:
def discriminator_loss(disc_real_output, disc_gen_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    
    gen_loss = loss_object(tf.zeros_like(disc_gen_output), disc_gen_output)
    
    return real_loss + gen_loss

**OPTIMIZERS AND CHECKPOINT SAVER**

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4,beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4,beta_1=0.5)

checkpoint_dir = "./training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir,"ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer,
                                discriminator_optimizer = discriminator_optimizer,
                                generator = generator,
                                discriminator = discriminator)


**FUNCTION TO GENERATE IMAGES**

In [None]:
def generate_images(model, inp, tar):
    inp_normalized = (inp/127.5)-1
    pred = model(inp_normalized, training = True)
    pred = (pred+1)*127.5
    
    pred = Image.fromarray(tar[0].astype('uint8'),'RGB')
    pred = pred.resize((200,200))
    pred = np.array(pred).reshape((1,200,200,3))
    display_list = [np.array(inp[0], dtype='int'),np.array(pred[0], dtype='int'),np.array(tar[0], dtype='int')]
    title_list = ['input','prediction','target']
    plt.figure(figsize = (20,20))
    
    for i in range(3):
        plt.subplot(1,3,i+1)
        plt.title(title_list[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    
    plt.show()




**TRAINING**

In [None]:
EPOCHS = 10
@tf.function
def train_step(inp, tar, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(inp, training = True)
        
        disc_real_output = discriminator([inp,tar], training = True)
        disc_gen_output = discriminator([inp,gen_output], training = True)
        
        gen_loss = generator_loss(disc_gen_output, gen_output, tar)
        disc_loss = discriminator_loss(disc_real_output, disc_gen_output)
    
    gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

def fit(X_train, Y_train, X_test, Y_test, epochs):
    for epoch in range(epochs):
        clear_output(wait = True)
        
        #generate_images(generator, (X_test[0]+1)*127.5, (Y_test[0]+1)*127.5)

        print("Epoch : ", epoch)
        
        for inp,tar in tqdm(zip(X_train,Y_train)):
            train_step(inp,tar,epoch)
        
        if (epoch+1)%10 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
        
    checkpoint.save(file_prefix=checkpoint_prefix)

fit(X_train, Y_train, X_test, Y_test, EPOCHS)

In [None]:
generate_images(generator, (X_test[1]+1)*127.5, (Y_test[1]+1)*127.5)