# **ESRGAN implementation**
-Final implementation 

**Importing necessary modules**

In [None]:
import tensorflow as tf
from matplotlib import pyplot as plt
import os
import time
import cv2
import numpy as np
from tensorflow.keras.layers import Add,Concatenate,LeakyReLU,Conv2D,Lambda,UpSampling2D
from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.applications.vgg19 import preprocess_input

**Data preprocessing**

In [None]:
PATH = "C:\\Users\\kbr91\\Documents\\archive\\data"
EPOCHS = 5

def random_crop(input_image):
    start_height = np.random.randint(0,input_image.shape[0]-96)
    start_width = np.random.randint(0,input_image.shape[1]-96)
    image = input_image[start_height:start_height+96 , start_width:start_width+96]

    return image

def load_hr(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_png(image)
    image = np.asarray(image)
    hr_image = random_crop(image)

    return hr_image

def load_lr(hr_image):
    lr_image = cv2.blur(hr_image,(3,3))
    lr_image = cv2.resize(lr_image, (24,24))

    return lr_image

def normalize(image):
    image_t = tf.convert_to_tensor(image , dtype = tf.float32)
    image_t = image_t/127.5 -1
    return image_t


model_training_dataset = os.listdir(PATH)
for i in range(len(model_training_dataset)):
  model_training_dataset[i] = PATH + '/'+model_training_dataset[i]
model_training_hr_dataset = list(map(load_hr, model_training_dataset))
model_training_lr_dataset = list(map(load_lr,model_training_hr_dataset))
model_training_hr_dataset = tf.convert_to_tensor(list(map(normalize , model_training_hr_dataset)))
model_training_lr_dataset = tf.convert_to_tensor(list(map(normalize , model_training_lr_dataset)))

In [None]:
#Just to check some of the photo processed
for i in range(5):
    plt.imshow(model_training_hr_dataset[i]/2 + 0.5)
    plt.show()
    plt.imshow(model_training_lr_dataset[i]/2 + 0.5)
    plt.show()

**Building the Generator**

In [None]:
def dense_block(inpt):
    b1 = Conv2D(64, kernel_size=3, strides=1, padding='same')(inpt)
    b1 = LeakyReLU(0.2)(b1)
    b1 = Concatenate()([inpt,b1])
    b2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(b1)
    b2 = LeakyReLU(0.2)(b2)
    b2 = Concatenate()([inpt,b1,b2]) 
    b5 = Conv2D(64, kernel_size=3, strides=1, padding='same')(b2)
    b5 = b5*0.2
    b5 = Add()([b5, inpt])
    return b5

def RRDB(inpt):
    x = dense_block(inpt)
    x = dense_block(x)
    x = x*0.2
    out = Add()([x,inpt])

    return out

def buildGenerator():
  inpt = tf.keras.Input(shape = [24,24,3])
  up = UpSampling2D(4)(inpt)
  conv1 =Conv2D(64 , kernel_size = 3 , strides = 1,padding = 'same')(up)
  rrdb = RRDB(conv1)
  conv = Conv2D(filters = 256, kernel_size = 3, strides = 1, padding = 'same',activation = 'relu')(rrdb)
  out = Conv2D(3 , kernel_size= 3 , strides = 1 , padding = 'same')(conv)
  return tf.keras.Model(inputs = inpt , outputs = out)




In [None]:
gen_model = buildGenerator()
gen_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 24, 24, 3)]  0                                            
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 96, 96, 3)    0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 96, 96, 64)   1792        up_sampling2d[0][0]              
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 96, 96, 64)   36928       conv2d[0][0]                     
______________________________________________________________________________________________

**Building the Relativistic Discriminator**

In [None]:
from keras import backend as K
def build_disc_model():
    
    
    leakyrelu_alpha = 0.2
    momentum = 0.8

    input_0 = tf.keras.layers.Input(shape=(24,24,3))
    input_0_upscale = UpSampling2D(4)(input_0)

    input_1 = tf.keras.layers.Input(shape=(96,96,3))
    input_2 = tf.keras.layers.Input(shape = (96,96,3))

    x = tf.keras.layers.concatenate([input_0_upscale,input_1])
    y = tf.keras.layers.concatenate([input_0_upscale,input_2])
    for i in range(4):
      x = Conv2D(64 , kernel_size = 6 , strides = 1 , padding = 'same')(x)
      y = Conv2D(64 , kernel_size = 6 , strides = 1 , padding = 'same')(y)
      x = LeakyReLU()(x)
      y = LeakyReLU()(y)
      x = tf.keras.layers.BatchNormalization()(x)
      y = tf.keras.layers.BatchNormalization()(y)

    logits = x-K.mean(y)
    # fully connected layer 
    output = Conv2D(1,4, activation='sigmoid' , padding = 'same')(logits)   
    
    model = tf.keras.Model(inputs=[input_0 , input_1,input_2], outputs=[output], name='disc_model')
    
    return model

In [None]:
disc_model = build_disc_model()
disc_model.summary()

Model: "discriminator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 24, 24, 3)]  0                                            
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 96, 96, 3)    0           input_2[0][0]                    
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 96, 96, 3)]  0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None, 96, 96, 3)]  0                                            
______________________________________________________________________________________

**Building the ESRGAN**

In [None]:
def relativistic_loss(disc_real,disc_gen):
    real = disc_real
    fake = disc_gen
    fake_logits = K.sigmoid(fake - K.mean(real))
    real_logits = K.sigmoid(real - K.mean(fake))
            
    return [fake_logits, real_logits]

In [None]:

def disc_model_loss(fake_logits , real_logits) :
  return  K.mean(K.binary_crossentropy(K.zeros_like(fake_logits),fake_logits)+K.binary_crossentropy(K.ones_like(real_logits),real_logits))

def gen_model_loss(fake_logits , real_logits) :
  return  K.mean(K.binary_crossentropy(K.zeros_like(real_logits),real_logits)+K.binary_crossentropy(K.ones_like(fake_logits),fake_logits))


In [None]:
gen_model_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_model_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)


# The VGG model for VGG loss , made for our input shape 
vgg = VGG19(include_top = False, input_shape=(96,96,3))

Lambda = 0.05
Eeta = 1 #both these values are supposed to be changed after model_traininging_cycles. The initial values are such that the GAN first predicts a rough figure about the images
EPOCHS = 50


def model_training_step(input_lr_image, target, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = gen_model(input_lr_image, model_traininging=True)

        real_logits  = disc_model([input_lr_image, target , gen_output], model_traininging=True)
        fake_logits  = disc_model([input_lr_image, target , gen_output], model_traininging=True)
        gen_loss     = Lambda*gen_model_loss(fake_logits, real_logits)
        gen_loss    += Eeta*tf.reduce_mean(tf.abs(target - gen_output))
        feature_gen  = vgg(preprocess_input(gen_output))
        feature_real = vgg(preprocess_input(np.copy(target)))
        vgg_loss     = tf.keras.losses.mean_squared_error(feature_gen , feature_real)
        gen_loss    += 100*vgg_loss
        disc_loss = disc_model_loss(fake_logits, real_logits)

    gen_model_gradients = gen_tape.gradient(gen_loss,
                                          gen_model.model_trainingable_variables)
    disc_model_gradients = disc_tape.gradient(disc_loss,
                                               disc_model.model_trainingable_variables)

    gen_model_optimizer.apply_gradients(zip(gen_model_gradients,
                                          gen_model.model_trainingable_variables))
    disc_model_optimizer.apply_gradients(zip(disc_model_gradients,
                                              disc_model.model_trainingable_variables))

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
def fit(model_training_lr,model_training_hr, model_traininging_cycles):
    for epoch in range(model_traininging_cycles):
        start = time.time()
        print(".")
        for i in range(50):
          input_image  = model_training_lr[4*i:4*i+4]
          target  =  model_training_hr[4*i:4*i+4]
          model_training_step(input_image,target , epoch)

        if (epoch + 1) % 5 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
        generated = gen_model(model_training_lr_dataset[0:5])
        for i in range(len(generated)):
          plt.subplot(1,3,1)
          plt.imshow(model_training_lr_dataset[i]/2+0.5)
          plt.subplot(1,3,2)
          plt.imshow(model_training_hr_dataset[i]/2+0.5)
          plt.subplot(1,3,3)
          plt.imshow(generated[i]/2 + 0.5)
          plt.show()


        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
        



checkpoint_dir = './ESRGAN_checkpoints/'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.model_training.Checkpoint(gen_model_optimizer=gen_model_optimizer,disc_model_optimizer=disc_model_optimizer,gen_model=gen_model,disc_model=disc_model)

checkpoint.restore(tf.model_training.latest_checkpoint(checkpoint_dir))

fit(model_training_lr_dataset,model_training_hr_dataset , EPOCHS)


In [None]:

import random
for i in range(0,80):
  generated = gen_model(model_training_lr_dataset[10*i:10*i+10])
  for j in range(9):
      f ,ax = plt.subplots(1,3)
      ax[0].imshow(model_training_lr_dataset[j+10*i]/2+0.5)
      ax[0].set_title("LR image")
      ax[0].axis('off')
      ax[1].imshow(model_training_hr_dataset[j+10*i]/2+0.5)
      ax[1].set_title("HR image")
      ax[1].axis('off')
      ax[2].imshow(generated[j]/2 + 0.5)
      ax[2].set_title("Generated image")
      ax[2].axis('off')
      plt.show()

In [None]:
checkpoint.save(file_prefix = checkpoint_prefix)

'./ESRGAN_checkpoints/ckpt-4'