The idea has been derived and changed from SRGAN solution of
https://github.com/krasserm/super-resolution/tree/master

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


# Import Libraries

In [3]:
import tensorflow as tf  # software library for high performance numerical computation
from tensorflow.keras.applications.vgg19 import preprocess_input, VGG19   #Instantiates the VGG19 architecture.
import shutil
from tensorflow.keras.layers import Add, BatchNormalization, Conv2D, Dense, Flatten, Input, LeakyReLU, PReLU, Lambda #Basic building blocks of neural network
from tensorflow.keras.models import Model #  A model grouping layers into an object with training/inference features.
from tensorflow.keras.optimizers import Adam  # Importing optimizers
from tensorflow.keras.losses import BinaryCrossentropy, MeanAbsoluteError, MeanSquaredError # importing losses
from tensorflow.keras.metrics import Mean # importing metrics


import os
import matplotlib.pyplot as plt  # for image processing
import numpy as np # for arrays and matrices
from PIL import Image  # for image visualization
import time
#Python environment to draw the plots immediately after the current
from tensorflow.python.data.experimental import AUTOTUNE # To create multiple processor threa
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay # optimizer decay
%matplotlib inline

# Loading dataset

In [4]:
#Function for randomly cutting phrases from low-quality and high-quality images
def crop(low_res_img, high_res_img, high_res_crop_size=96, scale=2):
    lr_crop_size = high_res_crop_size // scale #Image crop size
    low_res_img_shape = tf.shape(low_res_img)[:2] #low resolution cropped image shape
    low_res_width = tf.random.uniform(shape=(), maxval=low_res_img_shape[1] - lr_crop_size + 1, dtype=tf.int32)
    low_res_height = tf.random.uniform(shape=(), maxval=low_res_img_shape[0] - lr_crop_size + 1, dtype=tf.int32)
    high_res_width = low_res_width * scale
    high_res_height = low_res_height * scale
    #cropping of images
    low_res_img_cropped = low_res_img[low_res_height:low_res_height + lr_crop_size, low_res_width:low_res_width + lr_crop_size]
    high_res_img_cropped = high_res_img[high_res_height:high_res_height + high_res_crop_size, high_res_width:high_res_width + high_res_crop_size]

    return low_res_img_cropped, high_res_img_cropped

def random_flip(lr_img, hr_img):
    rn = tf.random.uniform(shape=(), maxval=1)
    return tf.cond(rn < 0.5,
                   lambda: (lr_img, hr_img),
                   lambda: (tf.image.flip_left_right(lr_img),
                            tf.image.flip_left_right(hr_img)))

def random_rotate(lr_img, hr_img):
    rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
    return tf.image.rot90(lr_img, rn), tf.image.rot90(hr_img, rn)



#Download dataset
def download(file, dir, extract=True):
    url = f'http://data.vision.ee.ethz.ch/cvl/DIV2K/{file}'
    tf.keras.utils.get_file(file, url, cache_subdir=os.path.abspath(dir), extract=extract)
    os.remove(os.path.join(dir, file))


In [5]:
class main_dataset:
    def __init__(self,
                 selection='train', # selection
                 img_dir='/content/drive/MyDrive/GenAIBOOK/SRGAN/dataset/images', # images directory
                 cach_dir='/content/drive/MyDrive/GenAIBOOK/SRGAN/dataset/caches'): # cache directory

        self.scale =  4 # scale of image reduction along each axis
        self.downgrade = 'bicubic' #Bi-Cubic Interpolation is a powerful algorithm for image scaling,
        self.selection = selection
        self.img_dir = img_dir
        self.cach_dir = cach_dir

        os.makedirs(img_dir, exist_ok=True)
        os.makedirs(cach_dir, exist_ok=True)

        # Train and valid split
        if selection == 'train':
            self.image_ids = range(1, 801)
        elif selection == 'valid':
            self.image_ids = range(801, 901)
        else:
            raise ValueError("selection must be 'train' or 'valid'")

    def __len__(self):
        return len(self.image_ids)

    def dataset(self, batch_size=16, repeat_count=None, transform=True):
        ds = tf.data.Dataset.zip((self.lr_dataset(), self.hr_dataset())) # dataset creation (lr,hr)
        if transform: # image transformations
            ds = ds.map(lambda lr, hr: crop(lr, hr, scale=self.scale), num_parallel_calls=AUTOTUNE) #random cropping
            ds = ds.map(random_rotate, num_parallel_calls=AUTOTUNE)
            ds = ds.map(random_flip, num_parallel_calls=AUTOTUNE)
        ds = ds.batch(batch_size) # Combines consecutive elements of the dataset into groups
        ds = ds.repeat(repeat_count) #repeat is used to iterate over a dataset in multiple epochs (epoch is a complete dataset).
        ds = ds.prefetch(buffer_size=AUTOTUNE)  #prefetching one batch of data for better performance
        return ds

    #High Resolution Dataset
    def hr_dataset(self):
        if not os.path.exists(os.path.join(self.img_dir, f'DIV2K_{self.selection}_HR')):
            download(f'DIV2K_{self.selection}_HR.zip', self.img_dir, extract=True)
        ds = self.img_dataset(self.high_res_img_files()).cache(self.high_res_cach_file()) #Create dataset from photo cache
        return ds

    #low resolution dataset
    def lr_dataset(self):
        if not os.path.exists(os.path.join(self.img_dir, f'DIV2K_{self.selection}_LR_{self.downgrade}', f'X{self.scale}')):
            download(f'DIV2K_{self.selection}_LR_{self.downgrade}_X{self.scale}.zip', self.img_dir, extract=True)
        ds = self.img_dataset(self.low_res_img_files()).cache(self.low_res_cach_file()) # Create dataset from phtoto chache
        return ds

    def high_res_cach_file(self): #Image cache path for high qualilty images
        return os.path.join(self.cach_dir, f'DIV2K_{self.selection}_HR.cache')

    def low_res_cach_file(self): #Image cache path for low qualilty images
        return os.path.join(self.cach_dir, f'DIV2K_{self.selection}_LR_{self.downgrade}_X{self.scale}.cache')

    def high_res_cach_index(self): # image cache indexes
        return f'{self.high_res_cach_file()}.index'

    def low_res_cach_index(self): # image cache indexws
        return f'{self.low_res_cach_file()}.index'

    #List of high quality images
    def high_res_img_files(self):
        img_dir = os.path.join(self.img_dir, f'DIV2K_{self.selection}_HR')
        return [os.path.join(img_dir, f'{image_id:04}.png') for image_id in self.image_ids]

    #List of low quality images
    def low_res_img_files(self):
        img_dir = os.path.join(self.img_dir, f'DIV2K_{self.selection}_LR_{self.downgrade}', f'X{self.scale}')
        return [os.path.join(img_dir, f'{image_id:04}x{self.scale}.png') for image_id in self.image_ids]

    #Dataset generation
    @staticmethod
    def img_dataset(image_files):
        ds = tf.data.Dataset.from_tensor_slices(image_files) # dataset conversion
        ds = ds.map(tf.io.read_file) # read image and add to dataset
        ds = ds.map(lambda x: tf.image.decode_png(x, channels=3), num_parallel_calls=AUTOTUNE) # convert to 3 channel images
        return ds


In [6]:
img_train = main_dataset(selection='train')
img_valid = main_dataset(selection='valid')
train_ds = img_train.dataset(batch_size=16, transform=True)
valid_ds = img_valid.dataset(batch_size=16, transform=True, repeat_count=1)

Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip
Downloading data from http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip


#Training Networks

In [None]:
# Universal class for training
class Univ_Trainer:
    def __init__(self,
                 model,
                 loss,
                 learning_rate,
                 checkpoint_dir='/content/drive/MyDrive/GenAIBOOK/SRGAN/ckpt/edsr'):

        self.now = None
        self.loss = loss
        #Saving checkpoint
        self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                              psnr=tf.Variable(-1.0),
                                              optimizer=Adam(learning_rate),
                                              model=model)
        #Checkpoint configuration
        self.checkpoint_manager = tf.train.CheckpointManager(checkpoint=self.checkpoint,
                                                             directory=checkpoint_dir,
                                                             max_to_keep=3)
        self.restore() #Restore checkpoint

    #Built-in decorator
    @property
    def model(self):
        return self.checkpoint.model


    def train(self, train_dataset, valid_dataset, steps, evaluate_every=1000):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        self.now = time.perf_counter()

        for lr, hr in train_dataset.take(steps - ckpt.step.numpy()):
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()

            loss = self.train_step(lr, hr)
            print("loss is calculated")
            loss_mean(loss)

            # Disply information based on
            if step % evaluate_every == 0:
                #Calculate average error
                loss_value = loss_mean.result()
                loss_mean.reset_states()

                # Calculate PSNR
                psnr_value = self.evaluate(valid_dataset)
                duration = time.perf_counter() - self.now
                print(f'{step}/{steps}: loss = {loss_value.numpy():.3f}, PSNR = {psnr_value.numpy():3f} ({duration:.2f}s)')

                #Save check point
                ckpt.psnr = psnr_value
                ckpt_mgr.save()
                #Reset the current time
                self.now = time.perf_counter()

    @tf.function
    def train_step(self, lr, hr):

        with tf.GradientTape() as tape: # Gradient descent

            #casting tensors to float
            lr = tf.cast(lr, tf.float32)
            hr = tf.cast(hr, tf.float32)

            #error
            sr = self.checkpoint.model(lr, training=True)
            loss_value = self.loss(hr, sr)

        #Applying gradients
        gradients = tape.gradient(loss_value, self.checkpoint.model.trainable_variables)
        self.checkpoint.optimizer.apply_gradients(zip(gradients, self.checkpoint.model.trainable_variables))

        return loss_value

    def evaluate(self, dataset):
        return evaluate(self.checkpoint.model, dataset)

    #Restore model from control point
    def restore(self):
        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(f'Model restored from checkpoint at step {self.checkpoint.step.numpy()}.')


#Inherits charachterestics of trainer class
class SrganGeneratorTrainer(Univ_Trainer):
    def __init__(self,
                 model,
                 checkpoint_dir,
                 learning_rate=1e-4):
        super().__init__(model, loss=MeanSquaredError(), learning_rate=learning_rate, checkpoint_dir=checkpoint_dir)

    #Setting traing paramerters
    def train(self, train_dataset, valid_dataset, steps=5000, evaluate_every=100):
        super().train(train_dataset, valid_dataset, steps, evaluate_every)

#Class to train the model

class SrganTrainer:

    def __init__(self,
                 generator,
                 discriminator,
                 content_loss='VGG54',
                 learning_rate=PiecewiseConstantDecay(boundaries=[1000], values=[1e-4, 1e-5])):
        vgg = VGG19(input_shape=(None, None, 3), include_top=False)
        self.vgg = Model(vgg.input, vgg.layers[20].output)
        self.content_loss = content_loss
        self.generator = generator
        self.discriminator = discriminator
        self.optimizer_gen = Adam(learning_rate=learning_rate)
        self.optimizer_disc = Adam(learning_rate=learning_rate)
        self.bce_loss = BinaryCrossentropy(from_logits=False)
        self.mse_loss = MeanSquaredError()

    #Define train function
    def train(self, train_dataset, steps=2000):
        pls_metric = Mean() # perception error
        dls_metric = Mean() # discriminator error
        step = 0

        for lr_img, hr_img in train_dataset.take(steps):
            step += 1
            pl, dl = self.train_step(lr_img, hr_img)
            pls_metric(pl)
            dls_metric(dl)

            if step % 200 == 0: # for evey 200 steps
                print(f'{step}/{steps}, perceptual loss = {pls_metric.result():.4f}, discriminator loss = {dls_metric.result():.4f}') #Выводим текущую информацию
                pls_metric.reset_states()
                dls_metric.reset_states()

    @tf.function
    def train_step(self, lr_output, hr_output):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            #cast tensor to new type
            lr_output = tf.cast(lr_output, tf.float32)
            hr_output = tf.cast(hr_output, tf.float32)

            #Generating Super Resolution - Image
            sr_output = self.generator(lr_output, training=True)

            #send the real image to discriminator
            hr_output_f = self.discriminator(hr_output, training=True)
            sr_output_f = self.discriminator(sr_output, training=True)

            content_loss = self._content_loss(hr_output, sr_output)
            generator_loss = self._generator_loss(sr_output_f)

            #The SRGAN uses perpetual loss function (LSR)  which is the weighted sum of two loss components
            #1.content loss 2. adversarial loss.

            perpetual_loss = content_loss + 0.001 * generator_loss
            discriminator_loss = self._discriminator_loss(hr_output_f, sr_output_f)

        #Applying gradients to genreator and discriminator
        gradients_of_generator = gen_tape.gradient(perpetual_loss, self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(discriminator_loss, self.discriminator.trainable_variables)
        self.optimizer_gen.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
        self.optimizer_disc.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))

        return perpetual_loss, discriminator_loss

#The SRGAN uses perpetual loss function (LSR)  which is the weighted sum of two loss components : content loss (VGG) and adversarial loss.
# VGG loss is based on the ReLU activation layers of the pre-trained 19 layer VGG network

    # Defining content_loss
    @tf.function
    def _content_loss(self, hr_output, sr_output):
        sr = preprocess_input(sr_output)
        hr = preprocess_input(hr_output)
        sr_features = self.vgg(sr) / 12.75
        hr_features = self.vgg(hr) / 12.75
        return self.mse_loss(hr_features, sr_features)

    # Defining generator_loss
    def _generator_loss(self, sr_output):
        return self.bce_loss(tf.ones_like(sr_output), sr_output)

    # Defining discriminator_loss
    def _discriminator_loss(self, hr_output, sr_output):
        hr_output_loss = self.bce_loss(tf.ones_like(hr_output), hr_output)
        sr_output_loss = self.bce_loss(tf.zeros_like(sr_output), sr_output)
        return hr_output_loss + sr_output_loss


# Train Model

In [None]:
def resolve(model, lr_batch):
    lr_batch = tf.cast(lr_batch, tf.float32)
    sr_batch = model(lr_batch)
    sr_batch = tf.clip_by_value(sr_batch, 0, 255)
    sr_batch = tf.round(sr_batch)
    sr_batch = tf.cast(sr_batch, tf.uint8)
    return sr_batch

def evaluate(model, dataset):
    psnr_values = []
    for lr, hr in dataset:
        sr = resolve(model, lr)
        psnr_value = psnr(hr, sr)[0]
        psnr_values.append(psnr_value)
    return tf.reduce_mean(psnr_values)


# Normalize RGB Images [0, 1]
def norm01(x):
    return x / 255.0

# Normalize RGB Images [-1, 1]
def norm11(x):
    return x / 127.5 - 1

# Denormalize the images
def denorm11(x):
    return (x + 1) * 127.5

#Меtrics - The mean-square error (MSE) and the peak signal-to-noise ratio (PSNR)
#are used to compare image compression quality.
def psnr(z1, z2):
    return tf.image.psnr(z1, z2, max_val=255)

def pixel_shuffle(scale):
    return lambda x: tf.nn.depth_to_space(x, scale) #Rearranges data from depth into blocks of spatial data.

In [None]:
low_res_size = 24
high_res_size = 96

# upSmaple Block
def up_sample(x_in, filters_num):
    x = Conv2D(filters_num, kernel_size=3, padding='same')(x_in)
    x = Lambda(pixel_shuffle(scale=2))(x)
    return PReLU(shared_axes=[1, 2])(x)

# res_block
def resnet_block(x_in, filters_num, momentum=0.7):
    x = Conv2D(filters_num, kernel_size=3, padding='same')(x_in)
    x = BatchNormalization(momentum=momentum)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    x = Conv2D(filters_num, kernel_size=3, padding='same')(x)
    x = BatchNormalization(momentum=momentum)(x)
    x = Add()([x_in, x])
    return x

# Generator architecture
def generator_sr_resnet(filters_num=64, num_res_blocks=16):
    x_in = Input(shape=(None, None, 3))
    x_g = Lambda(norm01)(x_in)

    x_g = Conv2D(filters_num, kernel_size=9, padding='same')(x_g)
    x_g = x_1 = PReLU(shared_axes=[1, 2])(x_g)

    for _ in range(num_res_blocks):
        x_g = resnet_block(x_g, filters_num)

    x_g = Conv2D(filters_num, kernel_size=3, padding='same')(x_g)
    x_g = BatchNormalization()(x_g)
    x_g = Add()([x_1, x_g])

    x_g = up_sample(x_g, filters_num * 4)
    x_g = up_sample(x_g, filters_num * 4)

    x_g = Conv2D(3, kernel_size=9, padding='same', activation='tanh')(x_g)
    x_g = Lambda(denorm11)(x_g)

    return Model(x_in, x_g)


generator = generator_sr_resnet

#Discriminator block
def discriminator_block(x_in, filters_num, strides=1, batchnorm=True, momentum=0.8):
    x = Conv2D(filters_num, kernel_size=3, strides=strides, padding='same')(x_in)
    if batchnorm:
        x = BatchNormalization(momentum=momentum)(x)
    return LeakyReLU(alpha=0.2)(x)

#Discriminator architecture
def discriminator(filters_num=64):
    x_in = Input(shape=(high_res_size, high_res_size, 3))
    x_d = Lambda(norm11)(x_in)

    x_d = discriminator_block(x_d, filters_num, batchnorm=False)
    x_d = discriminator_block(x_d, filters_num, strides=2)

    x_d = discriminator_block(x_d, filters_num * 2)
    x_d = discriminator_block(x_d, filters_num * 2, strides=2)

    x_d = discriminator_block(x_d, filters_num * 4)
    x_d = discriminator_block(x_d, filters_num * 4, strides=2)

    x_d = discriminator_block(x_d, filters_num * 8)
    x_d = discriminator_block(x_d, filters_num * 8, strides=2)

    x_d = Flatten()(x_d)

    x_d = Dense(1024)(x_d)
    x_d = LeakyReLU(alpha=0.2)(x_d)
    x_d = Dense(1, activation='sigmoid')(x_d)

    return Model(x_in, x_d)


# Model Training

In [None]:
# Train the generator
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'/content/drive/MyDrive/GenAIBOOK/SRGAN/.ckpt/pre_generator')
pre_trainer.train(train_ds,
                  valid_ds.take(100),
                  steps=5000,
                  evaluate_every=100
                  )

os.makedirs('/content/drive/MyDrive/GenAIBOOK/SRGAN/weights/srgan/', exist_ok=True)
pre_trainer.model.save_weights('/content/drive/MyDrive/GenAIBOOK/SRGAN/weights/srgan/pre_generator.h5')

In [None]:
gan_generator = generator() # Create Generator
gan_generator.load_weights('/content/drive/MyDrive/GenAIBOOK/SRGAN/weights/srgan/pre_generator.h5') # loading pre trained generator
gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator()) # Instantiate SR-GAN trainer
gan_trainer.train(train_ds, steps=500) # SR GAN Training

200/500, perceptual loss = 0.1516, discriminator loss = 0.9117
400/500, perceptual loss = 0.1471, discriminator loss = 0.1896


In [None]:
# Save model weights
gan_trainer.generator.save_weights('/content/drive/MyDrive/GenAIBOOK/SRGAN/weights/srgan/gan_generator.h5')
gan_trainer.discriminator.save_weights('/content/drive/MyDrive/GenAIBOOK/SRGAN/weights/srgan/gan_discriminator.h5')

# Test Model

In [None]:
gan_generator = generator()
#gan_generator.load_weights('/content/drive/MyDrive/GenAIBOOK/SRGAN/weights/srgan/gan_generator.h5')

#Test Visualization

In [None]:
def resolve_and_plot(lr_path):
    lr = np.array(Image.open(lr_path))
    gan_sr = resolve(gan_generator, tf.expand_dims(lr, axis=0))[0]
    plt.figure(figsize=(20, 10))
    images = [lr, gan_sr]
    titles = ['LR', 'SR(GAN)']

    for i, (img, title) in enumerate(zip(images, titles)):
        plt.subplot(1, 2, i+1)
        plt.imshow(img)
        plt.title(title)
        plt.xticks([])
        plt.yticks([])

In [None]:
for file in os.listdir('/content/drive/MyDrive/GenAIBOOK/SRGAN/test_image/'):
  resolve_and_plot('/content/drive/MyDrive/GenAIBOOK/SRGAN/test_image/' + file)