In [None]:
import os
import sys
import argparse
import tensorflow as tf

from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy, Reduction
from tensorflow.keras.layers import (
    BatchNormalization, GlobalAvgPool2D, LeakyReLU, Rescaling,
    Conv2D, Dense, PReLU, Add, Input
)
from tensorflow.keras.models import Model, load_model
from tensorflow import GradientTape, concat, zeros, ones, reduce_mean, distribute
from tensorflow.keras.applications import VGG19
from tensorflow.keras.preprocessing.image import array_to_img
from matplotlib.pyplot import subplots, savefig, title, xticks, yticks, show
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from tensorflow.config import experimental_connect_to_cluster
from tensorflow.tpu.experimental import initialize_tpu_system
from tensorflow.keras.optimizers import Adam
from tensorflow.io.gfile import glob

In [None]:
# name of the TFDS dataset we will be using
DATASET = "div2k/bicubic_x4"

# define the shard size and batch size
SHARD_SIZE = 256
TRAIN_BATCH_SIZE = 64
INFER_BATCH_SIZE = 8

# dataset specs
HR_SHAPE = [96, 96, 3]
LR_SHAPE = [24, 24, 3]
SCALING_FACTOR = 4

# GAN model specs
FEATURE_MAPS = 64
RESIDUAL_BLOCKS = 16
LEAKY_ALPHA = 0.2
DISC_BLOCKS = 4

# training specs
PRETRAIN_LR = 1e-4
FINETUNE_LR = 1e-5
PRETRAIN_EPOCHS = 2500
FINETUNE_EPOCHS = 2500
STEPS_PER_EPOCH = 10

# define the path to the dataset
BASE_DATA_PATH = "dataset"
DIV2K_PATH = os.path.join(BASE_DATA_PATH, "div2k")

# define the path to the tfrecords for GPU training
GPU_BASE_TFR_PATH = "tfrecord"
GPU_DIV2K_TFR_TRAIN_PATH = os.path.join(GPU_BASE_TFR_PATH, "train")
GPU_DIV2K_TFR_TEST_PATH = os.path.join(GPU_BASE_TFR_PATH, "test")

# path to our base output directory
BASE_OUTPUT_PATH = "outputs"

# GPU training SRGAN model paths
GPU_PRETRAINED_GENERATOR_MODEL = os.path.join(BASE_OUTPUT_PATH,
    "models", "pretrained_generator")
GPU_GENERATOR_MODEL = os.path.join(BASE_OUTPUT_PATH, "models",
    "generator")

# define the path to the inferred images and to the grid image
BASE_IMAGE_PATH = os.path.join(BASE_OUTPUT_PATH, "images")
GRID_IMAGE_PATH = os.path.join(BASE_IMAGE_PATH, "grid.png")

In [3]:
# define AUTOTUNE object
AUTO = tf.data.AUTOTUNE

def random_crop(lrImage, hrImage, hrCropSize=96, scale=4):
    # calculate the low resolution image crop size and image shape
    lrCropSize = hrCropSize // scale
    lrImageShape = tf.shape(lrImage)[:2]
    
    # calculate the low resolution image width and height offsets
    lrW = tf.random.uniform(shape=(),
        maxval=lrImageShape[1] - lrCropSize + 1, dtype=tf.int32)
    lrH = tf.random.uniform(shape=(),
        maxval=lrImageShape[0] - lrCropSize + 1, dtype=tf.int32)
    
    # calculate the high resolution image width and height
    hrW = lrW * scale
    hrH = lrH * scale
    
    # crop the low and high resolution images
    lrImageCropped = tf.slice(lrImage, [lrH, lrW, 0], 
        [(lrCropSize), (lrCropSize), 3])
    hrImageCropped = tf.slice(hrImage, [hrH, hrW, 0],
        [(hrCropSize), (hrCropSize), 3])
    
    # return the cropped low and high resolution images
    return (lrImageCropped, hrImageCropped)

In [4]:
def get_center_crop(lrImage, hrImage, hrCropSize=96, scale=4):
    # calculate the low resolution image crop size and image shape
    lrCropSize = hrCropSize // scale
    lrImageShape = tf.shape(lrImage)[:2]
    
    # calculate the low resolution image width and height
    lrW = lrImageShape[1] // 2
    lrH = lrImageShape[0] // 2
    
    # calculate the high resolution image width and height
    hrW = lrW * scale
    hrH = lrH * scale
    
    # crop the low and high resolution images
    lrImageCropped = tf.slice(lrImage, [lrH - (lrCropSize // 2),
        lrW - (lrCropSize // 2), 0], [lrCropSize, lrCropSize, 3])
    hrImageCropped = tf.slice(hrImage, [hrH - (hrCropSize // 2),
        hrW - (hrCropSize // 2), 0], [hrCropSize, hrCropSize, 3])
    
    # return the cropped low and high resolution images
    return (lrImageCropped, hrImageCropped)

In [None]:
def random_flip(lrImage, hrImage):
    # calculate a random chance for flip
    flipProb = tf.random.uniform(shape=(), maxval=1)
    (lrImage, hrImage) = tf.cond(flipProb < 0.5,
        lambda: (lrImage, hrImage),
        lambda: (tf.image.flip_left_right(lrImage), tf.image.flip_left_right(hrImage)))
    
    # return the randomly flipped low and high resolution images
    return (lrImage, hrImage)

In [None]:
def random_rotate(lrImage, hrImage):
    # randomly generate the number of 90 degree rotations
    n = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
    
    # rotate the low and high resolution images
    lrImage = tf.image.rot90(lrImage, n)
    hrImage = tf.imagerot90(hrImage, n)
    
    # return the randomly rotated images
    return (lrImage, hrImage)

In [None]:
def read_train_example(example):
    # get the feature template and  parse a single image according to
    # the feature template
    feature = {
        "lr": tf.io.FixedLenFeature([], tf.string),
        "hr": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, feature)
    
    # parse the low and high resolution images
    lrImage = tf.io.parse_tensor(example["lr"], out_type=tf.uint8)
    hrImage = tf.io.parse_tensor(example["hr"], out_type=tf.uint8)
    
    # perform data augmentation
    (lrImage, hrImage) = random_crop(lrImage, hrImage)
    (lrImage, hrImage) = random_flip(lrImage, hrImage)
    (lrImage, hrImage) = random_rotate(lrImage, hrImage)
    
    # reshape the low and high resolution images
    lrImage = tf.reshape(lrImage, (24, 24, 3))
    hrImage = tf.reshape(hrImage, (96, 96, 3))
    
    # return the low and high resolution images
    return (lrImage, hrImage)

In [None]:
def read_test_example(example):
    # get the feature template and  parse a single image according to
    # the feature template
    feature = {
        "lr": tf.io.FixedLenFeature([], tf.string),
        "hr": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, feature)
    
    # parse the low and high resolution images
    lrImage = tf.io.parse_tensor(example["lr"], out_type=tf.uint8)
    hrImage = tf.io.parse_tensor(example["hr"], out_type=tf.uint8)
    
    # center crop both low and high resolution image
    (lrImage, hrImage) = get_center_crop(lrImage, hrImage)
    
    # reshape the low and high resolution images
    lrImage = tf.reshape(lrImage, (24, 24, 3))
    hrImage = tf.reshape(hrImage, (96, 96, 3))
    
    # return the low and high resolution images
    return (lrImage, hrImage)

In [9]:
def load_dataset(filenames, batchSize, train=False):
    # get the TFRecords from the filenames
    dataset = tf.data.TFRecordDataset(filenames, 
        num_parallel_reads=AUTO)
    
    # check if this is the training dataset
    if train:
        # read the training examples
        dataset = dataset.map(read_train_example,
            num_parallel_calls=AUTO)
    # otherwise, we are working with the test dataset
    else:
        # read the test examples
        dataset = dataset.map(read_test_example,
            num_parallel_calls=AUTO)
        
    # batch and prefetch the data
    dataset = (dataset
        .shuffle(batchSize)
        .batch(batchSize)
        .repeat()
        .prefetch(AUTO)
    )
    
    # return the dataset
    return dataset

## Implementing the SRGAN Loss Functions

In [None]:
class Losses:
    
    def __init__(self, numReplicas):
        self.numReplicas = numReplicas
        
    def bce_loss(self, real, pred):
        # compute binary cross entropy loss without reduction
        BCE = BinaryCrossentropy(reduction=Reduction.NONE)
        loss = BCE(real, pred)
        
        # compute reduced mean over the entire batch
        loss = reduce_mean(loss) * (1. / self.numReplicas)
        
        # return reduced bce loss
        return loss
    
    def mse_loss(self, real, pred):
        # compute mean squared error loss without reduction
        MSE = MeanSquaredError(reduction=Reduction.NONE)
        loss = MSE(real, pred)
        
        # compute reduced mean over the entire batch
        loss = reduce_mean(loss) * (1. / self.numReplicas)
        
        # return reduced mse loss
        return loss

## Implementing the SRGAN

In [None]:
class SRGAN(object):
    
    @staticmethod
    def generator(scalingFactor, featureMaps, residualBlocks):
        # initialize the input layer
        inputs = Input((None, None, 3))
        xIn = Rescaling(scale=(1.0 / 255.0), offset=0.0)(inputs)
        
        # pass the input through CONV => PReLU block
        xIn = Conv2D(featureMaps, 9, padding="same")(xIn)
        xIn = PReLU(shared_axes=[1, 2])(xIn)
        
        # construct the "residual in residual" block
        x = Conv2D(featureMaps, 3, padding="same")(xIn)
        x = BatchNormalization()(x)
        x = PReLU(shared_axes=[1, 2])(x)
        x = Conv2D(featureMaps, 3, padding="same")(x)
        x = BatchNormalization()(x)
        xSkip = Add()([xIn, x])
        
        # create a number of residual blocks
        for _ in range(residualBlocks - 1):
            x = Conv2D(featureMaps, 3, padding="same")(xSkip)
            x = BatchNormalization()(x)
            x = PReLU(shared_axes=[1, 2])(x)
            x = Conv2D(featureMaps, 3, padding="same")(x)
            x = BatchNormalization()(x)
            xSkip = Add()([xSkip, x])
        
        # get the last residual block without activation
        x = Conv2D(featureMaps, 3, padding="same")(xSkip)
        x = BatchNormalization()(x)
        x = Add()([xIn, x])
        
        # upscale the image with pixel shuffle
        x = Conv2D(featureMaps * (scalingFactor // 2), 3, padding="same")(x)
        x = tf.nn.depth_to_space(x, 2)
        x = PReLU(shared_axes=[1, 2])(x)
        
        # upscale the image with pixel shuffle
        x = Conv2D(featureMaps * scalingFactor, 3,
            padding="same")(x)
        x = tf.nn.depth_to_space(x, 2)
        x = PReLU(shared_axes=[1, 2])(x)
        
        # get the output and scale it from [-1, 1] to [0, 255] range
        x = Conv2D(3, 9, padding="same", activation="tanh")(x)
        x = Rescaling(scale=127.5, offset=127.5)(x)
    
        # create the generator model
        generator = Model(inputs, x)
        
        # return the generator
        return generator
    
    @staticmethod
    def discriminator(featureMaps, leakyAlpha, discBlocks):
        # initialize the input layer and process it with conv kernel
        inputs = Input((None, None, 3))
        x = Rescaling(scale=(1.0 / 127.5), offset=-1.0)(inputs)
        x = Conv2D(featureMaps, 3, padding="same")(x)
        
        # unlike the generator we use leaky relu in the discriminator
        x = LeakyReLU(leakyAlpha)(x)
        
        # pass the output from previous layer through a CONV => BN =>
        # LeakyReLU block
        x = Conv2D(featureMaps, 3, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(leakyAlpha)(x)
        
        # create a number of discriminator blocks
        for i in range(1, discBlocks):
            # first CONV => BN => LeakyReLU block
            x = Conv2D(featureMaps * (2 ** i), 3, strides=2,
                padding="same")(x)
            x = BatchNormalization()(x)
            x = LeakyReLU(leakyAlpha)(x)
            # second CONV => BN => LeakyReLU block
            x = Conv2D(featureMaps * (2 ** i), 3, padding="same")(x)
            x = BatchNormalization()(x)
            x = LeakyReLU(leakyAlpha)(x)
            
        # process the feature maps with global average pooling
        x = GlobalAvgPool2D()(x)
        x = LeakyReLU(leakyAlpha)(x)
        
        # final FC layer with sigmoid activation function
        x = Dense(1, activation="sigmoid")(x)
        
        # create the discriminator model
        discriminator = Model(inputs, x)
        
        # return the discriminator
        return discriminator

## Implementing the SRGAN Training Script

In [None]:
class SRGANTraining(Model):
    
    def __init__(self, generator, discriminator, vgg, batchSize):
        super().__init__()
        # initialize the generator, discriminator, vgg model, and 
        # the global batch size
        self.generator = generator
        self.discriminator = discriminator
        self.vgg = vgg
        self.batchSize = batchSize
    
    def compile(self, gOptimizer, dOptimizer, bceLoss, mseLoss):
        super().compile()
        # initialize the optimizers for the generator 
        # and discriminator
        self.gOptimizer = gOptimizer
        self.dOptimizer = dOptimizer
        
        # initialize the loss functions
        self.bceLoss = bceLoss
        self.mseLoss = mseLoss
    
    def train_step(self, images):
        # grab the low and high resolution images
        (lrImages, hrImages) = images
        lrImages = tf.cast(lrImages, tf.float32)
        hrImages = tf.cast(hrImages, tf.float32)
        
        # generate super resolution images
        srImages = self.generator(lrImages)
        
        # combine them with real images
        combinedImages = concat([srImages, hrImages], axis=0)
        
        # assemble labels discriminating real from fake images where
        # label 0 is for predicted images and 1 is for original high
        # resolution images
        labels = concat(
            [zeros((self.batchSize, 1)), ones((self.batchSize, 1))],
            axis=0)
        
        # train the discriminator
        with GradientTape() as tape:
            # get the discriminator predictions
            predictions = self.discriminator(combinedImages)
            
            # compute the loss
            dLoss = self.bceLoss(labels, predictions)
        
        # compute the gradients
        grads = tape.gradient(dLoss,
            self.discriminator.trainable_variables)
        
        # optimize the discriminator weights according to the
        # gradients computed
        self.dOptimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_variables)
        )
        
        # generate misleading labels
        misleadingLabels = ones((self.batchSize, 1))
        
        # train the generator (note that we should *not* update the
        #  weights of the discriminator)!
        with GradientTape() as tape:
            # get fake images from the generator
            fakeImages = self.generator(lrImages)
        
            # get the prediction from the discriminator
            predictions = self.discriminator(fakeImages)
        
            # compute the adversarial loss
            gLoss = 1e-3 * self.bceLoss(misleadingLabels, predictions)
            
            # compute the normalized vgg outputs
            srVgg = tf.keras.applications.vgg19.preprocess_input(
                fakeImages)
            srVgg = self.vgg(srVgg) / 12.75
            hrVgg = tf.keras.applications.vgg19.preprocess_input(
                hrImages)
            hrVgg = self.vgg(hrVgg) / 12.75
            # compute the perceptual loss
            percLoss = self.mseLoss(hrVgg, srVgg)
        
            # calculate the total generator loss
            gTotalLoss = gLoss + percLoss
        
        # compute the gradients
        grads = tape.gradient(gTotalLoss,
            self.generator.trainable_variables)
        
        # optimize the generator weights with the computed gradients
        self.gOptimizer.apply_gradients(zip(grads,
            self.generator.trainable_variables)
        )
        
        # return the generator and discriminator losses
        return {"dLoss": dLoss, "gTotalLoss": gTotalLoss,
            "gLoss": gLoss, "percLoss": percLoss}

## Implementing the Final Utility Scripts

In [None]:
class VGG:
    
    @staticmethod
    def build():
        # initialize the pre-trained VGG19 model
        vgg = VGG19(input_shape=(None, None, 3), weights="imagenet",
            include_top=False)
        
        # slicing the VGG19 model till layer #20
        model = Model(vgg.input, vgg.layers[20].output)
        
        # return the sliced VGG19 model
        return model

## Assesing the output images

In [None]:
# the following code snippet has been taken from:
# https://keras.io/examples/vision/super_resolution_sub_pixel
def zoom_into_images(image, imageTitle):
    # create a new figure with a default 111 subplot.
    (fig, ax) = subplots()
    im = ax.imshow(array_to_img(image[::-1]), origin="lower")
    title(imageTitle)
    
    # zoom-factor: 2.0, location: upper-left
    axins = zoomed_inset_axes(ax, 2, loc=2)
    axins.imshow(array_to_img(image[::-1]), origin="lower")
    
    # specify the limits.
    (x1, x2, y1, y2) = 20, 40, 20, 40
    
    # apply the x-limits.
    axins.set_xlim(x1, x2)
    
    # apply the y-limits.
    axins.set_ylim(y1, y2)
    
    # remove the xticks and yticks
    yticks(visible=False)
    xticks(visible=False)
    
    # make the line.
    mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
    
    # build the image path and save it to disk
    imagePath = os.path.join(BASE_IMAGE_PATH, f"{imageTitle}.png")
    savefig(imagePath)
    
    # show the image
    show()

## Training the SRGAN

In [None]:
tf.random.set_seed(42)

# define the multi-gpu strategy
strategy = distribute.MirroredStrategy()

# set the train TFRecords, pretrained generator, and final
# generator model paths to be used for GPU training
tfrTrainPath = GPU_DIV2K_TFR_TRAIN_PATH
pretrainedGenPath = GPU_PRETRAINED_GENERATOR_MODEL
genPath = GPU_GENERATOR_MODEL

# display the number of accelerators
print("[INFO] number of accelerators: {}..."
    .format(strategy.num_replicas_in_sync))

# grab train TFRecord filenames
print("[INFO] grabbing the train TFRecords...")
trainTfr = glob(tfrTrainPath +"/*.tfrec")

# build the div2k datasets from the TFRecords
print("[INFO] creating train and test dataset...")
trainDs = load_dataset(filenames=trainTfr, train=True,
    batchSize=TRAIN_BATCH_SIZE * strategy.num_replicas_in_sync)

# call the strategy scope context manager
with strategy.scope():
    # initialize our losses class object
    losses = Losses(numReplicas=strategy.num_replicas_in_sync)
    
    # initialize the generator, and compile it with Adam optimizer and
    # MSE loss
    generator = SRGAN.generator(
        scalingFactor=SCALING_FACTOR,
        featureMaps=FEATURE_MAPS,
        residualBlocks=RESIDUAL_BLOCKS)
    
    generator.compile(
        optimizer=Adam(learning_rate=PRETRAIN_LR),
        loss=losses.mse_loss)
    
    # pretraining the generator
    print("[INFO] pretraining SRGAN generator...")
    generator.fit(trainDs, epochs=PRETRAIN_EPOCHS,
        steps_per_epoch=STEPS_PER_EPOCH)
    
# check whether output model directory exists, if it doesn't, then
# create it
if not os.path.exists(BASE_OUTPUT_PATH):
    os.makedirs(BASE_OUTPUT_PATH)
    
# save the pretrained generator
print("[INFO] saving the SRGAN pretrained generator to {}..."
    .format(pretrainedGenPath))
generator.save(pretrainedGenPath)

# call the strategy scope context manager
with strategy.scope():
    # initialize our losses class object
    losses = Losses(numReplicas=strategy.num_replicas_in_sync)
    
    # initialize the vgg network (for perceptual loss) and discriminator
    # network
    vgg = VGG.build()
    
    discriminator = SRGAN.discriminator(
        featureMaps=FEATURE_MAPS, 
        leakyAlpha=LEAKY_ALPHA, discBlocks=DISC_BLOCKS)
    
    # build the SRGAN training model and compile it
    srgan = SRGANTraining(
        generator=generator,
        discriminator=discriminator,
        vgg=vgg,
        batchSize=TRAIN_BATCH_SIZE)
    
    srgan.compile(
        dOptimizer=Adam(learning_rate=FINETUNE_LR),
        gOptimizer=Adam(learning_rate=FINETUNE_LR),
        bceLoss=losses.bce_loss,
        mseLoss=losses.mse_loss,
    )
    
    # train the SRGAN model
    print("[INFO] training SRGAN...")
    srgan.fit(trainDs, epochs=FINETUNE_EPOCHS,
        steps_per_epoch=STEPS_PER_EPOCH)

# save the SRGAN generator
print("[INFO] saving SRGAN generator to {}...".format(genPath))
srgan.generator.save(genPath)

## Creating the inference script for the SRGAN

In [None]:
# define the multi-gpu strategy
strategy = distribute.MirroredStrategy()

# set the train TFRecords, pretrained generator, and final
# generator model paths to be used for GPU training
tfrTestPath = GPU_DIV2K_TFR_TEST_PATH
pretrainedGenPath = GPU_PRETRAINED_GENERATOR_MODEL
genPath = GPU_GENERATOR_MODEL

# get the dataset
print("[INFO] loading the test dataset...")
testTfr = glob(tfrTestPath + "/*.tfrec")
testDs = load_dataset(testTfr, INFER_BATCH_SIZE, train=False)

# get the first batch of testing images
(lrImage, hrImage) = next(iter(testDs))

# call the strategy scope context manager
with strategy.scope(): 
    # load the SRGAN trained models
    print("[INFO] loading the pre-trained and fully trained SRGAN model...")
    srganPreGen = load_model(pretrainedGenPath, compile=False)
    srganGen = load_model(genPath, compile=False)
    
    # predict using SRGAN
    print("[INFO] making predictions with pre-trained and fully trained SRGAN model...")
    srganPreGenPred = srganPreGen.predict(lrImage)
    srganGenPred = srganGen.predict(lrImage)

# plot the respective predictions
print("[INFO] plotting the SRGAN predictions...")
(fig, axes) = subplots(nrows=INFER_BATCH_SIZE, ncols=4,
    figsize=(50, 50))

# plot the predicted images from low res to high res
for (ax, lowRes, srPreIm, srGanIm, highRes) in zip(axes, lrImage,
        srganPreGenPred, srganGenPred, hrImage):
    # plot the low resolution image
    ax[0].imshow(array_to_img(lowRes))
    ax[0].set_title("Low Resolution Image")
    
    # plot the pretrained SRGAN image
    ax[1].imshow(array_to_img(srPreIm))
    ax[1].set_title("SRGAN Pretrained")
    
    # plot the SRGAN image
    ax[2].imshow(array_to_img(srGanIm))
    ax[2].set_title("SRGAN")
    
    # plot the high resolution image
    ax[3].imshow(array_to_img(highRes))
    ax[3].set_title("High Resolution Image")

# check whether output image directory exists, if it doesn't, then
# create it
if not os.path.exists(BASE_IMAGE_PATH):
    os.makedirs(BASE_IMAGE_PATH)
    
# serialize the results to disk
print("[INFO] saving the SRGAN predictions to disk...")
fig.savefig(GRID_IMAGE_PATH)

# plot the zoomed in images
zoom_into_images(srganPreGenPred[0], "SRGAN Pretrained")
zoom_into_images(srganGenPred[0], "SRGAN")