In [20]:
import os
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from keras import Model
from keras.optimizers import Adam
from keras.models import load_model
from keras.applications import VGG19
from sklearn.model_selection import train_test_split
from keras.applications.vgg19 import preprocess_input
from keras.layers import (
    Input, 
    Conv2D, 
    Add, 
    Lambda, 
    Concatenate, 
    LeakyReLU, 
    BatchNormalization, 
    GlobalAveragePooling2D, 
    Dense
)
from keras.backend import eval, shape, mean, square, binary_crossentropy

In [None]:
def load_dataset_as_patches(hr_root, lr_root, patch_size_lr=48, stride=24, scale_factor=2, max_patches_per_image=None):
    """
    Loads HR and LR image patches from separate folders for EDSR training.
    Extracts patches from both LR and HR images maintaining the scale factor relationship.

    Parameters:
        hr_root (str): Root path to HR images.
        lr_root (str): Root path to LR images.
        patch_size_lr (int): Size of LR patches (HR patches will be patch_size_lr * scale_factor).
        scale_factor (int): The scale factor (2, 3, or 4).
        stride (int): Stride for patch extraction.
        max_patches_per_image (int): Maximum patches to extract per image (None for all).

    Returns:
        X (np.ndarray): Low-resolution patches (model input).
        Y (np.ndarray): High-resolution patches (target).
    """
    
    def add_padding(image: np.ndarray, patch_size, stride):
        """Add padding to ensure full coverage."""
        
        h, w, c = image.shape
        
        # Calcular cuánto padding se necesita
        pad_h = (patch_size - (h % stride)) % stride if h % stride != 0 else 0
        pad_w = (patch_size - (w % stride)) % stride if w % stride != 0 else 0
        
        # Agregar padding extra para asegurar cobertura completa
        pad_h = max(pad_h, patch_size - stride)
        pad_w = max(pad_w, patch_size - stride)
        
        # Padding reflejado (mirror) para mantener continuidad
        padded_img = np.pad(
            image, 
            ((0, pad_h), (0, pad_w), (0, 0)), 
            mode='reflect'
        )
        
        return padded_img
    
    if not os.path.exists(hr_root) or not os.path.exists(lr_root):
        raise ValueError("Both HR and LR root directories must exist.")
    
    patch_size_hr = patch_size_lr * scale_factor
    X, Y = [], []

    def get_all_image_paths(root):
        image_paths = []
        for dirpath, _, filenames in os.walk(root):
            for filename in filenames:
                if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff")):
                    image_paths.append(os.path.join(dirpath, filename))
        return sorted(image_paths)

    hr_paths = get_all_image_paths(hr_root)
    lr_paths = get_all_image_paths(lr_root)

    # Match HR and LR images by filename
    hr_dict = {os.path.basename(p): p for p in hr_paths}
    lr_dict = {os.path.basename(p): p for p in lr_paths}
    common_filenames = sorted(set(hr_dict.keys()) & set(lr_dict.keys()))

    if not common_filenames:
        raise ValueError("No matching filenames found between HR and LR directories.")

    total_patches = 0
    
    for fname in common_filenames:
        hr_img = cv2.imread(hr_dict[fname], cv2.IMREAD_COLOR)
        lr_img = cv2.imread(lr_dict[fname], cv2.IMREAD_COLOR)

        if hr_img is None or lr_img is None:
            continue

        # Convert to RGB and normalize
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

        lr_h, lr_w, _ = lr_img.shape
        hr_h, hr_w, _ = hr_img.shape
        
        # Add padding to ensure full coverage
        hr_img = add_padding(hr_img, patch_size_hr, stride)
        lr_img = add_padding(lr_img, patch_size_lr, stride)

        # Generate patches
        patches_this_image = 0
        
        for i in range(0, lr_h - patch_size_lr + 1, stride):
            for j in range(0, lr_w - patch_size_lr + 1, stride):
                # Extract LR patch
                lr_patch = lr_img[i:i+patch_size_lr, j:j+patch_size_lr]
                
                # Extract corresponding HR patch
                hr_i = i * scale_factor
                hr_j = j * scale_factor
                
                if hr_i + patch_size_hr <= hr_h and hr_j + patch_size_hr <= hr_w:
                    hr_patch = hr_img[hr_i:hr_i+patch_size_hr, hr_j:hr_j+patch_size_hr]
                    
                    X.append(lr_patch)
                    Y.append(hr_patch)
                    
                    patches_this_image += 1
                    total_patches += 1
                    
                    if max_patches_per_image and patches_this_image >= max_patches_per_image:
                        break
            
            if max_patches_per_image and patches_this_image >= max_patches_per_image:
                break

    if not X:
        raise ValueError("No patches could be extracted. Check your patch size and image dimensions.")

    X_array = np.array(X)
    Y_array = np.array(Y)
    
    print(f"Extracted {total_patches} patch pairs from {len(common_filenames)} images")
    print(f"LR patches shape: {X_array.shape}")
    print(f"HR patches shape: {Y_array.shape}")

    return X_array, Y_array

In [None]:
class ESRGAN:
    """
    Enhanced Super-Resolution Generative Adversarial Network (ESRGAN) implementation.
    
    This class implements the ESRGAN architecture for image super-resolution,
    including the generator (RRDBNet), discriminator (VGG-style), and training logic.
    """
    
    def __init__(self):
        """
        Initialize ESRGAN model.
        
        Args:
            num_rrdb_blocks: Number of Residual-in-Residual Dense Blocks
        """
        
        # Initialize models
        self.generator = None
        self.discriminator = None
        self.vgg_model = None
        
        # Training parameters
        self.g_optimizer = None
        self.d_optimizer = None
        
    def setup_model(
            self, 
            scale_factor=2, 
            growth_channels=32, 
            num_rrdb_blocks=23, 
            input_shape=(None, None, 3),
            output_shape=(None, None, 3),
            from_trained=False, 
            generator_pretrained_path=None, 
            discriminator_pretrained_path=None):
        """
        Setup the ESRGAN models either from scratch or from pretrained weights.
        
        Args:
            scale_factor: Upscaling factor (2, 4, or 8)
            growth_channels: Number of growth channels in dense blocks
            lr_size: Low resolution image size
            hr_size: High resolution image size (calculated if None)
            channels: Number of image channels
            from_trained: If True, load pretrained models
            generator_pretrained_path: Path to pretrained generator model
            discriminator_pretrained_path: Path to pretrained discriminator model
        """
        
        if from_trained:
            # Check if paths exist
            if generator_pretrained_path is None or not os.path.exists(generator_pretrained_path):
                raise FileNotFoundError(f"Generator pretrained path does not exist: {generator_pretrained_path}")
            if discriminator_pretrained_path is None or not os.path.exists(discriminator_pretrained_path):
                raise FileNotFoundError(f"Discriminator pretrained path does not exist: {discriminator_pretrained_path}")
            
            # Load pretrained models
            self.generator = load_model(generator_pretrained_path)
            self.discriminator = load_model(discriminator_pretrained_path)
            self.vgg_model = self._build_vgg_model(output_shape)
            
            print(f"- Generator loaded from: {generator_pretrained_path}")
            print(f"- Discriminator loaded from: {discriminator_pretrained_path}")
            print("- VGG model built for perceptual loss")
        else:
            self.generator = self._build_generator(input_shape, scale_factor, growth_channels, num_rrdb_blocks)
            self.discriminator = self._build_discriminator(output_shape)
            self.vgg_model = self._build_vgg_model(output_shape)

            self._compile_models()
        
    def _compile_models(self):
        """
        Compile the models with optimizers.
        
        Args:
            g_lr: Generator learning rate
            d_lr: Discriminator learning rate
            beta_1: Beta1 parameter for Adam optimizer
            beta_2: Beta2 parameter for Adam optimizer
        """
        
        self.g_optimizer = Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999)
        self.d_optimizer = Adam(learning_rate=1e-5, beta_1=0.9, beta_2=0.999)
        
        print("=" * 50)
        print("GENERATOR SUMMARY")
        print("=" * 50)
        self.generator.summary()
        
        print("\n" + "=" * 50)
        print("DISCRIMINATOR SUMMARY")
        print("=" * 50)
        self.discriminator.summary()
        
        print("\n" + "=" * 50)
        print("VGG FEATURE EXTRACTOR SUMMARY")
        print("=" * 50)
        self.vgg_model.summary()
        
    def _dense_block(self, x, growth_rate, name="dense_block"):
        """
        Create a dense block.
        
        Args:
            x: Input tensor
            growth_rate: Number of filters to add per layer
            name: Block name
            
        Returns:
            Output tensor
        """
        
        # Store input for skip connection
        input_tensor = x
        input_channels = x.shape[-1]
        
        # First conv layer
        x1 = Conv2D(growth_rate, 3, padding="same", activation="relu", name=f"{name}_conv1")(x)
        x1_concat = Concatenate(name=f"{name}_concat1")([x, x1])
        
        # Second conv layer
        x2 = Conv2D(growth_rate, 3, padding="same", activation="relu", name=f"{name}_conv2")(x1_concat)
        x2_concat = Concatenate(name=f"{name}_concat2")([x, x1, x2])
        
        # Third conv layer
        x3 = Conv2D(growth_rate, 3, padding="same", activation="relu", name=f"{name}_conv3")(x2_concat)
        x3_concat = Concatenate(name=f"{name}_concat3")([x, x1, x2, x3])
        
        # Fourth conv layer
        x4 = Conv2D(growth_rate, 3, padding="same", activation="relu", name=f"{name}_conv4")(x3_concat)
        x4_concat = Concatenate(name=f"{name}_concat4")([x, x1, x2, x3, x4])
        
        # Fifth conv layer (output layer)
        x5 = Conv2D(input_channels, 3, padding="same", name=f"{name}_conv5")(x4_concat)
        
        # Residual scaling
        x5 = Lambda(lambda t: t * 0.2, name=f"{name}_scale")(x5)
        
        # Skip connection
        output = Add(name=f"{name}_add")([input_tensor, x5])
        
        return output
    
    def _rrdb_block(self, x, growth_channels, name="rddb"):
        """
        Create a Residual-in-Residual Dense Block (RRDB).
        
        Args:
            x: Input tensor
            growth_channels: Number of growth channels
            name: Block name
            
        Returns:
            Output tensor
        """
        
        input_tensor = x
        
        # Three dense blocks
        x = self._dense_block(x, growth_channels, f"{name}_dense1")
        x = self._dense_block(x, growth_channels, f"{name}_dense2")
        x = self._dense_block(x, growth_channels, f"{name}_dense3")
        
        # Residual scaling
        x = Lambda(lambda t: t * 0.2, name=f"{name}_scale")(x)
        
        # Skip connection
        output = Add(name=f"{name}_add")([input_tensor, x])
        
        return output
    
    def _upsample_block(self, x, filters, name="upsample"):
        """
        Create an upsampling block using sub-pixel convolution.
        
        Args:
            x: Input tensor
            filters: Number of filters
            name: Block name
            
        Returns:
            Upsampled tensor
        """
        
        x = Conv2D(filters * 4, 3, padding="same", name=f"{name}_conv")(x)
        x = Lambda(lambda t: tf.nn.depth_to_space(t, 2), name=f"{name}_pixelshuffle")(x)
        x = LeakyReLU(alpha=0.2, name=f"{name}_leaky")(x)
        
        return x
    
    def _build_generator(self, input_shape, scale_factor, growth_channels, num_rrdb_blocks):
        """
        Build the generator network (RRDBNet).
        
        Returns:
            Generator model
        """
        
        inputs = Input(shape=input_shape, name="lr_input")
        
        # Initial convolution
        x = Conv2D(64, 3, padding="same", name="initial_conv")(inputs)
        trunk_output = x
        
        # RRDB blocks
        for i in range(num_rrdb_blocks):
            x = self._rrdb_block(x, growth_channels, f"rrdb_{i}")
        
        # Trunk convolution
        x = Conv2D(64, 3, padding="same", name="trunk_conv")(x)
        
        # Trunk connection
        x = Add(name="trunk_add")([trunk_output, x])
        
        # Upsampling blocks
        num_upsample = int(np.log2(scale_factor))
        for i in range(num_upsample):
            x = self._upsample_block(x, 64, f"upsample_{i}")
        
        # Final convolution layers
        x = Conv2D(64, 3, padding="same", activation="relu", name="final_conv1")(x)
        outputs = Conv2D(inputs.shape[-1], 3, padding="same", activation="tanh", name="final_conv2")(x)
        
        model = Model(inputs=inputs, outputs=outputs, name="Generator")
        
        return model
    
    def _build_discriminator(self, output_shape):
        """
        Build the discriminator network (VGG-style).
        
        Returns:
            Discriminator model
        """
        
        inputs = Input(shape=output_shape, name="hr_input")
        
        # Initial convolution
        x = Conv2D(64, 3, padding="same", name="disc_conv1")(inputs)
        x = LeakyReLU(alpha=0.2, name="disc_leaky1")(x)
        
        # Convolutional blocks
        filters = [64, 128, 128, 256, 256, 512, 512]
        strides = [2, 1, 2, 1, 2, 1, 2]
        
        for i, (f, s) in enumerate(zip(filters, strides)):
            x = Conv2D(f, 3, strides=s, padding="same", name=f"disc_conv{i+2}")(x)
            x = BatchNormalization(name=f"disc_bn{i+2}")(x)
            x = LeakyReLU(alpha=0.2, name=f"disc_leaky{i+2}")(x)
        
        # Global average pooling and dense layers
        x = GlobalAveragePooling2D(name="disc_gap")(x)
        x = Dense(1024, name="disc_dense1")(x)
        x = LeakyReLU(alpha=0.2, name="disc_leaky_dense1")(x)
        outputs = Dense(1, activation="sigmoid", name="disc_output")(x)
        
        model = Model(inputs=inputs, outputs=outputs, name="Discriminator")
        
        return model
    
    def _build_vgg_model(self, output_shape):
        """
        Build VGG model for perceptual loss.
        
        Returns:
            VGG model for feature extraction
        """
        
        vgg = VGG19(
            include_top=False, 
            weights="imagenet", 
            input_shape=output_shape
        )
        
        # Extract features from conv5_4 layer
        vgg.trainable = False
        outputs = vgg.get_layer("block5_conv4").output
        
        model = Model(inputs=vgg.input, outputs=outputs, name="VGG_Feature_Extractor")
        
        return model
    
    def _preprocess_vgg_input(self, x):
        """Preprocess input for VGG model."""
        
        # Convert from [-1, 1] to [0, 255]
        x = (x + 1) * 127.5
        
        # Apply VGG preprocessing
        return preprocess_input(x)
    
    def _perceptual_loss(self, hr_real, hr_fake):
        """
        Calculate perceptual loss using VGG features.
        
        Args:
            hr_real: Real high-resolution images
            hr_fake: Generated high-resolution images
            
        Returns:
            Perceptual loss
        """
        
        # Preprocess inputs for VGG
        hr_real_vgg = self._preprocess_vgg_input(hr_real)
        hr_fake_vgg = self._preprocess_vgg_input(hr_fake)
        
        # Extract features
        real_features = self.vgg_model(hr_real_vgg)
        fake_features = self.vgg_model(hr_fake_vgg)
        
        # Calculate MSE loss
        return mean(square(real_features - fake_features))
    
    def _pixel_loss(self, hr_real, hr_fake):
        """
        Calculate pixel-wise L1 loss.
        
        Args:
            hr_real: Real high-resolution images
            hr_fake: Generated high-resolution images
            
        Returns:
            Pixel loss
        """
        
        return mean(abs(hr_real - hr_fake))
    
    def _adversarial_loss(self, y_true, y_pred):
        """
        Calculate adversarial loss.
        
        Args:
            y_true: True labels
            y_pred: Predicted labels
            
        Returns:
            Adversarial loss
        """
        
        return mean(binary_crossentropy(y_true, y_pred))
    
    def _train_step(self, lr_images, hr_images):
        """
        Perform one training step.
        
        Args:
            lr_images: Low-resolution images
            hr_images: High-resolution images
            
        Returns:
            Dictionary containing losses
        """
        
        # Train discriminator
        with tf.GradientTape() as d_tape:
            # Generate fake images
            hr_fake = self.generator(lr_images, training=True)
            
            # Discriminator predictions
            d_real = self.discriminator(hr_images, training=True)
            d_fake = self.discriminator(hr_fake, training=True)
            
            # Discriminator losses
            d_loss_real = self._adversarial_loss(tf.ones_like(d_real), d_real)
            d_loss_fake = self._adversarial_loss(tf.zeros_like(d_fake), d_fake)
            d_loss = d_loss_real + d_loss_fake
        
        # Update discriminator
        d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
        self.d_optimizer.apply_gradients(zip(d_grads, self.discriminator.trainable_variables))
        
        # Train generator
        with tf.GradientTape() as g_tape:
            # Generate fake images
            hr_fake = self.generator(lr_images, training=True)
            
            # Discriminator prediction for fake images
            d_fake = self.discriminator(hr_fake, training=True)
            
            # Generator losses
            g_adversarial_loss = self._adversarial_loss(tf.ones_like(d_fake), d_fake)
            g_perceptual_loss = self._perceptual_loss(hr_images, hr_fake)
            g_pixel_loss = self._pixel_loss(hr_images, hr_fake)
            
            # Combined generator loss
            g_loss = g_adversarial_loss + 0.01 * g_perceptual_loss + 0.01 * g_pixel_loss
        
        # Update generator
        g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_variables))
        
        return {
            "g_loss": g_loss,
            "g_adversarial_loss": g_adversarial_loss,
            "g_perceptual_loss": g_perceptual_loss,
            "g_pixel_loss": g_pixel_loss,
            "d_loss": d_loss,
            "d_loss_real": d_loss_real,
            "d_loss_fake": d_loss_fake
        }
    
    def fit(
        self, 
        dataset, 
        number_of_steps=None, 
        epochs=100):
        """
        Train the ESRGAN model.
        
        Args:
            dataset: Training dataset (tf.data.Dataset)
            epochs: Number of training epochs
        """
        
        # Print device info
        physical_devices = tf.config.list_physical_devices('GPU')
        if physical_devices:
            print(f"Training on GPU: {[d.name for d in physical_devices]}")
        else:
            print("Training on CPU")
                
        # Training loop
        for epoch in range(epochs):
            print(f"Epoch {epoch + 1}/{epochs}")
            
            # Initialize metrics
            epoch_losses = {
                "g_loss": [],
                "g_adversarial_loss": [],
                "g_perceptual_loss": [],
                "g_pixel_loss": [],
                "d_loss": [],
                "d_loss_real": [],
                "d_loss_fake": [], 
                "psnr": [], 
                "ssim": []
            }
            
            # Training step
            for step, (lr_batch, hr_batch) in enumerate(dataset):
                losses = self._train_step(lr_batch, hr_batch)
                
                # Accumulate losses
                for key, value in losses.items():
                    epoch_losses[key].append(eval(value))
                    
                # Calculate PSNR and SSIM for this batch
                hr_fake = self.generator(lr_batch, training=False)
                hr_real_eval = (hr_batch + 1.0) / 2.0
                hr_gen_eval = (hr_fake + 1.0) / 2.0
                psnr_score = tf.image.psnr(hr_real_eval, hr_gen_eval, max_val=1.0)
                ssim_score = tf.image.ssim(hr_real_eval, hr_gen_eval, max_val=1.0)
                epoch_losses["psnr"].append(eval(tf.reduce_mean(psnr_score)))
                epoch_losses["ssim"].append(eval(tf.reduce_mean(ssim_score)))
                
                # Print progress
                if number_of_steps:
                    print(
                        f"- Step {step+1}/{number_of_steps}: G_loss={eval(losses['g_loss']):.4f}, "
                        f"D_loss={eval(losses['d_loss']):.4f}, "
                        f"PSNR={epoch_losses['psnr'][-1]:.2f}, SSIM={epoch_losses['ssim'][-1]:.4f}"
                    )
                else:
                    print(
                        f"- Step {step}: G_loss={eval(losses['g_loss']):.4f}, "
                        f"D_loss={eval(losses['d_loss']):.4f}, "
                        f"PSNR={epoch_losses['psnr'][-1]:.2f}, SSIM={epoch_losses['ssim'][-1]:.4f}"
                    )
                    
            # Print epoch summary
            avg_losses = {key: np.mean(values) for key, values in epoch_losses.items()}
            print(f"  Epoch Summary - G_loss: {avg_losses['g_loss']:.4f}, "
                f"D_loss: {avg_losses['d_loss']:.4f}, "
                f"PSNR: {avg_losses['psnr']:.2f}, SSIM: {avg_losses['ssim']:.4f}")
                    
    def evaluate(self, test_dataset):
        """
        Evaluate the trained model with a test dataset.
        
        Args:
            test_dataset: Test dataset (tf.data.Dataset)
            
        Returns:
            Dictionary containing evaluation metrics
        """
        
        print("Evaluating model on test dataset...")
        
        # Initialize metrics
        total_psnr = 0.0
        total_ssim = 0.0
        total_pixel_loss = 0.0
        total_perceptual_loss = 0.0
        num_batches = 0
        
        for lr_batch, hr_batch in test_dataset:
            # Generate high-resolution images
            hr_generated = self.generator(lr_batch, training=False)
            
            # Calculate metrics for each image in the batch
            batch_size = shape(lr_batch)[0]
            
            # Pixel loss
            pixel_loss = self._pixel_loss(hr_batch, hr_generated)
            total_pixel_loss += eval(pixel_loss)
            
            # Perceptual loss
            perceptual_loss = self._perceptual_loss(hr_batch, hr_generated)
            total_perceptual_loss += eval(perceptual_loss)
            
            # Convert to [0, 1] range for PSNR and SSIM
            hr_real_eval = (hr_batch + 1.0) / 2.0
            hr_gen_eval = (hr_generated + 1.0) / 2.0
            
            # Calculate PSNR and SSIM
            psnr_score = tf.image.psnr(hr_real_eval, hr_gen_eval, max_val=1.0)
            ssim_score = tf.image.ssim(hr_real_eval, hr_gen_eval, max_val=1.0)
            
            total_psnr += eval(mean(psnr_score))
            total_ssim += eval(mean(ssim_score))
            
            num_batches += 1
        
        # Calculate averages
        avg_psnr = total_psnr / num_batches
        avg_ssim = total_ssim / num_batches
        avg_pixel_loss = total_pixel_loss / num_batches
        avg_perceptual_loss = total_perceptual_loss / num_batches
        
        metrics = {
            "avg_psnr": avg_psnr,
            "avg_ssim": avg_ssim,
            "avg_pixel_loss": avg_pixel_loss,
            "avg_perceptual_loss": avg_perceptual_loss
        }
        
        print(f"Evaluation Results:")
        print(f"  Average PSNR: {avg_psnr:.4f}")
        print(f"  Average SSIM: {avg_ssim:.4f}")
        print(f"  Average Pixel Loss: {avg_pixel_loss:.4f}")
        print(f"  Average Perceptual Loss: {avg_perceptual_loss:.4f}")
        
        return metrics

In [4]:
PATCH_SIZE_LR = 24
STRIDE = 12
SCALE_FACTOR = 2

In [5]:
X, Y = load_dataset_as_patches("../../data/images/HR", "../../data/images/LR", patch_size_lr=PATCH_SIZE_LR, stride=STRIDE, scale_factor=SCALE_FACTOR)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, shuffle=True, random_state=42)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, shuffle=True, random_state=42)

print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_val shape: {X_val.shape}, Y_val shape: {Y_val.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

Extracted 23976 patch pairs from 74 images
LR patches shape: (23976, 24, 24, 3)
HR patches shape: (23976, 48, 48, 3)
X_train shape: (19420, 24, 24, 3), Y_train shape: (19420, 48, 48, 3)
X_val shape: (2158, 24, 24, 3), Y_val shape: (2158, 48, 48, 3)
X_test shape: (2398, 24, 24, 3), Y_test shape: (2398, 48, 48, 3)


In [6]:
# Use a generator to avoid memory issues
def data_generator(X, Y):
    for x, y in zip(X, Y):
        yield x, y

output_signature = (
    tf.TensorSpec(shape=X_train.shape[1:], dtype=tf.float32),
    tf.TensorSpec(shape=Y_train.shape[1:], dtype=tf.float32)
)

BATCH_SIZE = 16

train_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(X_train, Y_train),
    output_signature=output_signature
).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(X_val, Y_val),
    output_signature=output_signature
).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(X_test, Y_test),
    output_signature=output_signature
).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

In [7]:
model = ESRGAN()

model.setup_model(
    scale_factor=2, 
    growth_channels=32, 
    num_rrdb_blocks=23, 
    input_shape=X_train.shape[1:],
    output_shape=Y_train.shape[1:],
    from_trained=False
)

GENERATOR SUMMARY
Model: "Generator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 lr_input (InputLayer)          [(None, 24, 24, 3)]  0           []                               
                                                                                                  
 initial_conv (Conv2D)          (None, 24, 24, 64)   1792        ['lr_input[0][0]']               
                                                                                                  
 rrdb_0_dense1_conv1 (Conv2D)   (None, 24, 24, 32)   18464       ['initial_conv[0][0]']           
                                                                                                  
 rrdb_0_dense1_concat1 (Concate  (None, 24, 24, 96)  0           ['initial_conv[0][0]',           
 nate)                                                            'rrdb_

In [8]:
model.fit(train_dataset, len(X_train) // BATCH_SIZE, epochs=1)

Training on GPU: ['/physical_device:GPU:0']
Epoch 1/1
- Step 1/1213: G_loss=2.4271, D_loss=1.3982
- Step 2/1213: G_loss=1.6185, D_loss=1.3859
- Step 3/1213: G_loss=1.4286, D_loss=1.3781
- Step 4/1213: G_loss=1.3974, D_loss=1.3658
- Step 5/1213: G_loss=1.3354, D_loss=1.3695
- Step 6/1213: G_loss=1.3851, D_loss=1.3602
- Step 7/1213: G_loss=1.3088, D_loss=1.3259
- Step 8/1213: G_loss=1.2718, D_loss=1.3403
- Step 9/1213: G_loss=1.2334, D_loss=1.3221
- Step 10/1213: G_loss=1.2091, D_loss=1.3211
- Step 11/1213: G_loss=1.2834, D_loss=1.3104
- Step 12/1213: G_loss=1.2485, D_loss=1.2953
- Step 13/1213: G_loss=1.2126, D_loss=1.2752
- Step 14/1213: G_loss=1.2598, D_loss=1.2739
- Step 15/1213: G_loss=1.2090, D_loss=1.2808
- Step 16/1213: G_loss=1.1784, D_loss=1.2756
- Step 17/1213: G_loss=1.2191, D_loss=1.2747
- Step 18/1213: G_loss=1.1882, D_loss=1.2616
- Step 19/1213: G_loss=1.2298, D_loss=1.2466
- Step 20/1213: G_loss=1.2486, D_loss=1.2380
- Step 21/1213: G_loss=1.2753, D_loss=1.1850
- Step 22/

In [10]:
import datetime
os.makedirs("models", exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
generator_path = os.path.join("models", f"ESRGAN_generator_x{SCALE_FACTOR}_{timestamp}.h5")
discriminator_path = os.path.join("models", f"ESRGAN_discriminator_x{SCALE_FACTOR}_{timestamp}.h5")
model.generator.save(generator_path)
model.discriminator.save(discriminator_path)



In [22]:
pretrained_model = ESRGAN()

pretrained_model.setup_model(
    from_trained=True, 
    generator_pretrained_path="models/ESRGAN_generator_x2_20250627_164319.h5", 
    discriminator_pretrained_path="models/ESRGAN_discriminator_x2_20250627_164319.h5"
)

- Generator loaded from: models/ESRGAN_generator_x2_20250627_164319.h5
- Discriminator loaded from: models/ESRGAN_discriminator_x2_20250627_164319.h5
- VGG model built for perceptual loss


In [23]:
pretrained_model.evaluate(test_dataset)

Evaluating model on test dataset...
Evaluation Results:
  Average PSNR: 3.8115
  Average SSIM: 0.0027
  Average Pixel Loss: 1.0958
  Average Perceptual Loss: 118.6404


{'avg_psnr': 3.8114942800278633,
 'avg_ssim': 0.002731919829302386,
 'avg_pixel_loss': 1.0958046497114553,
 'avg_perceptual_loss': 118.64038116659894}