In [None]:
# Enable mixed precision training
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')


import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, LeakyReLU, PReLU, Add, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG19
import numpy as np
from tensorflow.keras.backend import clear_session
import gc

In [None]:
pip install memory-profiler

In [None]:
from memory_profiler import profile

In [None]:
def residual_block(x, filters):
    # Save the input as a skip connection
    skip = x
    
    # First convolutional layer
    x = Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    
    # Second convolutional layer
    x = Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(momentum=0.8)(x)
    
    # Add the skip connection (residual)
    x = Add()([skip, x])
    
    return x


In [None]:
def build_generator():
    #Input Shape: The input to the generator is an image with three color channels (RGB). The None values indicate that the height and width of the image can be any size.
    input_layer = Input(shape=(None, None, 3))

    # Pre-residual block

    #Convolutional Layer: The first layer of the generator is a convolutional layer with 64 filters, a 9x9 kernel size, and a stride of 1. This layer extracts low-level features from the input image.
    #PReLU Activation: PReLU (Parametric ReLU) is an activation function that introduces non-linearity, allowing the model to learn more complex patterns. The shared_axes=[1, 2] ensures that the parameters are shared across the spatial dimensions (height and width).
    x = Conv2D(64, kernel_size=9, strides=1, padding='same')(input_layer)
    x = PReLU(shared_axes=[1, 2])(x)
    residual = x

    # Residual blocks

    #Residual Connection: The initial convolutional output x is saved as residual to be used later in the network.
    #Residual Blocks: The generator uses 16 residual blocks. Each block adds a skip connection, allowing the network to learn residual features instead of direct mappings. These blocks enhance the capacity of the generator to capture finer details without losing the original low-resolution features.
    for _ in range(16):
        residual = residual_block(residual, 64)

    # Post-residual block

    #Convolutional Layer: After passing through the residual blocks, the output is passed through another convolutional layer with a 3x3 kernel size and 64 filters.
    #Batch Normalization: Normalizes the activations to improve stability and convergence during training.
    #Skip Connection: The output of this layer is added back to the original residual (input to the residual blocks). This step further emphasizes the learned residuals.
    x = Conv2D(64, kernel_size=3, strides=1, padding='same')(residual)
    x = BatchNormalization(momentum=0.8)(x)
    x = Add()([x, residual])

    # Upsampling blocks

    #Upsampling: To upscale the image, two upsampling blocks are used. Each block doubles the spatial dimensions (height and width) of the image using bilinear interpolation.
    #Convolutional Layer: After upsampling, the output is passed through a convolutional layer with 256 filters and a 3x3 kernel size. This layer helps in refining the upscaled image.
    #PReLU Activation: Adds non-linearity after upsampling to help the model learn complex features.
    for _ in range(2):
        x = UpSampling2D(size=2)(x)
        x = Conv2D(128, kernel_size=3, strides=1, padding='same')(x)
        x = PReLU(shared_axes=[1, 2])(x)


    #Final Convolutional Layer: The final layer of the generator is a convolutional layer with 3 filters (one for each RGB channel) and a 9x9 kernel size. This layer generates the final high-resolution image.
    #Tanh Activation: The tanh activation function is used to ensure that the output pixel values are in the range [-1, 1], which is typical for image data normalization.
    output_layer = Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(x)
    
    return Model(input_layer, output_layer)


In [None]:
def build_discriminator():
    input_layer = Input(shape=(None, None, 3))
    x = input_layer
    def conv_block(x, filters, strides=1, bn=True):
        x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(x)
        if bn:
            x = BatchNormalization(momentum=0.8)(x)
        x = LeakyReLU(alpha=0.2)(x)
        return x

    x = conv_block(input_layer, 64, bn=False)
    x = conv_block(x, 64, strides=2)
    x = conv_block(x, 128)
    x = conv_block(x, 128, strides=2)
#     x = conv_block(x, 256)
#     x = conv_block(x, 256, strides=2)
#     x = conv_block(x, 512)
#     x = conv_block(x, 512, strides=2)

    # Calculate flattened size after defining input shape
    x = tf.keras.layers.GlobalAveragePooling2D()(x)  # Flatten using GlobalAveragePooling2D
    flattened_size = x.shape[1]  # Get the flattened size
    # x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1024, input_shape=(flattened_size,))(x)  # Use the calculated flattened size
    x = LeakyReLU(alpha=0.2)(x)
    output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(x)

    return Model(input_layer, output_layer)


In [None]:
def build_vgg():
    vgg = VGG19(weights="imagenet", include_top=False)
    vgg.trainable = False
    model = Model(inputs=vgg.input, outputs=vgg.get_layer("block5_conv4").output)
    return model

In [None]:
def compile_srgan(generator, discriminator, vgg):
    # Compile the discriminator
    discriminator.compile(optimizer=tf.keras.optimizers.Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])

    # Build and compile the SRGAN
    input_lr = Input(shape=(None, None, 3))
    generated_hr = generator(input_lr)
    features = vgg(generated_hr)
    discriminator.trainable = False
    validity = discriminator(generated_hr)

    srgan_model = Model(inputs=input_lr, outputs=[validity, features])
    srgan_model.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=tf.keras.optimizers.Adam(0.0002, 0.5))

    return srgan_model


In [11]:
@profile
def train_srgan(generator, discriminator, srgan_model, epochs, batch_size, train_data):
    for epoch in range(epochs):
        clear_session()
        gc.collect()
        for batch in range(len(train_data) // batch_size):
            # Select a random batch of images
            idx = np.random.randint(0, len(train_data), batch_size)
            imgs_lr, imgs_hr = zip(*[train_data[i] for i in idx])

            imgs_lr = np.array(imgs_lr)
            imgs_hr = np.array(imgs_hr)

            # Generate high-resolution images
            generated_hr = generator.predict(imgs_lr)

            # Train the discriminator
            real = np.ones((batch_size, 1))
            fake = np.zeros((batch_size, 1))

            d_loss_real = discriminator.train_on_batch(imgs_hr, real)
            d_loss_fake = discriminator.train_on_batch(generated_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Train the generator
            image_features = vgg.predict(imgs_hr)
            g_loss = srgan_model.train_on_batch(imgs_lr, [real, image_features])

        print(f"Epoch: {epoch+1}, D Loss: {d_loss}, G Loss: {g_loss}")

        if (epoch + 1) % 200 == 0:
            generator.save(f'srgan_generator_epoch_{epoch+1}.keras')


In [None]:
# Instantiate the models
generator = build_generator()
discriminator = build_discriminator()
vgg = build_vgg()
generator.summary()
discriminator.summary()
# Compile the SRGAN
srgan_model = compile_srgan(generator, discriminator, vgg)


In [None]:
srgan_model.summary()

In [None]:
import os
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array, load_img

def load_images(lr_dir, hr_dir, img_size):
    lr_images = sorted(os.listdir(lr_dir))
    hr_images = sorted(os.listdir(hr_dir))
    
    # Slice the lists to get only the first n images
    lr_images_n = lr_images[:855]
    hr_images_n = hr_images[:855]

    train_data = []

    for lr_img, hr_img in zip(lr_images_n, hr_images_n):
        # Load low-res and high-res images
        lr_image = load_img(os.path.join(lr_dir, lr_img), target_size=img_size)
        hr_image = load_img(os.path.join(hr_dir, hr_img), target_size=(img_size[0] * 4, img_size[1] * 4))

        # Convert to arrays and normalize
        lr_image = img_to_array(lr_image) / 255.0
        hr_image = img_to_array(hr_image) / 255.0

        # Append the pair (LR, HR) to the training data list
        train_data.append((lr_image, hr_image))

    return train_data


In [None]:
# Specify the directories and image size for low-resolution images
img_size = (64, 64)  # Adjust as needed
lr_dir = '/kaggle/input/raw-data/Raw Data/low_res'
hr_dir = '/kaggle/input/raw-data/Raw Data/high_res'

# Load images and create the train_data array
train_data = load_images(lr_dir, hr_dir, img_size)

In [None]:
from tensorflow.keras.preprocessing.image import img_to_array, load_img, array_to_img
# array_to_img(train_data[0][0])
len(train_data)

In [None]:
# Train the SRGAN
import warnings
warnings.filterwarnings('ignore')

from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model

# checkpoint = ModelCheckpoint('srgan_generator_best.keras', monitor='g_loss', verbose=1, save_best_only=True, mode='min')

# generator = load_model('/kaggle/working/srgan_generator_epoch_1.h5')

# warnings.filterwarnings(action='once')
train_srgan(generator, discriminator, srgan_model, epochs=10000, batch_size=4, train_data=train_data)
