# Super-Resolution and Image Restoration (img2img task) with SRRGAN 

*Ciao*! Welcome to this personal project, where I will implement a Generative Adversarial Network (GAN) to my previous [notebook](https://www.kaggle.com/code/marcorosato/super-resolution-and-restoration-of-images-u-net). I even modified the degradation function by adding more **degradation** and differentiating the degradations for **training** by making them random, while using fixed degradations for **validation**.
I decided to use the **DF2K** dataset for training and validation, while **BSDS100** for the testing part.

The goal is to enhance a U-Net generator with an adversarial loss from a discriminator to produce sharper and more realistic images. 

**SRRGAN** stands for "Super-Resolution and Restoration GAN".

## Image Data Lifecycle
To fully understand the pipeline, which includes different functions and normalizations, it's helpful to trace the "journey" of an image. 
### 1. Training & Validation:
Tranining and validation are handled by the "tf.data" pipeline via the "load_and_preprocess" function.

1. **File Path**: the process starts with a string representing the path to an image file (for ex. "/kaggle/input/df2kdata/DF2K_train_HR/000001.png")
2. **Readn & Decode**:
   *  `tf.io.read_file` reads the raw binary data of the image.
   *  `tf.io.decode_png` decodes this data into a tf.Tensor of shape (H, W, 3) with pixel values in the range [0, 255].
   *  `tf.image.convert_image_dtype(img, tf.float32)` converts the data type to float32 and normalizes the pixel values to the range [0, 1].
3. **Patch Extraction and Augmentation**: a **256x256** High-Resolution (HR) patch is extracted from the full image (in training mode, augmentations are applied).
4. **Degradation**: `degradation_pipeline_tf` function is called with the [0, 1] HR patch. All operations are performed directly on tensors:
     * The pipeline stays entirely within the [0, 1] float32 range.
     * Degradations (downsampling, blur, noise, JPEG) are applied using `tf.image` and `keras_cv` operations.
     * The final Low-Resolution (LR) patch is clipped to ensure its values remain within the [0, 1] range.
       
   The output is a 128x128 LR patch as a float32 tensor in the [0, 1] range.
5. **Final Normalization for Model Input**: both the LR and HR patches are scaled to match the model's expected input and output range (a `tanh` acivation function is used)
    * LR Patch (Model Input `X`): scaled from [0, 1] to [-1, 1]
    * HR Patch (Ground Truth `y_true`): Scaled from [0, 1] to [-1, 1]

### 2. Inference & Visualization:
This process generates a super-resolved & restored image from the test set and displays the results.


1. **Prepare Model Input**: load an HR image ([0, 1]) and run it through the degradation pipeline (`degrade_full_image`) to get an LR image ([0, 1]). This is then normalized to [-1, 1] to serve as the input for the generator model.
2. **Model Prediction**: `model.predict()`  is called on the [-1, 1] normalized LR input. Thanks to the `tanh` at the end, the model's output (the predicted SR image) is in range [-1, 1].
3. **Denormalization for Visualization**: `Matplotlib` can directly render float arrays with values in the [0, 1] range. To prepare images for plotting, the `denorm` function is used. It converts tensors from the [-1, 1] range to the [0, 1]
4. **Display**: done by `plt.imshow()` to show a comparison of three images:
    * LR input, converted from [-1, 1] to [0, 1] by the denorm function.
    * SR predicted output, converted from [-1, 1] to [0, 1] by the denorm function.
    * HR ground truth, which was already kept in the [0, 1] range.

In [None]:
# --- Install

!pip install tqdm keras-cv protobuf==3.20.3 -q

In [None]:
# --- LPIPS for Tensorflow
# The standard PyTorch implementation is not used here because its dependencies conflict with the existing CUDA/CudNN setup.
# If you have any suggestion how to implement it, please, let me know!

import sys

!git clone https://github.com/Image-X-Institute/lpips_torch2tf.git

sys.path.append("/kaggle/working/lpips_torch2tf")

In [None]:
# --- Libraries

import os
import shutil
import glob
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import random
import math
import tensorflow as tf
import keras_cv
from types import SimpleNamespace
from matplotlib.gridspec import GridSpec
from dev_src.loss_fns import lpips_base_tf
from tqdm import tqdm
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.models import load_model
from tensorflow.keras import metrics
from tensorflow.keras.saving import register_keras_serializable
from tensorflow.keras.layers import SpectralNormalization
from tensorflow.keras.callbacks import CSVLogger
from sklearn.model_selection import train_test_split

In [None]:
### --- 1 Configuration

# 1.1 Seeds
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

# 1.2 Set the global policy for mixed precision
policy = tf.keras.mixed_precision.Policy("mixed_float16")
tf.keras.mixed_precision.set_global_policy(policy)

# 1.3 Patch dimensions for High-Resolution (HR) and Low-Resolution (LR) images
PATCH_H = 256                          # Height of the HR image patches
PATCH_W = 256                          # Width of the HR image patches
PATCH_LR_H = PATCH_H // 2              # Height of the LR image patches (128)
PATCH_LR_W = PATCH_W // 2              # Width of the LR image patches (128)
UPSCALE_FACTOR = 2                     # scale factor (2x super-resolution)
CH = 3                                 # RGB

# 1.4 Training
BATCH_SIZE      = 64
EPOCHS          = 100
SAMPLE_SIZE     = None                 # None = full data

# 1.5 Loss function weights 
PERCEPTUAL_WEIGHT_PRE = 1e-3

PIXEL_WEIGHT_GAN = 20.0
PERCEPTUAL_WEIGHT_GAN = 5e-2
GLOBAL_GAN_WEIGHT = 5e-2

# Data Loading and Preprocessing
This section handles the steps of preparing the dataset for the model task. The pipeline is designed by creating paired low-resolution (LR) and high-resolution (HR) image patches for training.

The key stages are:
1.  **Path Collection and Splitting:** Gather the file paths of all images and split them into training, validation, and test sets.
2.  **Dataset Sanity Check:** Verify the minimum dimensions of the images to ensure the patch extraction is valid.
3.  **Degradation Function:** A custom function is defined to simulate real-world image degradation and creating realistic LR images from HR patches. This includes blur, downsampling, noise and JPEG compression.
4.  **"tf.data" Pipeline:** An input pipeline using "tf.data.Dataset" that handles image loading, augmentation, degradation, batching, and prefetching.

## Image Path Collection and Dataset Splitting
Scan the dataset directory to collect the file paths of all available PNG images. 
-   **Training Set:** DF2K_train_HR.
-   **Validation Set:** DF2K_valid_HR.
-   **Test Set:** BSDS100.

In [None]:
# --- 2 Collection of image paths

train_folder = "/kaggle/input/df2kdata/DF2K_train_HR"
vali_folder = "/kaggle/input/df2kdata/DF2K_valid_HR"
test_folder = "/kaggle/input/super-resolution-benchmarks/BSDS100/BSDS100"

train_paths = []
val_paths = []
test_paths =[]

train = os.path.join(train_folder, f"*.png")            # All the images are png format
for f in glob.glob(train):
        if os.path.getsize(f) > 0:                      # Ensure that each file is not empty (>0)
           train_paths.append(f)

validation = os.path.join(vali_folder, f"*.png")        
for f in glob.glob(validation):
        if os.path.getsize(f) > 0:                      
           val_paths.append(f)

test = os.path.join(test_folder, f"*.png")              
for f in glob.glob(test):
        if os.path.getsize(f) > 0:                      
           test_paths.append(f)

print(f"Training images from DF2K train: {len(train_paths)}")
print(f"Validation images from DF2K val: {len(val_paths)}")
print(f"Test images from BSDS100: {len(test_paths)}")

## Dataset Sanity Check
Perform a sanity check on the image dimensions across the entire DF2K dataset. This step iterates through all images to find the minimum height and width. This is done to confirm that my chosen HR patch size (256x256) can be safely extracted from every image.

In [None]:
# --- 3 Minimum height and width between all images.
# This step helps to understand the variability in image sizes and ensure that the chosen patch size (256x256) is feasible
# by checking if all images are at least as large as the desired patch size.

min_height = np.inf
min_width  = np.inf

for path in train_paths + val_paths:
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)      # Read image by keeping its original channel count
    if img is None:
        print(f"Can't read this image: {path}")
        continue

    h, w = img.shape[:2]                              #Get height and width
    if h < min_height:
        min_height = h
    if w < min_width:
        min_width = w

if min_height == np.inf or min_width == np.inf:
    print("No valid image")
else:
    print(f"Minimum hight between all images : {min_height}")
    print(f"Minimum width between all images : {min_width}")

## Degradation Function - Blind Restoration Task
To train the super-resolution model, we require pairs of low-resolution (LR) and high-resolution (HR) images. This is accomplished by the degradation_pipeline_tf, a function that synthesizes a LR image from an HR source tensor (in the range [0, 1]).

A key design principle of this pipeline is its dual-mode operation (controlled by an is_training flag):


* Training Mode (Stochastic): When training, the pipeline applies a sequence of degradations with randomized parameters and even randomized order. This serves to improve generalization capabilities.

* Validation/Testing Mode (Deterministic): During evaluation, the pipeline uses a fixed set of parameters and a deterministic order. This ensures that the model's performance is measured consistently and reproducibly across epochs.


The degradation sequence is as follows:


1. **Blur (First-Order Degradation):** This step simulates physical lens imperfections.
    * *Training:* A blur type is chosen randomly (50% chance each):
        * Gaussian Blur: Uses a random kernel size (from 3x3 to 7x7) and a random sigma factor (from 0.2 to 1.2).
        * Box Blur: Implemented via a depthwise convolution with a random odd-sized kernel (from 3x3 to 5x5).
    * *Validation:* A single, fixed Gaussian blur is applied (5x5 kernel, sigma 1.0).
2. **Downsampling:** The blurred HR image is downscaled (x2)
    * *Training:* The interpolation method is chosen randomly from Bilinear, Bicubic, and Area to simulate different downsampling algorithms.
    * *Validation:* A fixed Area interpolation is used, which is generally preferred for downscaling as it prevents aliasing.
3. **Post-Downsampling Degradations (Second-Order):** After downsampling, a sequence of noise and compression artifacts is applied. The order of these operations is also randomized during training.
    * *Training:*
       * Order: Randomly applies either (Noise → JPEG) or (JPEG → Noise).
       * Noise: Randomly adds one of two types:
            * Gaussian Noise: With a random standard deviation (from 1/255 to 15/255).
            * Poisson Noise: Simulates shot noise with a random scale factor (from 10.0 to 60.0).
       * JPEG Compression: Applied with an 80% probability. If applied, the quality is a random integer between 60 and 95.
    * *Validation:*
        * Order: A fixed order is used: Noise → JPEG.
        * Noise: Fixed Gaussian noise is added (standard deviation of 10/255).
        * JPEG Compression: Fixed JPEG compression is applied with a quality of 75.
4. **Clipping:** The final LR image's pixel values are clipped to the [0, 1] range to ensure a valid output.

In [None]:
# --- 4 Degradation Function for TRAINING/VALIDATION
# This function creates a low-resolution image from a high-resolution one.
# It takes an HR image tensor in the [0, 1] range and returns a degraded LR version, also in [0, 1].

# Random blur inizialization
random_blur_layer = keras_cv.layers.RandomGaussianBlur(
    kernel_size = (3, 7),   # KerasCV layer picks a random kernel size...
    factor = (0.2, 1.2)     # ... And sigma for variety
)

# Fixed blur inizialization
fixed_blur_layer = keras_cv.layers.RandomGaussianBlur(
    kernel_size=(5, 5), # A fixed kernel size
    factor=1.0          # A fixed sigma factor
)

# Box blur
box_blur_kernel_size = (3, 5)


@tf.function
def degradation_pipeline_tf(hr_patch, is_training):
    """
    Applies a degradation pipeline that is random for training and
    deterministic for validation.
    """
    # Ensure is_training is a TensorFlow boolean tensor for tf.cond
    is_training = tf.convert_to_tensor(is_training, dtype=tf.bool)


    # --- 4.1 Blur 
    # 4.1.1 Gaussian blur (lens softness)
    def apply_random_gaussian_blur(img):
        # The KerasCV layer needs a batch dimension. (add here and then removed)
        original_dtype = img.dtype
        img_batched = tf.expand_dims(img, axis=0)
        blurred_img = random_blur_layer(img_batched, training=True)
        return tf.cast(blurred_img[0], dtype=original_dtype)
    
    # 4.1.2 Box blur
    def apply_random_box_blur(image):
    
        # 1. Generate the random kernel size as a TensorFlow tensor
        kernel_size = tf.random.uniform((), 
                                    minval=box_blur_kernel_size[0], 
                                    maxval=box_blur_kernel_size[1], 
                                    dtype=tf.int32)
        # Ensure the kernel size is odd
        kernel_size = kernel_size + (1 - kernel_size % 2) 
    
        # 2. Get the number of channels from the image shape
        channels = tf.shape(image)[-1]
    
        # 3. Create the box blur kernel dynamically
        # The value for each pixel in the kernel should be 1 / (size*size) to average
        kernel_value = 1.0 / tf.cast(kernel_size * kernel_size, image.dtype)
        # Create the kernel tensor for depthwise convolution
        # Shape is [height, width, in_channels, channel_multiplier=1]
        kernel = tf.fill([kernel_size, kernel_size, channels, 1], kernel_value)
    
        # Add a batch dimension to the input image for the convolution op
        image_batched = tf.expand_dims(image, axis=0)
    
        # 4. Apply the convolution
        blurred_image = tf.nn.depthwise_conv2d(
            image_batched,
            kernel,
            strides=[1, 1, 1, 1],
            padding="SAME"
        )
    
        # Remove the batch dimension before returning
        return tf.squeeze(blurred_image, axis=0)
    
    
    # 4.1.3 Fixed Gaussian blur
    def apply_fixed_blur(img):
        original_dtype = img.dtype
        img_batched = tf.expand_dims(img, axis=0)
        blurred_img = fixed_blur_layer(img_batched, training=False) # Ensure it's deterministic
        return tf.cast(blurred_img[0], dtype=original_dtype)
    
    # 4.1.4 Choice of what blur to apply
    def apply_randomized_blur(img):
        blur_type = tf.random.uniform(())
        return tf.cond(blur_type < 0.5,     # 50% gaussian blur and 50% box blur
                       lambda: apply_random_gaussian_blur(img),
                       lambda: apply_random_box_blur(img))
    
    # 4.1.5 Apply the blur to the HR patch
    hr_patch = tf.cond(is_training,
                       lambda: apply_randomized_blur(hr_patch),  # If training is True
                       lambda: apply_fixed_blur(hr_patch))       # If training is False
    
    
    
    
    
    # --- 4.2 Downsample the image by the upscale factor using area interpolation (Randomized)
    lr_h = PATCH_H // UPSCALE_FACTOR
    lr_w = PATCH_W // UPSCALE_FACTOR
    
    def random_resize(img):
        rand_interpo = tf.random.uniform((), 0, 3, dtype=tf.int32)
        # Use the pre-calculated static dimensions
        return tf.switch_case(rand_interpo, {
            0: lambda: tf.image.resize(img, [lr_h, lr_w], method=tf.image.ResizeMethod.BILINEAR),
            1: lambda: tf.image.resize(img, [lr_h, lr_w], method=tf.image.ResizeMethod.BICUBIC),
            2: lambda: tf.image.resize(img, [lr_h, lr_w], method=tf.image.ResizeMethod.AREA),
        })
    
    def fixed_resize(img):
        # Use a consistent for validation, with the pre-calculated static dimensions
        return tf.image.resize(img, [lr_h, lr_w], method=tf.image.ResizeMethod.AREA)
    
    lr_patch = tf.cond(is_training,
                       lambda: random_resize(hr_patch),
                       lambda: fixed_resize(hr_patch))
    
    
    
    # --- 4.3 Random order Noise (Gaussian and Poisson) and JPEG
    # --- 4.3.1 Noise for training
    def apply_random_noise(img):
        # 4.3.1.1 Gaussian
        def add_gaussian_noise():
            noise_std = tf.random.uniform([], 1/255.0, 15/255.0)
            noise = tf.random.normal(tf.shape(img), 0.0, noise_std, dtype=img.dtype)
            return img + noise
        # 4.3.1.2 Poisson
        def add_poisson_noise():
            noise_scale = tf.random.uniform([], 10.0, 60.0)
            noisy_image = tf.random.poisson(shape=[], lam=img * noise_scale) / noise_scale
            return tf.cast(noisy_image, img.dtype)
    
        return tf.cond(tf.random.uniform(()) < 0.5,
                       true_fn=add_gaussian_noise,
                       false_fn=add_poisson_noise)
    # 4.3.2 Noise for val
    def apply_fixed_noise(img):
        noise_std = tf.constant(10/255.0)
        noise = tf.random.normal(tf.shape(img), 0.0, noise_std, dtype=img.dtype)
        return img + noise
    
    # --- 4.3.3 JPEG 
    # 4.3.3.1 JPEG for training
    def apply_random_jpeg(img):
        img_float32 = tf.cast(img, tf.float32)
        img_uint8 = tf.image.convert_image_dtype(img_float32, tf.uint8, saturate=True)
        def apply_jpeg_with_known_shape():
            jpeg_img = tf.image.random_jpeg_quality(img_uint8, min_jpeg_quality=60, max_jpeg_quality=95)
            return tf.ensure_shape(jpeg_img, img_uint8.shape)
    
        img_jpeg = tf.cond(tf.random.uniform([], 0.0, 1.0) < 0.8,
                           true_fn=apply_jpeg_with_known_shape,
                           false_fn=lambda: img_uint8)
        return tf.image.convert_image_dtype(img_jpeg, img.dtype)
    
    # 4.3.3.2 JPEG for val
    def apply_fixed_jpeg(img):
        jpeg_quality = 75
        image_float32 = tf.cast(img, tf.float32)
        img_uint8 = tf.image.convert_image_dtype(image_float32, tf.uint8, saturate=True)
        img_jpeg = tf.image.adjust_jpeg_quality(img_uint8, jpeg_quality=jpeg_quality)
        return tf.image.convert_image_dtype(img_jpeg, img.dtype)
        
    # --- 4.4 Logic for random order
    def apply_post_degradations_randomized(img):
        # 4.4.1 Noise and then Jpeg
        def order_noise_then_jpeg(x):
            x = apply_random_noise(x)
            x = apply_random_jpeg(x)
            return x
        # 4.4.2 Jpeg and then Noise
        def order_jpeg_then_noise(x):
            x = apply_random_jpeg(x)
            x = apply_random_noise(x)
            return x
        
        # 4.4.3 50% one of these order
        return tf.cond(tf.random.uniform(()) < 0.5,
                       lambda: order_noise_then_jpeg(img),
                       lambda: order_jpeg_then_noise(img))
    
    def apply_post_degradations_fixed(img):
        # Fixed for valiidation
        img = apply_fixed_noise(img)
        img = apply_fixed_jpeg(img)
        return img
    
    # 4.4.4 Apply post downsample degradation
    lr_patch = tf.cond(is_training,
                       lambda: apply_post_degradations_randomized(lr_patch),
                       lambda: apply_post_degradations_fixed(lr_patch))
    
    
    
    
    
    # 4.5 Clip the final result to ensure all pixel values are in the valid [0, 1] range
    lr_patch_final = tf.clip_by_value(lr_patch, 0.0, 1.0)
    
    return lr_patch_final

In [None]:
# --- 4.6 Degradation Function for TEST: Inference & Final Visualization

@tf.function
def degrade_full_image(hr_image):
    """
    Applies a deterministic degradation pipeline to a full-sized HR image,
    preserving its aspect ratio.
    """
    # 1. Calculate target LR dimensions based on the input HR image shape
    hr_shape = tf.shape(hr_image)
    lr_h = hr_shape[0] // UPSCALE_FACTOR
    lr_w = hr_shape[1] // UPSCALE_FACTOR

    # 2. Apply the same FIXED blur as in the validation pipeline
    def apply_fixed_blur(img):
        original_dtype = img.dtype
        img_batched = tf.expand_dims(img, axis=0)
        blurred_img = fixed_blur_layer(img_batched, training=False)
        return tf.cast(blurred_img[0], dtype=original_dtype)
    
    hr_blurred = apply_fixed_blur(hr_image)

    # 3. Downsample using the CALCULATED dimensions to preserve aspect ratio
    lr_patch = tf.image.resize(hr_blurred, [lr_h, lr_w], method=tf.image.ResizeMethod.AREA)

    # 4. Apply the same FIXED post-degradations (noise -> JPEG)
    def apply_fixed_noise(img):
        noise_std = tf.constant(10/255.0)
        noise = tf.random.normal(tf.shape(img), 0.0, noise_std, dtype=img.dtype)
        return img + noise
    
    def apply_fixed_jpeg(img):
        jpeg_quality = 75
        image_float32 = tf.cast(img, tf.float32)
        img_uint8 = tf.image.convert_image_dtype(image_float32, tf.uint8, saturate=True)
        img_jpeg = tf.image.adjust_jpeg_quality(img_uint8, jpeg_quality=jpeg_quality)
        return tf.image.convert_image_dtype(img_jpeg, img.dtype)

    lr_patch = apply_fixed_noise(lr_patch)
    lr_patch = apply_fixed_jpeg(lr_patch)
    
    # 5. Clip the final result
    lr_patch_final = tf.clip_by_value(lr_patch, 0.0, 1.0)
    
    return lr_patch_final

In [None]:
# --- 5 Visualization of the degradation function

%matplotlib inline

# 5.1 Select a random image from train
random_path = random.choice(train_paths)
hr_full = tf.image.convert_image_dtype(tf.io.decode_png(tf.io.read_file(random_path), channels=CH), tf.float32)
hr_patch = tf.image.resize_with_crop_or_pad(hr_full, PATCH_H, PATCH_W)

# 5.2 Apply  degradation
lr_patch_degraded = degradation_pipeline_tf(hr_patch, is_training=tf.constant(True))

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("HR full [0, 1]")
plt.imshow(hr_full.numpy().astype('float32'))
plt.axis("off")
plt.subplot(1, 3, 2)
plt.title("HR Patch [0, 1]")
plt.imshow(hr_patch.numpy().astype('float32'))
plt.axis("off")
plt.subplot(1, 3, 3)
plt.title("LR Patch [0, 1] (degraded)")
plt.imshow(lr_patch_degraded.numpy().astype('float32'))
plt.axis("off")
plt.tight_layout()
plt.show()

## "tf.data" Preprocessing Function

The `load_and_preprocess_train` and `load_and_preprocess_eval` functions form the core of the data pipeline. The first function processes the training data with augmentations, while the second handles validation/test data with deterministic cropping.

 1.  **Load and Normalize to [0, 1]:** Reads an image file and converts it to float32 in the [0, 1] range.
 2.  **Patch Extraction and Augmentation (ONLY TRAINING):** Extracts an HR patch ([0, 1]).
 3.  **Degradation:** The HR patch is passed to our degradation_pipeline_tf function to generate the corresponding LR patch ([0, 1]).
 4.  **Final Normalization to [-1, 1]:** Both the LR and HR patches are normalized to the [-1, 1] range, which is the format the model expects for training.

In [None]:
# --- 6 Data loading and preprocessing functions

# --- This function is ONLY for the training dataset
@tf.function
def load_and_preprocess_train(path):
    # 6.1 Read and decode HR image to [0, 1]
    img = tf.io.read_file(path)
    img = tf.io.decode_png(img, channels=CH)
    img = tf.image.convert_image_dtype(img, tf.float32)


    # --- DATA AUGMENTATION ---
    
    # 6.2 Extract HR 256×256 patch and apply augmentations (still [0, 1])
    # Crop
    hr_patch = tf.image.random_crop(img, [PATCH_H, PATCH_W, CH])
    # Horizontal Flip         
    hr_patch = tf.image.random_flip_left_right(hr_patch)
    # Vertical Flip
    hr_patch = tf.image.random_flip_up_down(hr_patch)
    # Rotation (0, 90, 180, 270 degrees)
    k = tf.random.uniform([], 0, 4, tf.int32)
    hr_patch = tf.image.rot90(hr_patch, k)
    # --- Colour Augmentations 
    hr_patch = tf.image.random_brightness(hr_patch, max_delta=0.05)
    hr_patch = tf.image.random_contrast(hr_patch, lower=0.9, upper=1.1)
    hr_patch = tf.image.random_saturation(hr_patch, lower=0.9, upper=1.1)
    hr_patch = tf.image.random_hue(hr_patch, max_delta=0.06)
    hr_patch = tf.clip_by_value(hr_patch, 0.0, 1.0)
    
    # --- END DATA AUGMENTATION ---
    
    
    # 6.3 Generate LR patch [0, 1] using the degradation pipeline (with is_training=True)
    lr_patch_0_1 = degradation_pipeline_tf(hr_patch, is_training=tf.constant(True))

    # 6.4 Normalize both patches to [-1, 1] for the model
    lr_patch = (lr_patch_0_1 * 2.0) - 1.0
    hr_patch_norm = (hr_patch * 2.0) - 1.0

    # 6.5 Set static shapes
    lr_patch.set_shape([PATCH_LR_H, PATCH_LR_W, CH])
    hr_patch_norm.set_shape([PATCH_H, PATCH_W, CH])

    return lr_patch, hr_patch_norm




# --- This function is ONLY for the validation/test dataset
@tf.function
def load_and_preprocess_eval(path):
    # 6.1 Read and decode HR image to [0, 1]
    img = tf.io.read_file(path)
    img = tf.io.decode_png(img, channels=CH)
    img = tf.image.convert_image_dtype(img, tf.float32)

    # 6.2 For validation/testing, just do a center crop
    hr_patch = tf.image.resize_with_crop_or_pad(img, PATCH_H, PATCH_W)

    # 6.3 Generate LR patch [0, 1] using the degradation pipeline (with is_training=False)
    lr_patch_0_1 = degradation_pipeline_tf(hr_patch, is_training=tf.constant(False))

    # 6.4 Normalize both patches to [-1, 1] for the model
    lr_patch = (lr_patch_0_1 * 2.0) - 1.0
    hr_patch_norm = (hr_patch * 2.0) - 1.0

    # 6.5 Set static shapes
    lr_patch.set_shape([PATCH_LR_H, PATCH_LR_W, CH])
    hr_patch_norm.set_shape([PATCH_H, PATCH_W, CH])

    return lr_patch, hr_patch_norm

## Building the Datasets
Using the preprocessing function defined above, I now construct the final datasets for training, validation, and testing.

In [None]:
# --- 7 Dataset creation

def make_dataset(paths, training):
    ds = tf.data.Dataset.from_tensor_slices(paths)
    if training:
        ds = ds.shuffle(buffer_size=len(paths))
        # Use the training preprocessor
        map_func = load_and_preprocess_train
    else:
        # Use the evaluation preprocessor
        map_func = load_and_preprocess_eval
        
    ds = ds.map(map_func, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(BATCH_SIZE)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

if SAMPLE_SIZE is not None:
    train_paths = train_paths[:SAMPLE_SIZE]
    val_paths   = val_paths[:SAMPLE_SIZE//5]
    test_paths  = test_paths[:SAMPLE_SIZE//5]

# Creation
train_ds = make_dataset(train_paths, training=True)
val_ds   = make_dataset(val_paths,   training=False)
test_ds  = make_dataset(test_paths,  training=False)

print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")
print(train_ds)
print(val_ds)
print(test_ds)

# Model Architectures

### Generator (U-Net)
The U-Net is a powerful convolutional neural network characterized by its U-shaped encoder-decoder architecture. Its unique feature are the skip connections: they connect encoder layers with corresponding decoder layers, allowing the recovery of spatial information lost during the downsampling phase.

My model utilizes a "hybrid" U-Net architecture. The standard encoder-decoder structure with skip connections is used for its ability in image restoration tasks, allowing for the fusion of features from multiple spatial scales. To enhance its super-resolution capabilities, I replaced the traditional simple bottleneck with a deep stack of residual blocks. This design is inspired by state-of-the-art super-resolution networks and acts as a feature extractor, focusing on learning the residual high-frequency information.

Key features include:
* **Squeeze-and-Excitation (channel attention):** this block learns to give more "weight" to the most important channels (features);
* **Spatial Attention (spatial attention):** highlights spatial locations (helps the model focus on edges and textures);
* **PixelShuffle:** An enhanced upsampling method compared to Conv2DTranspose;
* **Global residaul:** A final residual connection where a simple upsampled version of the input is added to the network's output.


### Discriminator (PatchGAN)
The discriminator's job is to distinguish between real HR images and the fake SR images produced by the generator. Instead of classifying the entire image with a single "real" or "fake" label, a **PatchGAN** discriminator classifies N x N patches of the input image. This encourages the generator to produce realistic details across the entire image, rather than just getting the global structure right. The output is a feature map where each "pixel" corresponds to a verdict on a patch of the original image.

In [None]:
# --- 8.1 Generator (U-Net) Architecture

NUM_HEAD_BLOCKS = 4
NUM_BOTTLENECK_BLOCKS = 5
NUM_TAIL_BLOCKS = 2

# Squeeze-and-Excitation. 
# Squeeze-and-Excitation Networks" (SENet) by "Jie Hu et al., 2018 (CVPR)" 
def se_block(input_tensor, ratio=8):
    channels = input_tensor.shape[-1]
    se = layers.GlobalAveragePooling2D()(input_tensor)                                   # Squeeze: Aggregates spatial information into a single value per channel
    se = layers.Reshape((1, 1, channels))(se)                                            # Vector of shape (1,1,C)
    se = layers.Conv2D(channels // ratio, 1, activation = "relu", use_bias=True)(se)     # Excitation: Learn a nonlinear gating function to weight the channels.
    se = layers.Conv2D(channels, 1, activation = "sigmoid", use_bias=True)(se)           # 1x1 convolution that restores the number of channels back to the original C
    return layers.Multiply()([input_tensor, se])                                         # Apply weights to the input tensor


# Spatial Attention Block
# CBAM: Convolutional Block Attention Module by "Woo, S., Park, J., Lee, JY., Kweon, I.S. (2018)"
@register_keras_serializable()
class SpatialAttentionBlock(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.concat = layers.Concatenate(axis=3)
        self.multiply = layers.Multiply()

    def build(self, input_shape):
        self.conv = layers.Conv2D(
            filters=1, 
            kernel_size=7, 
            padding='same', 
            activation='sigmoid'
        )
        super().build(input_shape)

    def call(self, inputs):
        avg_pool = tf.reduce_mean(inputs, axis=3, keepdims=True)       # Average pool
        max_pool = tf.reduce_max(inputs, axis=3, keepdims=True)        # Max pool
        concat = self.concat([avg_pool, max_pool])                     # Concatenate
        attention_map = self.conv(concat)                              # Convolution
        return self.multiply([inputs, attention_map])

    def get_config(self):                                              # Save
        base_config = super().get_config()
        return base_config

        
# PixelShuffle
# Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network by "Wenzhe Shi et al., 2016 (CVPR)"
@register_keras_serializable()
class PixelShuffle(layers.Layer): 
    def __init__(self, block_size, **kwargs):                       # Store the configuration of the layer (runs only once)
        super().__init__(**kwargs)
        self.block_size = block_size                                # It takes the block_size (upscale factor) and saves it as self.block_size so it can be used later

    def call(self, inputs):                                         # Runs every time you pass a tensor through the layer
        return tf.nn.depth_to_space(inputs, self.block_size)        # This is the upsample part (uses the self.block_size that was saved during __init__)

    def get_config(self):                                           # Saving the Model (model.save())
        config = super().get_config()
        config.update({"block_size": self.block_size})              # It returns a dictionary of the layer's configuration
        return config


# Attention block in Head & Tail
def head_tail_block(x, filters=64):
    identity = x
    
    # Process 
    x = layers.Conv2D(filters, 3, padding="same", use_bias=False)(x)
    x = layers.GroupNormalization(groups=32)(x)
    x = layers.LeakyReLU(0.1)(x)
    
    # Attention
    x = se_block(x)
    x = SpatialAttentionBlock()(x)
    
    # Residual connection
    return layers.Add()([identity, x])





# Encoder 
def enc_block(x, filters):
    x = layers.Conv2D(filters, 3, padding = "same", use_bias=False)(x)
    x = layers.GroupNormalization(groups=32)(x)
    x = layers.LeakyReLU(0.1)(x)
    x = layers.Conv2D(filters, 3, strides = 2, padding = "same", use_bias=False)(x)      # "stride=2" -> downsample
    x = layers.GroupNormalization(groups=32)(x)
    x = layers.LeakyReLU(0.1)(x)
    x = se_block(x)
    return x

# Bottleneck
def bot_block(x, filters):
    identity = x
    b = layers.Conv2D(filters, 3, padding = "same", use_bias=False)(x)
    b = layers.GroupNormalization(groups=32)(b)
    b = layers.LeakyReLU(0.1)(b)
    b = layers.SpatialDropout2D(0.05)(b)                    # Regularization
    residual_bottleneck = layers.Conv2D(filters, 3, padding = "same", use_bias=False)(b)
    residual_bottleneck = layers.GroupNormalization(groups=32)(residual_bottleneck)
    b = layers.Add()([x, residual_bottleneck])              # Residual connection
    b = layers.LeakyReLU(0.1)(b)
    b_channel_att = se_block(b)
    b_spatial_att = SpatialAttentionBlock()(b_channel_att) 
    return b_spatial_att
     
# Decoder
def dec_block(x, skip, filters):
    x = layers.Conv2D(filters * (UPSCALE_FACTOR**2), 3, padding="same", use_bias=True)(x)   # Before upsample increase the number of channels
    x = PixelShuffle(block_size=UPSCALE_FACTOR)(x)           # Upscale by Pixelshuffle
    x = layers.LeakyReLU(0.1)(x)
    x = layers.Concatenate()([x, skip])                      # Skip connection (core of the U-Net)
    x = layers.Conv2D(filters, 3, padding="same", use_bias=False)(x)
    x = layers.GroupNormalization(groups=32)(x)    
    x = layers.LeakyReLU(0.1)(x)
    x = layers.SpatialDropout2D(0.05)(x)                     # Regularization
    x = layers.Conv2D(filters, 3, padding="same", use_bias=False)(x)
    x = layers.GroupNormalization(groups=32)(x)
    x = layers.LeakyReLU(0.1)(x)
    return x





def build_unet_sr_generator():
    inputs = layers.Input((None, None, CH))  

    # Global Residual Learning (START).
    # "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" by "Christian Ledig et al., 2017 (CVPR)"
    upsampled_lr_base = layers.UpSampling2D(size = (UPSCALE_FACTOR, UPSCALE_FACTOR), interpolation = "bilinear")(inputs)

    s0_initial = layers.Conv2D(64, 3, padding ="same", use_bias=True)(inputs)
    s0_initial = layers.LeakyReLU(0.1)(s0_initial)

    # Head block
    head_features = s0_initial
    for _ in range(NUM_HEAD_BLOCKS):
        head_features = head_tail_block(head_features, filters=64)
    head_output = layers.Conv2D(64, 3, padding="same")(head_features)
    
    s0_sum = layers.Add()([head_output, s0_initial])
    s0_final = layers.LeakyReLU(0.1)(s0_sum)

    # Encoder
    e1 = enc_block(s0_final, 64)                        
    e2 = enc_block(e1, 128)                       
    e3 = enc_block(e2, 256)                       

    # Bottleneck (RCAM & others style)
    b = e3
    for _ in range(NUM_BOTTLENECK_BLOCKS):
        b = bot_block(b, 256)

    # Decoder
    d3 = dec_block(b, e2, 128)                  
    d2 = dec_block(d3, e1, 64)
    d1 = dec_block(d2, s0_final, 64)

    # Tail block
    tail_refined = d1
    for _ in range(NUM_TAIL_BLOCKS): 
        tail_refined = head_tail_block(tail_refined, filters=64)
    
    # Last upsample
    x = layers.Conv2D(CH * (UPSCALE_FACTOR**2), 3, padding="same")(tail_refined)     # Before upsample increase the number of channels 
    residual_output = PixelShuffle(block_size=UPSCALE_FACTOR)(x)                     # Upscale by Pixelshuffle 
    
    # Global Residual Learning (FINISH)
    unactivated_output = layers.Add(dtype=tf.float32)([upsampled_lr_base, residual_output])

    # The final activation
    outputs = layers.Activation("tanh", dtype=tf.float32)(unactivated_output)

    return keras.Model(inputs, outputs, name="UNet_SR_Generator")

# 3 Generator
generator_u_net = build_unet_sr_generator()
generator_u_net.summary()

### PatchGAN

Instead of classifying the entire image as real/fake (a single output), PatchGAN classifies N x N patches of the input image as real or fake.
This encourages the generator to produce realistic details across the entire image. The output is a single-channel feature map (e.g., 30x30x1) where each "pixel" corresponds to confidence that a specific patch of the input image is real. This single value is the discriminator's raw prediction (logit) about the "realness" of the patch of the original input image that corresponds to that location in the output grid.

In the output, there is **NO** activation function (like `sigmoid`). This is because the loss function, **Least Squares GAN (LSGAN)** loss, operates directly on the raw output values (logits). It penalizes the discriminator based on how far its predictions are from the target labels (1 for real, 0 for fake) using a mean squared error objective, which is known to provide more stable gradients than traditional cross-entropy.

"SpectralNormalization" is used to increase stability by applying a spectral normalization on the weights of a target layer.

In [None]:
# --- 8.2 Discriminator (PatchGAN) Architecture 

def build_discriminator(input_shape=(PATCH_H, PATCH_W, CH)):
    # dtype="float32" for stability
    initializer = tf.random_normal_initializer(0., 0.02)
    inputs = layers.Input(shape = input_shape, name = "discriminator_input", dtype = "float32")
 
    # Block 1: C64
    x = layers.Conv2D(64, 4, strides = 2, padding=  "same",
                      kernel_initializer = initializer, use_bias = False, dtype = "float32")(inputs)
    x = layers.LeakyReLU(0.2, dtype="float32")(x)

    # Block 2: C128

    conv2 = layers.Conv2D(128, 4, strides = 2, padding="same",
                          kernel_initializer = initializer, use_bias = False, dtype = "float32")
    x = SpectralNormalization(conv2, dtype = "float32")(x)
    x = layers.LeakyReLU(0.2, dtype="float32")(x)

    # Block 3: C256
    conv3 = layers.Conv2D(256, 4, strides = 2, padding = "same",
                          kernel_initializer = initializer, use_bias = False, dtype = "float32")
    x = SpectralNormalization(conv3, dtype = "float32")(x)
    x = layers.LeakyReLU(0.2, dtype="float32")(x)

    # Block 4: C512
    conv4 = layers.Conv2D(512, 4, strides = 1, padding = "same",
                          kernel_initializer = initializer, use_bias = False, dtype = "float32")
    x = SpectralNormalization(conv4, dtype = "float32")(x)
    x = layers.LeakyReLU(0.2, dtype = "float32")(x)

    # Output
    conv_out = layers.Conv2D(1, 4, strides = 1, padding = "same",
                             kernel_initializer = initializer, dtype = "float32")
    patch_output = SpectralNormalization(conv_out, dtype = "float32")(x)

    model = keras.Model(inputs, patch_output, name = "PatchGAN_Discriminator_SN_float32")

    return model


discriminator = build_discriminator()
discriminator.summary()

## Loss Functions and Metrics
Loss functions for models training:
 
### Pixel and Perceptual Loss (for the Generator)
 *   **Mean Absolute Error (MAE):** A simple pixel-wise loss that measures the absolute difference between the predicted image and the ground truth. It's good for overall color and structure but can lead to blurry results.
 *   **Perceptual Loss:** This loss operates in a feature space, not pixel space. I use a pre-trained VGG19 network to extract features from both the predicted and ground truth images. By minimizing the MAE between these feature maps, the generator is encouraged to produce images that are *perceptually* similar to the ground truth, capturing high-level structures and textures more effectively.
 
#### Adversarial Loss (for GAN training)
Instead of using traditional Binary Cross-Entropy, this model uses the **Least Squares GAN (LSGAN)** loss.

 *   **Discriminator Loss:** The discriminator's goal is to output values close to 1 for real images and 0 for fake images. Its loss is the mean squared error between its predictions and these target values.

 *   **Generator GAN Loss:** The generator's goal is to fool the discriminator. To do this, it tries to generate images that the discriminator will classify as real (i.e. output a score of 1). Its loss is the mean squared error between the discriminator's predictions on fake images and a target of all ones.

 
 The generator's **total loss** in the GAN setup is a weighted sum of pixel loss, perceptual loss, and this adversarial loss.

In [None]:
# --- 9 Loss Functions

pixel_loss_fn = keras.losses.MeanAbsoluteError()

# 9.1 Base: Pixel-By-Pixel loss (MAE)
def mae_only_loss(y_true, y_pred):
    return pixel_loss_fn(y_true, y_pred)

# 9.2 Advance: MAE + Perceptual (VGG19) loss
vgg = VGG19(include_top = False, weights = "imagenet")   # "include_top = FALSE" to load VGG19 without classification layers
vgg.trainable = False                                    # "trainable = False" to freeze its weight (I only want to extract feature)

# 9.2 Extracion feature points
vgg_layers = ["block2_conv2", "block3_conv4", "block4_conv4", "block5_conv4"]
outputs = [vgg.get_layer(name).output for name in vgg_layers]

# 9.2.1 Feature extraction 
feat_ext_multi  = keras.Model(
    inputs = vgg.input,
    outputs = outputs       
)

def perceptual_multi_layer(y_true, y_pred):              # y_true (ground truth) and y_pred (output of the models) are in range [-1, 1]
    y_true_f32 = tf.cast(y_true, tf.float32)             # Cast inputs to float32 before passing to VGG.
    y_pred_f32 = tf.cast(y_pred, tf.float32)             # Cast inputs to float32 before passing to VGG.
    
    y_true_0_255 = ((y_true_f32 + 1.0) / 2.0) * 255.0    # y_true is in [-1, 1]. They must be converted to [0, 255] for VGG.
    y_pred_0_255 = ((y_pred_f32 + 1.0) / 2.0) * 255.0    # y_pred is in [-1, 1]. They must be converted to [0, 255] for VGG.
   
    yt_vgg = preprocess_input(y_true_0_255)              # Denormalize to [0, 255] to enable "preprocess_input" function
    yp_vgg = preprocess_input(y_pred_0_255)
    features_true = feat_ext_multi (yt_vgg)              # Extract feature from VGG19 for true and predicted images
    features_pred = feat_ext_multi (yp_vgg) 
    
    # Calculate the loss for each layer and add them up
    total_perceptual_loss = 0.0
    layer_weights = [0.303, 0.303, 0.242, 0.151]                # More weight to deeper features

    for i in range(len(features_true)):
        layer_loss = pixel_loss_fn(features_true[i], features_pred[i])     # Use Mae for layer loss of feature extracted
        total_perceptual_loss += layer_loss * layer_weights[i]             # Multiply each layer for it's weight

    return total_perceptual_loss


# 9.3 Combined loss MAE + Perceptual Loss
def content_loss_perceptual(y_true, y_pred):
    pixel_l = pixel_loss_fn(y_true, y_pred)
    perceptual_l = perceptual_multi_layer(y_true, y_pred)
    return pixel_l + PERCEPTUAL_WEIGHT_PRE * perceptual_l

In [None]:
# --- 9.4 Adversarial (GAN) Loss functions 

# 9.4.1 Least Squares GAN (LSGAN) Loss
# Instead of using Binary Cross-Entropy (which can lead to vanishing gradients when the
# discriminator becomes too confident), we use the Least Squares loss. This loss function
# penalizes predictions that are far from the target label (1 for real, 0 for fake)
# using a Mean Squared Error objective.

# 9.4.2 Discriminator Loss (LSGAN)

def discriminator_loss(real_output, fake_output):
    """
    Calculates the Least Squares loss for the discriminator.
    The discriminator's goal is to output values close to 1 for real images
    and close to 0 for fake images. We use Mean Squared Error for this.
    """
    
    real_output = tf.cast(real_output, tf.float32)
    fake_output = tf.cast(fake_output, tf.float32)
    
    # Loss for real images: Measures how far the discriminator's predictions
    # are from the target of 1.0. The use of 0.9 is for label smoothing
    real_loss = tf.reduce_mean(tf.square(real_output - 0.9))

    # Loss for fake images: Measures how far the discriminator's predictions
    # are from the target of 0.0.
    fake_loss = tf.reduce_mean(tf.square(fake_output))

    # The total loss is the average of the two
    return 0.5 * (real_loss + fake_loss)
 

# 9.4.3 Generator loss

def generator_gan_loss(fake_output):
    """
    Calculates the Least Squares loss for the generator.
    The generator's goal is to fool the discriminator. It wants the discriminator
    to score its fake images as 1.0 (real). This loss measures how far the
    discriminator's predictions for the fake images are from this target.
    """
    fake_output = tf.cast(fake_output, tf.float32)
    
    # By minimizing this loss, the generator learns to produce images that
    # the discriminator scores as close to 1.0 as possible
    return tf.reduce_mean(tf.square(fake_output - 1.0))

 ## SRRGAN Custom Model
To handle the two training part process of a GAN (updating the discriminator and generator separately), a custom `keras.Model` subclass is created that was made compatible with Keras APIs like `.fit()` and callbacks.
 
### Training Step Logic:
1.  **Train Discriminator:**
    1.   Generate a batch of fake (SR) images using the generator;
    2.   The discriminator makes predictions for both real (HR) and fake (SR) images.
    3.   Calculate the `discriminator_loss`.
    4.   Compute gradients and update ONLY the discriminator's weights.
 
2.  **Train Generator:**
     1.   Generate a new batch of fake images *inside a new gradient tape*.
     2.   Get the discriminator's verdict on these fake images.
     3.   Calculate the generator's total loss: a combination of `content_loss_perceptual` (pixel + VGG) and `generator_gan_loss`.
     4.   Compute gradients and update only the generator's weights.
 
3.  **Update Metrics:** Update and log all relevant metrics (losses, PSNR, SSIM).

In [None]:
# --- 9.5 SRRGAN Custom Model Class 

@register_keras_serializable()       # Allows Keras to save and load the custom model
class SRRGAN(keras.Model):
    """
    Custom Keras Model that encapsulates the SRRGAN training logic.
    This class  implement the two-part update rule for GANs (updating discriminator and generator separately).
    It also handles metric tracking and serialization.
    """
    def __init__(self, generator, discriminator, g_optimizer, d_optimizer,
                 pixel_weight = PIXEL_WEIGHT_GAN, gan_weight = GLOBAL_GAN_WEIGHT , perceptual_weight = PERCEPTUAL_WEIGHT_GAN, **kwargs):
        super().__init__(**kwargs)
        
        # --- Core Components ---
        # The SRRGAN model holds the generator, discriminator, and their optimizers
        # as internal attributes.
        self.generator = generator
        self.discriminator = discriminator
        self.g_optimizer = g_optimizer
        self.d_optimizer = d_optimizer
        
        # --- Hyperparameters ---
        # Weights for the different components of the generator's total loss.
        self.pixel_weight = PIXEL_WEIGHT_GAN
        self.gan_weight = gan_weight                 
        self.perceptual_weight = perceptual_weight

        # --- Metric Trackers ---
        # Keras metrics to track the mean of different loss values over each epoch.
        # Raw losses
        self.g_total_loss_tracker = metrics.Mean(name="g_total_loss")
        self.g_pixel_loss_tracker = metrics.Mean(name="g_pixel_loss")
        self.g_perceptual_loss_tracker = metrics.Mean(name="g_perceptual_loss")
        self.g_gan_loss_tracker = metrics.Mean(name="g_gan_loss")
        self.d_loss_tracker = metrics.Mean(name="d_loss")
        
        # Weighted losses
        self.g_pixel_loss_weighted_tracker = metrics.Mean(name="g_pixel_loss_weighted")
        self.g_perceptual_loss_weighted_tracker = metrics.Mean(name="g_perceptual_loss_weighted")
        self.g_gan_loss_weighted_tracker = metrics.Mean(name="g_gan_loss_weighted")

        # standard metrics
        self.psnr_tracker = metrics.Mean(name="psnr")
        self.ssim_tracker = metrics.Mean(name="ssim")

    def get_config(self):
        """
        Enables the model to be saved and loaded.
        This method returns a dictionary containing the configurations of all
        its components, which Keras uses during serialization.
        """
        base = super().get_config()
        base.update({
            # "serialize_keras_object" converts the models and optimizers into a
            # savable dictionary format.
            "generator": keras.saving.serialize_keras_object(self.generator),
            "discriminator": keras.saving.serialize_keras_object(self.discriminator),
            "g_optimizer": keras.saving.serialize_keras_object(self.g_optimizer),
            "d_optimizer": keras.saving.serialize_keras_object(self.d_optimizer),
            "pixel_weight": float(self.pixel_weight),
            "gan_weight": float(self.gan_weight),
            "perceptual_weight": float(self.perceptual_weight),
        })
        return base

    @classmethod
    def from_config(cls, config):
        """
        Creates an SRRGAN model instance from a configuration dictionary.
        This method is the counterpart to `get_config`, used by Keras to
        reconstruct the model when loading it from a file.
        """
        # "deserialize_keras_object" reconstructs the Python objects from their
        # dictionary representations.
        generator = keras.saving.deserialize_keras_object(config.pop("generator"))
        discriminator = keras.saving.deserialize_keras_object(config.pop("discriminator"))
        g_optimizer = keras.saving.deserialize_keras_object(config.pop("g_optimizer"))
        d_optimizer = keras.saving.deserialize_keras_object(config.pop("d_optimizer"))

        # The remaining items in the config are passed to the class constructor.
        return cls(
            generator=generator,
            discriminator=discriminator,
            g_optimizer=g_optimizer,
            d_optimizer=d_optimizer,
            **config
        )
    
    
    @property
    def metrics(self):
        return [
            self.g_total_loss_tracker, 
            self.d_loss_tracker,
            # Raw losses
            self.g_pixel_loss_tracker,
            self.g_perceptual_loss_tracker, 
            self.g_gan_loss_tracker,
            # Weighted losses
            self.g_pixel_loss_weighted_tracker,
            self.g_perceptual_loss_weighted_tracker,
            self.g_gan_loss_weighted_tracker,
            # Standard metrics
            self.psnr_tracker, 
            self.ssim_tracker,
        ]

    def train_step(self, data):
        """
        Defines the logic for one training step (one batch of data).
        """
        # Unpack the data.
        lr_images, hr_images = data

        # --- 1. Train the Discriminator ---
        # GradientTape to record operations for automatic differentiation.
        with tf.GradientTape() as tape:
            # Generate a batch of fake images.
            sr_images = self.generator(lr_images, training=True)            
            # Get the discriminator's predictions for both real and fake images.
            real_output = self.discriminator(hr_images, training=True)
            fake_output = self.discriminator(sr_images, training=True)
            # Calculate the discriminator's loss.
            d_loss = discriminator_loss(real_output, fake_output)
        
        # Compute the gradients of the loss with respect to the discriminator's weights.
        d_grads = tape.gradient(d_loss, self.discriminator.trainable_variables)
        # Apply the gradients to update the discriminator's weights.
        self.d_optimizer.apply_gradients(zip(d_grads, self.discriminator.trainable_variables))

        # --- 2. Train the Generator ---
        with tf.GradientTape() as tape:
            # Generate a batch of fake images.
            sr_images = self.generator(lr_images, training=True)
            # Get the discriminator's predictions
            fake_output = self.discriminator(sr_images, training=True)
            
            # Calculate raw loss components.            
            pixel_l = pixel_loss_fn(hr_images, sr_images)
            perceptual_l = perceptual_multi_layer(hr_images, sr_images)
            gan_l = generator_gan_loss(fake_output)

            # Calculated weight loss components
            pixel_l_w = self.pixel_weight * tf.cast(pixel_l, tf.float32)
            perceptual_l_w = self.perceptual_weight * tf.cast(perceptual_l, tf.float32)
            gan_l_w = self.gan_weight * tf.cast(gan_l, tf.float32)

            # Combine the weighted losses into the total loss        
            g_total_loss = pixel_l_w + perceptual_l_w + gan_l_w
            

        # Compute gradients and update the generator's weights.
        g_grads = tape.gradient(g_total_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_variables))

        # --- 3. Update Metrics ---
        self.d_loss_tracker.update_state(d_loss)
        self.g_total_loss_tracker.update_state(g_total_loss)
        
        # Update raw loss 
        self.g_pixel_loss_tracker.update_state(pixel_l)
        self.g_perceptual_loss_tracker.update_state(perceptual_l)
        self.g_gan_loss_tracker.update_state(gan_l)
        
        # Update weighted
        self.g_pixel_loss_weighted_tracker.update_state(pixel_l_w)
        self.g_perceptual_loss_weighted_tracker.update_state(perceptual_l_w)
        self.g_gan_loss_weighted_tracker.update_state(gan_l_w)
        
        # Update standard metrics
        self.psnr_tracker.update_state(psnr_metric(hr_images, sr_images))
        self.ssim_tracker.update_state(ssim_metric(hr_images, sr_images))

        # Return a dictionary of the current metric values. Keras displays this.
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        """
        Logic of one evaluation step.
        In the test step,  the evaluation is only on the generator's performance.
        The discriminator is not involved.
        """
        lr_images, hr_images = data
        
        # Generate SR images in inference mode.
        sr_images = self.generator(lr_images, training=False)
        # Update evaluation metrics.
        self.psnr_tracker.update_state(psnr_metric(hr_images, sr_images))
        self.ssim_tracker.update_state(ssim_metric(hr_images, sr_images))
        pixel_l = pixel_loss_fn(hr_images, sr_images)
        perceptual_l = perceptual_multi_layer(hr_images, sr_images)
        self.g_pixel_loss_tracker.update_state(pixel_l)
        self.g_perceptual_loss_tracker.update_state(perceptual_l)

        # Return a dictionary of the final evaluation metrics.
        return {
            "psnr": self.psnr_tracker.result(),
            "ssim": self.ssim_tracker.result(),
            "pixel_loss": self.g_pixel_loss_tracker.result(),
            "perceptual_loss": self.g_perceptual_loss_tracker.result()
        }

## Model Training
This section compiles and trains the models with our different loss function configurations.
 
 ### Training Strategy:
A three-stage training strategy is applied:
1.  **MAE Only:** Train a generator using only MAE loss. This provides a baseline and establishes a good initial weight configuration.
2.  **MAE + Perceptual:** Continue to Pre-train the generator by adding the Perceptual part. This serves as a strong pre-trained model for the final GAN step.
3.  **SRRGAN (Fine-tuning):** Initialize the generator with the weights from the MAE + Perceptual model. Then, train it adversarially against the discriminator.

In [None]:
# --- 10 Additional Metrics

def denorm_for_metric(tensor):          # Denormalize from [-1, 1] to [0, 1] for standard metric calculation.
    return (tensor + 1.0) / 2.0

def psnr_metric(y_true, y_pred):
    return tf.image.psnr(denorm_for_metric(y_true), denorm_for_metric(y_pred), max_val=1.0)

def ssim_metric(y_true, y_pred):
    return tf.image.ssim(denorm_for_metric(y_true), denorm_for_metric(y_pred), max_val=1.0)

In [None]:
# --- 11.1 Standard Plots

def plot_standard_metrics(history, title_prefix=""):
    epochs = range(1, len(history.history["loss"]) + 1)
    
    plt.figure(figsize=(15, 4))

    # Loss
    plt.subplot(1, 3, 1)
    plt.plot(epochs, history.history["loss"], label="Train Loss")
    plt.plot(epochs, history.history["val_loss"], label="Val Loss")
    plt.title(f"{title_prefix} Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss Value")
    plt.legend()

    # PSNR
    plt.subplot(1, 3, 2)
    plt.plot(epochs, history.history["psnr_metric"], label="Train PSNR")
    plt.plot(epochs, history.history["val_psnr_metric"], label="Val PSNR")
    plt.title(f"{title_prefix} PSNR")
    plt.xlabel("Epoch")
    plt.ylabel("PSNR (dB)")
    plt.legend()

    # SSIM
    plt.subplot(1, 3, 3)
    plt.plot(epochs, history.history["ssim_metric"], label="Train SSIM")
    plt.plot(epochs, history.history["val_ssim_metric"], label="Val SSIM")
    plt.title(f"{title_prefix} SSIM")
    plt.xlabel("Epoch")
    plt.ylabel("SSIM")
    plt.legend()

    plt.tight_layout()
    plt.show()


# --- 11.2 GAN Plots

def plot_gan_metrics(history, title_prefix="SRRGAN"):
    epochs = range(1, len(history.history["g_total_loss"]) + 1)

    plt.figure(figsize=(18, 8))

    # Weighted loss components
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history.history["g_total_loss"], label="Total Gen. Loss", linewidth=2)
    plt.plot(epochs, history.history["g_pixel_loss_weighted"], label="Weighted Pixel Loss", linestyle='--')
    plt.plot(epochs, history.history["g_perceptual_loss_weighted"], label="Weighted Perceptual Loss", linestyle='-.')
    plt.plot(epochs, history.history["g_gan_loss_weighted"], label="Weighted GAN Loss", linestyle=':')
    plt.title(f"{title_prefix} Weighted Generator Loss Components (Train)") 
    plt.xlabel("Epoch")
    plt.ylabel("Loss Value")
    plt.legend()

    # Discriminator Loss
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history.history["d_loss"], label="Disc. Loss", color='red')
    plt.title(f"{title_prefix} Discriminator Loss (Train)")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    # PSNR
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history.history["psnr"], label="Train PSNR")
    plt.plot(epochs, history.history["val_psnr"], label="Val PSNR")
    plt.title(f"{title_prefix} PSNR")
    plt.xlabel("Epoch")
    plt.ylabel("PSNR (dB)")
    plt.legend()

    # SSIM
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history.history["ssim"], label="Train SSIM")
    plt.plot(epochs, history.history["val_ssim"], label="Val SSIM")
    plt.title(f"{title_prefix} SSIM")
    plt.xlabel("Epoch")
    plt.ylabel("SSIM")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
# --- 12.1 Callbacks

def callbacks_mae (checkpoint_path, patience_es = 15, patience_lr = 5):
    return [
        keras.callbacks.ModelCheckpoint(                                        # Save only the model with the best validation loss
            filepath = checkpoint_path,
            save_best_only = True,
            monitor = "val_loss",
            verbose = 1
        ),
        keras.callbacks.EarlyStopping(                                          # Stop training if there are not improvement
            patience = patience_es,
            restore_best_weights = True,
            monitor = "val_loss",
            verbose = 1
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor = 'val_loss', 
            factor = 0.5,        
            patience = patience_lr, 
            verbose = 1,
            min_lr = 1e-6         
        )
    ]

def callbacks_perceptual  (checkpoint_path, patience_es = 16, patience_lr = 4):
    return [
        keras.callbacks.ModelCheckpoint(                                        # Save only the model with the best validation loss
            filepath = checkpoint_path,
            save_best_only = True,
            monitor = "val_loss",
            verbose = 1
        ),
        keras.callbacks.EarlyStopping(                                          # Stop training if there are not improvement
            patience = patience_es,
            restore_best_weights = True,
            monitor = "val_loss",
            verbose = 1
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor = 'val_loss', 
            factor = 0.5,        
            patience = patience_lr, 
            verbose = 1,
            min_lr = 5e-8         
        )
    ]


# 12.2 Checkpoint Position
CHECKPOINT_MAE        = "checkpoint_mae.keras"
CHECKPOINT_PERCEPTUAL = "checkpoint_perceptual.keras"


# 12.3 callbacks creation
callbacks_mae        = callbacks_mae(CHECKPOINT_MAE)
callbacks_perceptual = callbacks_perceptual(CHECKPOINT_PERCEPTUAL)


# --- 12.4 GAN callbacks

steps_per_epoch = len(train_ds)

gan_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath = "gan_checkpoints/SRRGAN_epoch_{epoch:02d}.keras",
    save_weights_only = False,
    save_best_only = False,               # Save PERIODICALLY, not just the best
    save_freq = steps_per_epoch * 4       # Save every 4 epochs
)


class GANMonitor(keras.callbacks.Callback):
    """
    Callback to generate and show images 
    """
    def __init__(self, val_dataset, num_samples=3, frequency=1, log_dir="gan_images"):
        super().__init__()
        val_batch = next(iter(val_dataset))
        self.lr_images = val_batch[0][:num_samples]
        self.hr_images = val_batch[1][:num_samples]
        self.num_samples = num_samples
        self.frequency = frequency  
        
        self.log_dir = log_dir
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

    def denorm(self, img):
        img = tf.cast(img, tf.float32)
        return tf.clip_by_value((img + 1.0) / 2.0, 0, 1)

    def on_epoch_end(self, epoch, logs=None):
        # Check if the current epoch is a multiple of the frequency
        # The `+ 1` is because epochs are 0-indexed in the callback.
        if (epoch + 1) % self.frequency == 0:
            generator = self.model.generator
            sr_images_pred = generator.predict(self.lr_images, verbose=0)

            print(f"\n--- Generating images for epoch {epoch + 1} ---")
            for i in range(self.num_samples):
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                
                axes[0].imshow(self.denorm(self.lr_images[i]))
                axes[0].set_title("Low-Res Input")
                axes[0].axis("off")
                
                axes[1].imshow(self.denorm(sr_images_pred[i]))
                axes[1].set_title("Super-Res (Generated)")
                axes[1].axis("off")

                axes[2].imshow(self.denorm(self.hr_images[i]))
                axes[2].set_title("High-Res (Ground Truth)")
                axes[2].axis("off")

                fig.suptitle(f'Epoch: {epoch + 1}', fontsize=16)
                
                filepath = os.path.join(self.log_dir, f'epoch_{epoch+1:03d}_sample_{i+1}.png')
                plt.savefig(filepath)
                
                plt.show()
                plt.close(fig)

In [None]:
# --- 13 Optimizers: Schedule, mixedprecision & dictionary

# --- Mae & Mae + Perceptual

# 13.1 Learning rate
initial_lr_mae = 5e-4
initial_lr_perceptual = 1e-5

# 13.2 Adam optimizer 
optimizer_mae = keras.optimizers.Adam(
    learning_rate = initial_lr_mae,
    weight_decay = 1e-4
)

optimizer_perceptual = keras.optimizers.Adam(
    learning_rate = initial_lr_perceptual,
    weight_decay = 5e-5
)


# 13.3 Add Mixed precision

optimizer_mae = tf.keras.mixed_precision.LossScaleOptimizer(optimizer_mae)
optimizer_perceptual = tf.keras.mixed_precision.LossScaleOptimizer(optimizer_perceptual)





# --- GAN networks 

# Number of training steps
steps_per_epoch = len(train_ds)
total_steps = steps_per_epoch * EPOCHS 

# Schedule for generator and discriminator
g_lr_schedule = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = 5e-5,   # Generator LR
    decay_steps = total_steps,
    alpha=0.0
)

d_lr_schedule = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = 2.5e-5,  # Discriminator LR 
    decay_steps = total_steps,
    alpha = 0.0
)


# Optimizers with schedule. One for the generator and one for the discriminator
g_optimizer_gan = keras.optimizers.Adam(learning_rate = g_lr_schedule,
                                        beta_1 = 0.5,
                                        clipnorm = 1.0
                                       )

d_optimizer_gan = keras.optimizers.Adam(learning_rate = d_lr_schedule,
                                        beta_1 = 0.5,
                                        clipnorm=1.0
                                       )


# --- 13.5 Custom objects

# Define custom objects (so Keras knows how to load the model)
custom_objects = {
    "mae_only_loss": mae_only_loss,
    "content_loss_perceptual": content_loss_perceptual,
    "perceptual_multi_layer" : perceptual_multi_layer,
    "pixel_loss_fn": pixel_loss_fn,
    "psnr_metric": psnr_metric,
    "ssim_metric": ssim_metric,
    "PixelShuffle": PixelShuffle,
    "SRRGAN": SRRGAN,
    "SpatialAttentionBlock": SpatialAttentionBlock
}

### MAE --- Baseline

In [None]:
# --- 14.1 Compile U-Net with MAE

generator_u_net.compile(
    optimizer = optimizer_mae,    
    loss = mae_only_loss,
    metrics = [
        psnr_metric,
        ssim_metric,
        mae_only_loss
    ]
)

In [None]:
# --- 14.2 Train U-Net with MAE (Training was done in a previouse Version)

#csv_logger_mae = CSVLogger("history_log_mae.csv")

#history_mae  = generator_u_net.fit(train_ds,
#                              validation_data = val_ds,
#                              epochs = EPOCHS,
#                              callbacks = callbacks_mae + [csv_logger_mae]
#                                )

In [None]:
# --- 14.3 Plot Train vs Val

# Load the history from the CSV file
history_df_mae = pd.read_csv("/kaggle/input/history-mae-and-mae-perceptual/history_log_mae.csv")

mock_history_mae = SimpleNamespace(history = history_df_mae.to_dict('list'))

# Plot
plot_standard_metrics(mock_history_mae, title_prefix="U-Net with MAE")

In [None]:
# --- 14.4 Save U-Net with MAE (Save was done in Version "Pre-Train Mae & Mae+Perceptual")

#generator_u_net.save("model_mae.keras")

### MAE + perceptual --- & Pre-Training

In [None]:
# --- 15.1 Continue pre-training

#generator_u_net = build_unet_sr_generator()
#generator_u_net = load_model("/kaggle/input/model_mae/keras/default/1/model_mae.keras", custom_objects=custom_objects)

In [None]:
# --- 15.2 Compile U-Net with MAE + Perceptual

generator_u_net.compile(
    optimizer = optimizer_perceptual,     
    loss = content_loss_perceptual,
    metrics = [
        perceptual_multi_layer,
        pixel_loss_fn,
        psnr_metric,
        ssim_metric
    ]
)

In [None]:
# --- 15.3 Train U-Net with MAE + Perceptual (Training was done in a previouse Version)

#csv_logger_perceptual = CSVLogger("history_log_perceptual.csv")

#history_perceptual  = generator_u_net.fit(train_ds,
#                                  validation_data = val_ds,
#                                  epochs = EPOCHS,
#                                  callbacks = callbacks_perceptual + [csv_logger_perceptual]
#                                         )

In [None]:
# --- 15.4 Plot Train vs Val

# Load the history from the CSV file
history_df_perceptual = pd.read_csv("/kaggle/input/history-mae-and-mae-perceptual/history_log_perceptual.csv")

mock_history_perceptual = SimpleNamespace(history = history_df_perceptual.to_dict('list'))

# Plot
plot_standard_metrics(mock_history_perceptual, title_prefix = "U-Net with Mae + Perceptual Loss")

In [None]:
# --- 15.5 Save U-Net with MAE + Perceptual (Save was done in Version "Pre-Train Mae & Mae+Perceptual")

#generator_u_net.save("model_perceptual.keras")

### MAE + perceptual + GAN --- Fine-Tuning

In [None]:
# --- 16.1 Fine-Tune

generator_u_net = build_unet_sr_generator()
generator_u_net = load_model("/kaggle/input/model-mae-and-maeperceptual/keras/default/1/model_perceptual.keras", custom_objects=custom_objects)

In [None]:
# --- 16.2 Compile SRRGAN 
  
SRRGAN = SRRGAN(
    generator = generator_u_net, 
    discriminator = discriminator,
    g_optimizer = g_optimizer_gan,
    d_optimizer = d_optimizer_gan
)

SRRGAN.compile(optimizer=g_optimizer_gan)

In [None]:
# --- 16.3 Train SRRGAN  

# Images visualization (frequency = after how many epochs new generation)
gan_monitor_callback = GANMonitor(val_dataset=val_ds, num_samples=3, frequency=5)

csv_logger_gan = CSVLogger("history_log_gan.csv")

history_gan = SRRGAN.fit(
    train_ds,
    validation_data = val_ds,
    epochs = EPOCHS,
    callbacks = [gan_checkpoint_callback, gan_monitor_callback, csv_logger_gan]
)

In [None]:
# --- 16.4 Plot GAN training history

# Load the history from the CSV file
history_df_gan = pd.read_csv("history_log_gan.csv")

mock_history_gan = SimpleNamespace(history = history_df_gan.to_dict('list'))

# Plot
plot_gan_metrics(mock_history_gan, title_prefix = "SRRGAN")

In [None]:
# --- 16.5 Save the final generator model

SRRGAN.generator.save("model_gan.keras")

### Gan Evaluation
Optimizing a **GAN** based on its loss can be unreliable due to the adversarial nature of the training.
To find the best model, I will evaluate the generator from each checkpoint (saved every 4 epochs) using the **LPIPS** metric. The best model is the one with the **lowest** LPIPS value.

The choice to use LPIPS instead of PSNR or SSIM metrics is due to the purpose of using the GAN: create images that are *perceptually realistic*, not just pixel-accurate. Since GANs are unstable, during the training, they can start producing "bad" artifacts. PSNR can miss GAN artifacts in localized zones. On the other hand, LPIPS compares images in a feature space. This correlates better with **human judgments** of perceptual similarity then PSNR or SSIM and it is more sensitive to unnatural patterns and textures than PSNR.

In [None]:
# --- 16.6 Setting for Gan evaluation

# Directory of saved checkpoints
CHECKPOINT_DIR = "gan_checkpoints"

# List of all saved checkpoint and sort them by epoch
checkpoint_paths = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "*.keras")))

print(f"Found {len(checkpoint_paths)} checkpoints to evaluate.")

In [None]:
# --- 16.7 Create LPIPS under float32 policy (Notebook's policy is float16, but the metric was built in float32)

orig_policy = tf.keras.mixed_precision.global_policy().name
tf.keras.mixed_precision.set_global_policy("float32")
lpips_loss_fn = lpips_base_tf.LPIPS(base='vgg16', pre_norm=False)
print("LPIPS compute dtype:", lpips_loss_fn.compute_dtype)
tf.keras.mixed_precision.set_global_policy(orig_policy)  # restore mixed policy

In [None]:
# --- 16.8 Evaluation loop 

results = {}

for ckpt_path in checkpoint_paths:
    print(f"\n--- Evaluating checkpoint: {os.path.basename(ckpt_path)} ---")
    model = load_model(ckpt_path, custom_objects=custom_objects)
    # Generator is the one to evealuate
    generator = model.generator

    lpips_scores = []
    for lr_batch, hr_batch in tqdm(test_ds, desc="Processing", leave=False):
        sr_batch = generator(lr_batch, training=False)

        # Cast inputs to float32 so LPIPS receives float32 tensors (better to do it again...)
        hr_batch = tf.cast(hr_batch, tf.float32)
        sr_batch = tf.cast(sr_batch, tf.float32)

        # Compute LPIPS
        lpips_score_batch = lpips_loss_fn(hr_batch, sr_batch)
        lpips_scores.extend(tf.reshape(lpips_score_batch, [-1]).numpy().tolist())

    avg_lpips = float(np.mean(lpips_scores))
    results[os.path.basename(ckpt_path)] = {"LPIPS": avg_lpips}

In [None]:
# --- 16.9 Gan final results

sorted_results = sorted(results.items(), key=lambda item: item[1]['LPIPS'])

best_gan_filename = sorted_results[0][0]
best_gan_path = os.path.join(CHECKPOINT_DIR, best_gan_filename)

print(f"\nBest GAN model identified: {best_gan_filename}")
print(f"   LPIPS: {sorted_results[0][1]['LPIPS']:.4f} (Lower is better)")

print("\n--- Full Ranking (by LPIPS) ---")
print(f"{'Checkpoint':<25} | {'LPIPS':<10}")
print("-" * 40)
for name, scores in sorted_results:
    print(f"{name:<25} | {scores['LPIPS']:.4f}")
print("-" * 40)

# Evaluation & Visualization

In [None]:
# --- 17 Load models + Evaluation

# 17.1 load models
model_mae = load_model("/kaggle/input/model-mae-and-maeperceptual/keras/default/1/model_mae.keras",               custom_objects=custom_objects,   safe_mode=False) # Mae
model_perceptual = load_model("/kaggle/input/model-mae-and-maeperceptual/keras/default/1/model_perceptual.keras", custom_objects=custom_objects,   safe_mode=False) # Mae + Perceptual

best_gan_model_full = load_model(best_gan_path, custom_objects=custom_objects)  # GAN
model_gan = best_gan_model_full.generator

# 17.2 Evaluation of each model on the test data
res_mae = model_mae.evaluate(test_ds, return_dict=True)
res_perceptual = model_perceptual.evaluate(test_ds, return_dict=True)

# 17.3 The GAN generator was not compiled with a loss... so re-compile it for evaluation
model_gan.compile(loss=mae_only_loss, metrics=[psnr_metric, ssim_metric, pixel_loss_fn, perceptual_multi_layer])
res_gan = model_gan.evaluate(test_ds, return_dict=True)
print("\nModel evaluated.")

## Evaluation

PSNR and SSIM don't always align with human perception of image quality. The LPIPS (Learned Perceptual Image Patch Similarity) metric is designed to be a better "proxy" for how humans perceive differences.

*    **PSNR: A higher PSNR score is better** indicating that the generated image is more similar to the ground truth in terms of pixel-level accuracy.
*    **SSIM: A higher SSIM score is better** indicating that the generated image is more similar to the ground truth in terms of perceived structural and visual quality.
*    **LPIPS: A lower score is better**, indicating that the generated image is perceptually closer to the ground truth.

In [None]:
# --- Evaluation

# 18.1 As before, ensure the policy is float32 for the LPIPS calculation.
orig_policy = tf.keras.mixed_precision.global_policy().name
tf.keras.mixed_precision.set_global_policy("float32")
lpips_loss_fn = lpips_base_tf.LPIPS(base='vgg16', pre_norm=False)
tf.keras.mixed_precision.set_global_policy(orig_policy)

def evaluate_lpips(model, dataset):
    """
    Calculates the average LPIPS score for a given model on a dataset.
    """
    all_lpips_scores = []
    for lr_batch, hr_batch in tqdm(dataset, desc=f"LPIPS for {model.name}"):
        sr_batch = model(lr_batch, training=False)
        
        # Cast to float32 for LPIPS metric
        hr_batch_f32 = tf.cast(hr_batch, tf.float32)
        sr_batch_f32 = tf.cast(sr_batch, tf.float32)

        lpips_scores = lpips_loss_fn(hr_batch_f32, sr_batch_f32)
        all_lpips_scores.extend(tf.reshape(lpips_scores, [-1]).numpy())
        
    return float(np.mean(all_lpips_scores))

# 18.2 Calculate Mae for all models
mae_val_mae_model = res_mae['loss'] 
mae_val_perceptual_model = res_perceptual['mean_absolute_error'] 
mae_val_gan_model = res_gan['mean_absolute_error']

# 18.3 Calculate LPIPS for all models
lpips_mae = evaluate_lpips(model_mae, test_ds)
lpips_perceptual = evaluate_lpips(model_perceptual, test_ds)
lpips_gan = evaluate_lpips(model_gan, test_ds)

# 18.4 Display Final Comparison Table
print("\n" + "="*70)
print(" " * 21 + "FINAL MODEL COMPARISON")
print("="*70)
print(f'{"Metric":<12} | {"U-Net (MAE)":<15} | {"U-Net (Perceptual)":<20} | {"SRRGAN":<15}')
print("-"*70)
print(f'{"PSNR":<12} | {res_mae["psnr_metric"]:.4f}{" ":>8} | {res_perceptual["psnr_metric"]:.4f}{" ":>13} | {res_gan["psnr_metric"]:.4f}')
print(f'{"SSIM":<12} | {res_mae["ssim_metric"]:.4f}{" ":>9} | {res_perceptual["ssim_metric"]:.4f}{" ":>14} | {res_gan["ssim_metric"]:.4f}')
print(f'{"MAE (Pixel)":<12} | {mae_val_mae_model:.4f}{" ":>9} | {mae_val_perceptual_model:.4f}{" ":>14} | {mae_val_gan_model:.4f}')
print(f'{"LPIPS":<12} | {lpips_mae:.4f}{" ":>9} | {lpips_perceptual:.4f}{" ":>14} | {lpips_gan:.4f}')
print("="*70)

## Patch Visualization

In [None]:
# --- 19 Preparation for the Patch Visualization

# 19.1 Denormalization
def denorm(img):
    img_f32 = tf.cast(img, tf.float32)
    return tf.clip_by_value((img_f32 + 1.0) / 2.0, 0, 1)

# 19.2 Select random images from test data
num_images_to_show = 10
sample_paths = random.sample(test_paths, num_images_to_show)
plt.figure(figsize=(50, 100)) # Adjusted figsize for better aspect ratio

# 19.3.1 Loop through each of the selected images
for i, sample_path_hr in enumerate(sample_paths):
    print(f"--- Processing patch from image {i+1}/{num_images_to_show} ---")

    # 19.3.2 Preparation steps are now inside the loop
    hr_full_image = tf.image.convert_image_dtype(                                  # Get image
        tf.io.decode_png(tf.io.read_file(sample_path_hr), channels=CH),
        tf.float32
    )
    hr_patch_gt = tf.image.random_crop(hr_full_image, size=[PATCH_H, PATCH_W, CH])                    # Use random crop for more variety
    lr_patch_degraded_0_1 = degradation_pipeline_tf(hr_patch_gt, is_training=False).numpy()           # Apply degradation
    lr_patch_input = (lr_patch_degraded_0_1 * 2.0) - 1.0 
    
    # 19.3.3 Model predictions
    sr_patch_predicted_mae = model_mae.predict(lr_patch_input[None, ...], verbose=0)[0]
    sr_patch_predicted_perceptual = model_perceptual.predict(lr_patch_input[None, ...], verbose=0)[0]
    sr_patch_predicted_gan = model_gan.predict(lr_patch_input[None, ...], verbose=0)[0]

    # 19.3.4 Plotting row of patch images 
    
    # 19.3.4.1 LR Input
    ax = plt.subplot(num_images_to_show, 5, i * 5 + 1)
    ax.imshow(denorm(lr_patch_input))
    ax.axis("off")
    if i == 0: ax.set_title(f"LR Input\n{PATCH_LR_H}x{PATCH_LR_W}", fontsize = 40)  

    # 19.3.4.2 U-Net with MAE
    ax = plt.subplot(num_images_to_show, 5, i * 5 + 2)
    ax.imshow(denorm(sr_patch_predicted_mae))
    ax.axis("off")
    if i == 0: ax.set_title(f"U-Net MAE\n{PATCH_H}x{PATCH_W}", fontsize = 40)

    # 19.3.4.3 U-Net with Perceptual Loss
    ax = plt.subplot(num_images_to_show, 5, i * 5 + 3)
    ax.imshow(denorm(sr_patch_predicted_perceptual))
    ax.axis("off")
    if i == 0: ax.set_title(f"U-Net Perceptual\n{PATCH_H}x{PATCH_W}", fontsize = 40)

    # 19.3.4.4 SRRGAN
    ax = plt.subplot(num_images_to_show, 5, i * 5 + 4)
    ax.imshow(denorm(sr_patch_predicted_gan))
    ax.axis("off")
    if i == 0: ax.set_title(f"SRRGAN\n{PATCH_H}x{PATCH_W}", fontsize = 40)

    # 19.3.4.5 Original HR Patch (Ground Truth)
    ax = plt.subplot(num_images_to_show, 5, i * 5 + 5)
    ax.imshow(hr_patch_gt)
    ax.axis("off")
    if i == 0: ax.set_title(f"Original HR (GT)\n{PATCH_H}x{PATCH_W}", fontsize = 40)
    
# 19.4 Show the final plot
plt.subplots_adjust(wspace=0.05, hspace=0.05)

plt.savefig("model_comparison_patch.png", dpi=400, bbox_inches='tight') 
plt.show()

### Full image inference

The following function handles inference on full-sized images, which can be tricky. Here is the general "story":
* **Padding (20.1-20.3)**:An image's dimensions might not be perfectly divisible by the patch stride, which would cause issues when tiling the image. To solve this, padding is added to **ensure** the dimensions are a perfect multiple of the stride. This is done using `mode="reflect"`, which reflects pixels from the edge to create a seamless border.
* **Extract patches (20.4)**: `tf.image.extract_patches` performs a sliding window operation, returning a single batch of all the 128x128 LR patches.

* **Predict (20.5)**: The model **predicts** on the entire batch of patches at once.
* **Reconstruction & Blending (20.6-20.10)**: Simply placing predicted HR patches side-by-side can create visible seams because the edges might not perfectly match. The solution is to **blend** the patches in their overlapping **areas**.

    * A **"Hann Window" (20.7)** is used for blending. It has a value of 1.0 in the center and smoothly fades to 0.0 at the edges.

    * Each predicted HR patch is **multiplied** by the Hann window **(20.8)** before being added to the final output canvas (`final_hr`). This way, the center of each patch contributes fully, while the edges contribute less.

    * In overlapping regions, a pixel receives contributions from multiple patches. A `weight_map` **(20.9)** keeps track of the sum of these Hann window values at every pixel. Dividing by this map performs a weighted average of the predictions, creating a smooth transition.

    * Finally, the initial padding is **cropped** off **(20.10)**.

In [None]:
# --- 20 Full Image Inference Functions

def predict_full_image_tiled(model, lr_image_norm, patch_size=128, overlap=32, batch_size=8):
    """
    model: the trained super-resolution model.
    lr_image_norm: the low-resolution input image normalized to [-1, 1].
    patch_size: the size of LR patches to predict on.
    overlap: the overlap between patches.
    batch_size: batch size for model prediction.
    """
    # 20.1 Get image dimensions and calculate stride
    lr_h, lr_w, C = lr_image_norm.shape
    stride = patch_size - overlap

    # 20.2 Calc padding needed for the image to be divisible by the stride
    # This ensures that the entire image is covered by patches.
    n_patches_h = math.ceil(max(0, lr_h - patch_size) / stride) + 1
    n_patches_w = math.ceil(max(0, lr_w - patch_size) / stride) + 1
    pad_h = (n_patches_h - 1) * stride + patch_size - lr_h
    pad_w = (n_patches_w - 1) * stride + patch_size - lr_w
    
    # 20.3 Apply reflection padding to minimize edge artifacts
    lr_padded = np.pad(lr_image_norm, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')

    # 20.4 Extract patches
    # The input image is expanded to a 4D tensor (batch, height, width, channels)
    lr_padded_tensor = tf.convert_to_tensor(lr_padded[None, ...], dtype=tf.float32)  # shape (1, H, W, C)
    patches = tf.image.extract_patches(
        images=lr_padded[None, ...],
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, stride, stride, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
    )
    # Reshape the extracted patches into a batch that the model can process
    patches = tf.reshape(patches, [-1, patch_size, patch_size, C])
    patches = tf.cast(patches, tf.float32)

    # 20.5 Run model prediction on the batch of patches (already in  [-1, 1] range)
    preds_norm = model.predict(patches, batch_size=batch_size, verbose=0)   # Output is [-1, 1]

    # 20.6 Reconstruct the full HR image from the predicted HR patches
    patch_hr_size = patch_size * UPSCALE_FACTOR
    stride_hr = stride * UPSCALE_FACTOR
    hr_h, hr_w = lr_padded.shape[0] * UPSCALE_FACTOR, lr_padded.shape[1] * UPSCALE_FACTOR
    
    # Create an empty canvas for the final image and a weight map for blending
    final_hr = np.zeros((hr_h, hr_w, C), dtype=np.float32)
    weight_map = np.zeros_like(final_hr)

    # 20.7 Create a 2D Hann window for smooth blending
    # The window gives more weight to the pixels in the center of a patch and less to the edges.
    hann_1d = np.hanning(patch_hr_size)
    window_2d = np.outer(hann_1d, hann_1d)[..., None]

    # 20.8 Loop through patches, apply the window, and add them to the canvas
    idx = 0
    for iy in range(n_patches_h):
        for ix in range(n_patches_w):
            y, x = iy * stride_hr, ix * stride_hr
            # Add the predicted patch, weighted by the Hann window
            final_hr[y:y+patch_hr_size, x:x+patch_hr_size, :] += preds_norm[idx] * window_2d
            # Add the window itself to the weight map
            weight_map[y:y+patch_hr_size, x:x+patch_hr_size, :] += window_2d
            idx += 1

    # 20.9 Normalize the reconstructed image by the weight map
    # This averages the overlapping regions weighted by the Hann window.
    final_hr = final_hr / (weight_map + 1e-8)
    
    # 20.10 Crop the image back to its original size, removing the padding
    orig_hr_h, orig_hr_w = lr_h * UPSCALE_FACTOR, lr_w * UPSCALE_FACTOR
    final_hr_cropped = final_hr[:orig_hr_h, :orig_hr_w, :]

    return final_hr_cropped # The output is still in [-1, 1] range

## Full Image Visualization

In [None]:
# --- 21 Final Full Image Visualization

# 21.1 Loaded model group
models_to_visualize = {
    "U-Net MAE": model_mae,
    "U-Net Perceptual": model_perceptual,
    "SRRGAN": model_gan
}

# 21.2 Select random images
num_images_to_show = 10 
sample_paths = random.sample(test_paths, num_images_to_show)

# 21.3 Create a figure and a grid of subplots
fig, axes = plt.subplots(
    nrows=num_images_to_show,
    ncols=5,
    figsize=(25, 5 * num_images_to_show) 
)

if num_images_to_show == 1:
    axes = np.array([axes])

# 21.4 Loop for prediction and plotting 
for i, sample_path_full in enumerate(sample_paths):
    print(f"\n--- Processing image {i+1}/{num_images_to_show}: {sample_path_full.split('/')[-1]} ---")

    # 21.5 Load and prepare the HR image
    hr_full_original = tf.image.convert_image_dtype(tf.io.decode_png(tf.io.read_file(sample_path_full), channels=CH), tf.float32).numpy()
    hr_full_tensor = tf.convert_to_tensor(hr_full_original, dtype=tf.float32)

    # Create the LR input, preserving the aspect ratio
    lr_full_input_degraded = degrade_full_image(hr_full_tensor).numpy()

    lr_model_input = (lr_full_input_degraded * 2.0) - 1.0

    # 21.6 Generate super-resolved images
    sr_images = {}
    for name, model in models_to_visualize.items():
        sr_image_norm = predict_full_image_tiled(model, lr_model_input, patch_size=PATCH_LR_W, overlap=32, batch_size=BATCH_SIZE)
        sr_images[name] = np.clip((sr_image_norm + 1.0) / 2.0, 0.0, 1.0).astype('float32')
    
    # 21.7 Plotting the images
    axes[i, 0].imshow(lr_full_input_degraded.astype('float32'))
    
    for j, (name, sr_img) in enumerate(sr_images.items()):
        axes[i, j + 1].imshow(sr_img)

    axes[i, 4].imshow(hr_full_original.astype('float32'))

# 21.8 Set titles and turn off axes
for i in range(num_images_to_show):
    for j in range(5):
        axes[i, j].axis("off")
        if i == 0:
            if j == 0:
                axes[i, j].set_title(f"Degraded LR Input", fontsize=20)
            elif j == 4:
                axes[i, j].set_title(f"Original High-Res (GT)", fontsize=20)
            else:
                model_name = list(models_to_visualize.keys())[j-1]
                axes[i, j].set_title(f"{model_name}", fontsize=20)

plt.tight_layout(pad=0.1, w_pad=0.5, h_pad=0.5)
plt.savefig("model_comparison_FULL_final.png", dpi=300, bbox_inches='tight') 
plt.show()

## Comparision with Bilinear upsample

In [None]:
# ---  Qualitative Comparison
# We will visualize the comparison of different images comparing different models: Bilinear upsample and Gan

num_images_to_show = 10 
models_to_visualize = {"SRRGAN": model_gan} 

# Figure
fig, axes = plt.subplots(
    nrows=num_images_to_show,
    ncols=4,
    figsize=(20, 5 * num_images_to_show)
)

# Ensure axes is a 2D array even if num_images_to_show is 1, for consistent indexing
if num_images_to_show == 1:
    axes = np.array([axes])

# Select random sample
sample_paths = random.sample(test_paths, num_images_to_show)

for i, sample_path in enumerate(sample_paths):
    print(f"--- Generating comparison for image {i+1}/{num_images_to_show}: {os.path.basename(sample_path)} ---")

    # 1 Original HR Image (float32 - [0, 1])
    hr_original_tensor = tf.image.convert_image_dtype(
        tf.io.decode_png(tf.io.read_file(sample_path), channels=CH),
        tf.float32
    )
    hr_original_np = hr_original_tensor.numpy()
    hr_shape = tf.shape(hr_original_tensor)[:2] 

    # 2 Low-Resolution Degraded Image ([0, 1] range)
    lr_degraded_tensor = degrade_full_image(hr_original_tensor)
    lr_degraded_np = lr_degraded_tensor.numpy()

    # 3 Create a simple Bilinear Upsampled Version
    # Resize the LR image back up to the original HR dimensions using bilinear interpolation
    hr_bilinear_tensor = tf.image.resize(
        lr_degraded_tensor,
        hr_shape,
        method=tf.image.ResizeMethod.BILINEAR
    )
    # Clip values for valid [0, 1] range for plotting
    hr_bilinear_np = tf.clip_by_value(hr_bilinear_tensor.numpy(), 0, 1)

    # 4 GAN Inference
    # 4.1 Normalize the LR image from [0, 1] to [-1, 1] for the model input
    lr_model_input = (lr_degraded_np * 2.0) - 1.0
    # 4.2 Predict full images
    sr_gan_norm_np = predict_full_image_tiled(
        model_gan,
        lr_model_input,
        patch_size=PATCH_LR_W,    
        overlap=32,               
        batch_size=BATCH_SIZE     
    )
    # 4.3 Denormalize the GAN output from [-1, 1] back to [0, 1] for plotting
    sr_gan_np = np.clip((sr_gan_norm_np + 1.0) / 2.0, 0, 1)

    
    # 5 Plot results 
    
    # Low-Resolution Degraded Input
    axes[i, 0].imshow(lr_degraded_np)
    axes[i, 0].axis('off')

    # Simple Bilinear Upsample
    axes[i, 1].imshow(hr_bilinear_np)
    axes[i, 1].axis('off')

    # GAN Inference Output
    axes[i, 2].imshow(sr_gan_np)
    axes[i, 2].axis('off')

    # Original High-Resolution (Ground Truth)
    axes[i, 3].imshow(hr_original_np)
    axes[i, 3].axis('off')

    # Titles for the first row of images
    if i == 0:
        axes[i, 0].set_title(f"Low-Res Degraded", fontsize=16)
        axes[i, 1].set_title(f"Upsampled (Bilinear)", fontsize=16)
        axes[i, 2].set_title(f"SRRGAN Inference", fontsize=16)
        axes[i, 3].set_title(f"Original HR (GT)", fontsize=16)

# Layout 
plt.tight_layout(pad=0.5)

# Save
plt.savefig("final_model_comparison.png", dpi=300, bbox_inches='tight')

# Display 
plt.show()