In [None]:
import os
import sys
import datetime

import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from keras import Model
from keras.optimizers import Adam
from keras.models import load_model
from keras.applications import VGG19
from sklearn.model_selection import train_test_split
from keras.applications.vgg19 import preprocess_input
from tensorflow_addons.layers import SpectralNormalization
from keras.layers import (
    Input, 
    Conv2D, 
    Add, 
    Lambda, 
    Concatenate, 
    LeakyReLU, 
    BatchNormalization, 
    GlobalAveragePooling2D, 
    Dense, 
    Layer
)
from keras.backend import eval, shape, mean, square, binary_crossentropy

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../../")))
from SRModels.metrics import psnr, ssim
from SRModels.loading_methods import load_dataset_as_patches
from SRModels.data_augmentation import AdvancedAugmentGenerator
from SRModels.constants import ESRGAN_PATCH_SIZE, ESRGAN_STRIDE, RANDOM_SEED, ESRGAN_SCALE_FACTOR


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

 The versions of TensorFlow you are currently using is 2.10.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [2]:
class SelfAttention(Layer):
    """
    Self-Attention Layer for 2D feature maps.
    """
    
    def __init__(self, channels, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        
        self.channels = channels

    def build(self, input_shape):
        self.f = Conv2D(self.channels // 8, 1, padding='same', name=self.name + "_f")
        self.g = Conv2D(self.channels // 8, 1, padding='same', name=self.name + "_g")
        self.h = Conv2D(self.channels // 2, 1, padding='same', name=self.name + "_h")
        self.v = Conv2D(self.channels, 1, padding='same', name=self.name + "_v")
        
        super(SelfAttention, self).build(input_shape)

    def call(self, x):
        f = self.f(x)  # [B, H, W, C//8]
        g = self.g(x)  # [B, H, W, C//8]
        h = self.h(x)  # [B, H, W, C//2]

        shape_f = tf.shape(f)
        shape_g = tf.shape(g)
        shape_h = tf.shape(h)

        f_flat = tf.reshape(f, [shape_f[0], -1, shape_f[-1]])  # [B, HW, C//8]
        g_flat = tf.reshape(g, [shape_g[0], -1, shape_g[-1]])  # [B, HW, C//8]
        h_flat = tf.reshape(h, [shape_h[0], -1, shape_h[-1]])  # [B, HW, C//2]

        s = tf.matmul(g_flat, f_flat, transpose_b=True)  # [B, HW, HW]
        beta = tf.nn.softmax(s, axis=-1)  # attention map

        o = tf.matmul(beta, h_flat)  # [B, HW, C//2]
        o = tf.reshape(o, tf.shape(h))  # [B, H, W, C//2]
        o = self.v(o)  # [B, H, W, C]

        x = Add()([x, o])
        
        return x

In [None]:
class ESRGAN:
    """
    Enhanced Super-Resolution Generative Adversarial Network (ESRGAN) implementation.
    
    This class implements the ESRGAN architecture for image super-resolution,
    including the generator (RRDBNet), discriminator (VGG-style), and training logic.
    """
    
    def __init__(self):
        """
        Initialize ESRGAN model.
        
        Args:
            num_rrdb_blocks: Number of Residual-in-Residual Dense Blocks
        """
        
        # Initialize models
        self.generator = None
        self.discriminator = None
        self.vgg_model = None
        
        # Training parameters
        self.g_optimizer = None
        self.d_optimizer = None
        
        self.trained = False
        
    def setup_model(
            self, 
            scale_factor=2, 
            growth_channels=32, 
            num_rrdb_blocks=23, 
            input_shape=(None, None, 3),
            output_shape=(None, None, 3),
            from_trained=False, 
            generator_pretrained_path=None, 
            discriminator_pretrained_path=None):
        """
        Setup the ESRGAN models either from scratch or from pretrained weights.
        
        Args:
            scale_factor: Upscaling factor (2, 4, or 8)
            growth_channels: Number of growth channels in dense blocks
            lr_size: Low resolution image size
            hr_size: High resolution image size (calculated if None)
            channels: Number of image channels
            from_trained: If True, load pretrained models
            generator_pretrained_path: Path to pretrained generator model
            discriminator_pretrained_path: Path to pretrained discriminator model
        """
        
        if from_trained:
            # Check if paths exist
            if generator_pretrained_path is None or not os.path.exists(generator_pretrained_path):
                raise FileNotFoundError(f"Generator pretrained path does not exist: {generator_pretrained_path}")
            if discriminator_pretrained_path is None or not os.path.exists(discriminator_pretrained_path):
                raise FileNotFoundError(f"Discriminator pretrained path does not exist: {discriminator_pretrained_path}")
            
            # Load pretrained models
            self.generator = load_model(generator_pretrained_path)
            self.discriminator = load_model(discriminator_pretrained_path)
            self.vgg_model = self._build_vgg_model(output_shape)
            
            self.trained = True
            
            print(f"- Generator loaded from: {generator_pretrained_path}")
            print(f"- Discriminator loaded from: {discriminator_pretrained_path}")
            print("- VGG model built for perceptual loss")
        else:
            self.generator = self._build_generator(input_shape, scale_factor, growth_channels, num_rrdb_blocks)
            self.discriminator = self._build_discriminator(output_shape)
            self.vgg_model = self._build_vgg_model(output_shape)

            self._compile_models()
        
    def _compile_models(self):
        """
        Compile the models with optimizers.
        
        Args:
            g_lr: Generator learning rate
            d_lr: Discriminator learning rate
            
            beta_1: Beta1 parameter for Adam optimizer
            beta_2: Beta2 parameter for Adam optimizer
        """
        
        self.g_optimizer = Adam(
            learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=1e-4,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True
            ), 
            beta_1=0.9, 
            beta_2=0.999
        )
        self.d_optimizer = Adam(
            learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=1e-5,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True
            ), 
            beta_1=0.9, 
            beta_2=0.999
        )
        
        print("=" * 50)
        print("GENERATOR SUMMARY")
        print("=" * 50)
        self.generator.summary()
        
        print("\n" + "=" * 50)
        print("DISCRIMINATOR SUMMARY")
        print("=" * 50)
        self.discriminator.summary()
        
        print("\n" + "=" * 50)
        print("VGG FEATURE EXTRACTOR SUMMARY")
        print("=" * 50)
        self.vgg_model.summary()
        
    def _dense_block(self, x, growth_rate, name="dense_block"):
        """
        Create a dense block.
        
        Args:
            x: Input tensor
            growth_rate: Number of filters to add per layer
            name: Block name
            
        Returns:
            Output tensor
        """
        
        # Store input for skip connection
        input_tensor = x
        input_channels = x.shape[-1]
        
        # First conv layer
        x1 = Conv2D(growth_rate, 3, padding="same", activation="relu", name=f"{name}_conv1")(x)
        x1_concat = Concatenate(name=f"{name}_concat1")([x, x1])
        
        # Second conv layer
        x2 = Conv2D(growth_rate, 3, padding="same", activation="relu", name=f"{name}_conv2")(x1_concat)
        x2_concat = Concatenate(name=f"{name}_concat2")([x, x1, x2])
        
        # Third conv layer
        x3 = Conv2D(growth_rate, 3, padding="same", activation="relu", name=f"{name}_conv3")(x2_concat)
        x3_concat = Concatenate(name=f"{name}_concat3")([x, x1, x2, x3])
        
        # Fourth conv layer
        x4 = Conv2D(growth_rate, 3, padding="same", activation="relu", name=f"{name}_conv4")(x3_concat)
        x4_concat = Concatenate(name=f"{name}_concat4")([x, x1, x2, x3, x4])
        
        # Fifth conv layer (output layer)
        x5 = Conv2D(input_channels, 3, padding="same", name=f"{name}_conv5")(x4_concat)
        
        # Residual scaling
        x5 = Lambda(lambda t: t * 0.2, name=f"{name}_scale")(x5)
        
        # Skip connection
        output = Add(name=f"{name}_add")([input_tensor, x5])
        
        return output
    
    def _rrdb_block(self, x, growth_channels, name="rddb"):
        """
        Create a Residual-in-Residual Dense Block (RRDB).
        
        Args:
            x: Input tensor
            growth_channels: Number of growth channels
            name: Block name
            
        Returns:
            Output tensor
        """
        
        input_tensor = x
        
        # Three dense blocks
        x = self._dense_block(x, growth_channels, f"{name}_dense1")
        x = self._dense_block(x, growth_channels, f"{name}_dense2")
        x = self._dense_block(x, growth_channels, f"{name}_dense3")
        
        # Residual scaling
        x = Lambda(lambda t: t * 0.2, name=f"{name}_scale")(x)
        
        # Skip connection
        output = Add(name=f"{name}_add")([input_tensor, x])
        
        return output
    
    def _upsample_block(self, x, filters, name="upsample"):
        """
        Create an upsampling block using sub-pixel convolution.
        
        Args:
            x: Input tensor
            filters: Number of filters
            name: Block name
            
        Returns:
            Upsampled tensor
        """
        
        x = Conv2D(filters * 4, 3, padding="same", name=f"{name}_conv")(x)
        x = Lambda(lambda t: tf.nn.depth_to_space(t, 2), name=f"{name}_pixelshuffle")(x)
        x = LeakyReLU(alpha=0.2, name=f"{name}_leaky")(x)
        
        return x
    
    def _build_generator(self, input_shape, scale_factor, growth_channels, num_rrdb_blocks):
        """
        Build the generator network (RRDBNet).
        
        Returns:
            Generator model
        """
        
        inputs = Input(shape=input_shape, name="lr_input")
        
        # Initial convolution
        x = Conv2D(64, 3, padding="same", name="initial_conv")(inputs)
        trunk_output = x
        
        # RRDB blocks
        for i in range(num_rrdb_blocks):
            x = self._rrdb_block(x, growth_channels, f"rrdb_{i}")
        
        # Trunk convolution
        x = Conv2D(64, 3, padding="same", name="trunk_conv")(x)
        
        # Trunk connection
        x = Add(name="trunk_add")([trunk_output, x])
        
        # Self-Attention after RRDB trunk
        x = SelfAttention(64, name="self_attention_trunk")(x)
        
        # Upsampling blocks
        num_upsample = int(np.log2(scale_factor))
        for i in range(num_upsample):
            x = self._upsample_block(x, 64, f"upsample_{i}")
            
            # Self-Attention after first upsampling
            if i == 0:
                x = SelfAttention(64, name=f"self_attention_upsample_{i}")(x)
        
        # Final convolution layers
        x = Conv2D(64, 3, padding="same", activation="relu", name="final_conv1")(x)
        outputs = Conv2D(inputs.shape[-1], 3, padding="same", activation="tanh", name="final_conv2")(x)
        
        model = Model(inputs=inputs, outputs=outputs, name="Generator")
        
        return model
    
    def _build_discriminator(self, output_shape):
        """
        Build the discriminator network (VGG-style).
        
        Returns:
            Discriminator model
        """
        
        inputs = Input(shape=output_shape, name="hr_input")
        
        # Initial convolution
        x = SpectralNormalization(Conv2D(64, 3, padding="same", name="disc_conv1"))(inputs)
        x = LeakyReLU(alpha=0.2, name="disc_leaky1")(x)
        
        # Convolutional blocks
        filters = [64, 128, 128, 256, 256, 512, 512]
        strides = [2, 1, 2, 1, 2, 1, 2]
        
        for i, (f, s) in enumerate(zip(filters, strides)):
            x = SpectralNormalization(Conv2D(f, 3, strides=s, padding="same", name=f"disc_conv{i+2}"))(x)
            x = BatchNormalization(name=f"disc_bn{i+2}")(x)
            x = LeakyReLU(alpha=0.2, name=f"disc_leaky{i+2}")(x)
        
        # Global average pooling and dense layers
        x = GlobalAveragePooling2D(name="disc_gap")(x)
        x = SpectralNormalization(Dense(1024, name="disc_dense1"))(x)
        x = LeakyReLU(alpha=0.2, name="disc_leaky_dense1")(x)
        outputs = SpectralNormalization(Dense(1, activation="sigmoid", name="disc_output"))(x)
        
        model = Model(inputs=inputs, outputs=outputs, name="Discriminator")
        
        return model
    
    def _build_vgg_model(self, output_shape):
        """
        Build VGG model for perceptual loss.
        
        Returns:
            VGG model for feature extraction
        """
        
        vgg = VGG19(
            include_top=False, 
            weights="imagenet", 
            input_shape=output_shape
        )
        
        # Extract features from conv5_4 layer
        vgg.trainable = False
        outputs = vgg.get_layer("block5_conv4").output
        
        model = Model(inputs=vgg.input, outputs=outputs, name="VGG_Feature_Extractor")
        
        return model
    
    def _preprocess_vgg_input(self, x):
        """Preprocess input for VGG model."""
        
        # Convert from [-1, 1] to [0, 255]
        x = (x + 1) * 127.5
        
        # Apply VGG preprocessing
        return preprocess_input(x)
    
    def _perceptual_loss(self, hr_real, hr_fake):
        """
        Calculate perceptual loss using VGG features.
        
        Args:
            hr_real: Real high-resolution images
            hr_fake: Generated high-resolution images
            
        Returns:
            Perceptual loss
        """
        
        # Preprocess inputs for VGG
        hr_real_vgg = self._preprocess_vgg_input(hr_real)
        hr_fake_vgg = self._preprocess_vgg_input(hr_fake)
        
        # Extract features
        real_features = self.vgg_model(hr_real_vgg)
        fake_features = self.vgg_model(hr_fake_vgg)
        
        # Calculate MSE loss
        return mean(square(real_features - fake_features))
    
    def _pixel_loss(self, hr_real, hr_fake):
        """
        Calculate pixel-wise L1 loss.
        
        Args:
            hr_real: Real high-resolution images
            hr_fake: Generated high-resolution images
            
        Returns:
            Pixel loss
        """
        
        return mean(abs(hr_real - hr_fake))
    
    def _adversarial_loss(self, y_true, y_pred):
        """
        Calculate adversarial loss.
        
        Args:
            y_true: True labels
            y_pred: Predicted labels
            
        Returns:
            Adversarial loss
        """
        
        return mean(binary_crossentropy(y_true, y_pred))
    
    def _spectral_loss(self, hr_real, hr_fake):
        """Spectral (Fourier) loss for texture preservation."""
        
        # Compute FFT2 for each image in the batch
        hr_real_fft = tf.signal.fft2d(tf.cast(hr_real, tf.complex64))
        hr_fake_fft = tf.signal.fft2d(tf.cast(hr_fake, tf.complex64))
        
        # Use magnitude (abs) for comparison
        real_mag = tf.abs(hr_real_fft)
        fake_mag = tf.abs(hr_fake_fft)
        
        # L1 loss between magnitude spectra
        return tf.reduce_mean(tf.abs(real_mag - fake_mag))
    
    def _train_step(self, lr_images, hr_images):
        """
        Perform one training step.
        
        Args:
            lr_images: Low-resolution images
            hr_images: High-resolution images
            
        Returns:
            Dictionary containing losses
        """
        
        # Train discriminator
        with tf.GradientTape() as d_tape:
            # Generate fake images
            hr_fake = self.generator(lr_images, training=True)
            
            # Discriminator predictions
            d_real = self.discriminator(hr_images, training=True)
            d_fake = self.discriminator(hr_fake, training=True)
            
            # Discriminator losses
            d_loss_real = self._adversarial_loss(tf.ones_like(d_real), d_real)
            d_loss_fake = self._adversarial_loss(tf.zeros_like(d_fake), d_fake)
            d_loss = d_loss_real + d_loss_fake
        
        # Update discriminator
        d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
        self.d_optimizer.apply_gradients(zip(d_grads, self.discriminator.trainable_variables))
        
        # Train generator
        with tf.GradientTape() as g_tape:
            # Generate fake images
            hr_fake = self.generator(lr_images, training=True)
            
            # Discriminator prediction for fake images
            d_fake = self.discriminator(hr_fake, training=True)
            
            # Generator losses
            g_adversarial_loss = self._adversarial_loss(tf.ones_like(d_fake), d_fake)
            g_perceptual_loss = self._perceptual_loss(hr_images, hr_fake)
            g_pixel_loss = self._pixel_loss(hr_images, hr_fake)
            g_spectral_loss = self._spectral_loss(hr_images, hr_fake)
            
            # Combined generator loss
            g_loss = (
                g_adversarial_loss 
                + 1.0 * g_perceptual_loss 
                + 100.0 * g_pixel_loss 
                + 1.0 * g_spectral_loss)
        
        # Update generator
        g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_variables))
        
        return {
            "g_loss": g_loss, 
            "g_adversarial_loss": g_adversarial_loss, 
            "g_perceptual_loss": g_perceptual_loss, 
            "g_pixel_loss": g_pixel_loss, 
            "g_spectral_loss": g_spectral_loss, 
            "d_loss": d_loss, 
            "d_loss_real": d_loss_real, 
            "d_loss_fake": d_loss_fake
        }
    
    def fit(
        self,
        X_train=None,
        Y_train=None,
        train_dataset=None,
        X_val=None,
        Y_val=None,
        val_dataset=None,
        epochs=100,
        batch_size=16,
        steps_per_epoch=None,
        val_steps=None,
        use_augmentation=True,
        use_mix=True,
        augment_validation=False,
        normalize=True
    ):
        """
        Train the ESRGAN model con opción de data augmentation avanzada.

        Opciones de entrada:
        - Proporcionar (X_train, Y_train) y opcionalmente (X_val, Y_val)
        - O proporcionar directamente un train_dataset (tf.data.Dataset) ya preparado

        Parámetros:
        X_train, Y_train: ndarrays en rango [0,1]
        train_dataset: tf.data.Dataset que produce (lr, hr) en [0,1] o [-1,1]
        steps_per_epoch: obligatorio si la fuente es infinita (repeat / generador)
        use_augmentation: si True aplica AdvancedAugmentGenerator sobre X_train/Y_train
        use_mix: controla mixup/cutmix dentro del generador avanzado
        augment_validation: aplica augment también a validación (no recomendado habitual)
        normalize: si True convierte batches de [0,1] a [-1,1]
        """
        # Validaciones básicas
        if train_dataset is None and (X_train is None or Y_train is None):
            raise ValueError("Debe aportar (X_train,Y_train) o un train_dataset")
        if use_augmentation and (X_train is None or Y_train is None):
            raise ValueError("Para use_augmentation=True se requieren X_train e Y_train")

        # Info dispositivo
        devices = tf.config.list_physical_devices('GPU')
        if devices:
            print("Training on GPU:", [d.name for d in devices])
        else:
            print("Training on CPU")

        # Construcción del dataset de entrenamiento
        if use_augmentation:
            aug_seq = AdvancedAugmentGenerator(
                X_train, Y_train, batch_size=batch_size,
                shuffle=True, use_mix=use_mix
            )
            
            output_signature = (
                tf.TensorSpec(shape=(None,)+X_train.shape[1:], dtype=tf.float32),
                tf.TensorSpec(shape=(None,)+Y_train.shape[1:], dtype=tf.float32)
            )
            
            def gen_epoch():
                for i in range(len(aug_seq)):
                    yield aug_seq[i]

            train_dataset = tf.data.Dataset.from_generator(
                gen_epoch,
                output_signature=output_signature
            ).repeat()
            
            if steps_per_epoch is None:
                steps_per_epoch = len(aug_seq)
        elif train_dataset is None:
            # Dataset simple desde arrays
            train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(len(X_train)).batch(batch_size).repeat()
            if steps_per_epoch is None:
                steps_per_epoch = int(np.ceil(len(X_train)/batch_size))
        else:
            # Se proporcionó un dataset externo. Aseguramos batching y repetición si no las trae.
            # (Heurística simple: si no tiene _variant_tensor_attr asumimos que necesita repeat)
            train_dataset = train_dataset.repeat()
            if steps_per_epoch is None:
                raise ValueError("Debe indicar steps_per_epoch cuando aporta un dataset externo")

        # Normalización a [-1,1] si procede
        if normalize:
            train_dataset = train_dataset.map(lambda x,y: (x*2.0 - 1.0, y*2.0 - 1.0), num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

        # Dataset de validación
        val_data_struct = None
        if val_dataset is not None:
            val_data_struct = val_dataset
        elif X_val is not None and Y_val is not None:
            if augment_validation and use_augmentation:
                val_seq = AdvancedAugmentGenerator(
                    X_val, Y_val, batch_size=batch_size,
                    shuffle=False, use_mix=False
                )
                
                output_signature_val = (
                    tf.TensorSpec(shape=(None,)+X_val.shape[1:], dtype=tf.float32),
                    tf.TensorSpec(shape=(None,)+Y_val.shape[1:], dtype=tf.float32)
                )
                def gen_val_batches():
                    for i in range(len(val_seq)):
                        yield val_seq[i]
                
                val_data_struct = tf.data.Dataset.from_generator(
                    gen_val_batches,
                    output_signature=output_signature_val
                )
                
                if val_steps is None:
                    val_steps = len(val_seq)
            else:
                val_data_struct = tf.data.Dataset.from_tensor_slices((X_val, Y_val)).batch(batch_size)
                if val_steps is None:
                    val_steps = int(np.ceil(len(X_val)/batch_size))
        
        if val_data_struct is not None and normalize:
            val_data_struct = val_data_struct.map(lambda x,y: (x*2.0 - 1.0, y*2.0 - 1.0), num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

        # Bucle de entrenamiento
        for epoch in range(epochs):
            print(f"Epoch {epoch + 1}/{epochs}")
            
            # Métricas acumuladas
            epoch_losses = {
                "g_loss": [],
                "g_adversarial_loss": [],
                "g_perceptual_loss": [],
                "g_pixel_loss": [],
                "g_spectral_loss": [], 
                "g_lr": [], 
                "d_loss": [],
                "d_loss_real": [],
                "d_loss_fake": [], 
                "d_lr": [], 
                "psnr": [], 
                "ssim": []
            }

            # Iteración sobre batches
            for step, (lr_batch, hr_batch) in enumerate(train_dataset.take(steps_per_epoch)):
                losses = self._train_step(lr_batch, hr_batch)
                for key, value in losses.items():
                    epoch_losses[key].append(float(value.numpy()))

                # Métricas perceptuales en [0,1]
                hr_fake = self.generator(lr_batch, training=False)
                hr_real_eval = (hr_batch + 1.0) / 2.0
                hr_gen_eval = (hr_fake + 1.0) / 2.0
                psnr_score = tf.reduce_mean(tf.image.psnr(hr_real_eval, hr_gen_eval, max_val=1.0))
                ssim_score = tf.reduce_mean(tf.image.ssim(hr_real_eval, hr_gen_eval, max_val=1.0))
                epoch_losses["psnr"].append(float(psnr_score.numpy()))
                epoch_losses["ssim"].append(float(ssim_score.numpy()))
                epoch_losses["g_lr"].append(float(self.g_optimizer._decayed_lr(tf.float32).numpy()))
                epoch_losses["d_lr"].append(float(self.d_optimizer._decayed_lr(tf.float32).numpy()))

                if (step+1) % 10 == 0 or (step+1) == steps_per_epoch:
                    print(
                        f"  Step {step+1}/{steps_per_epoch} G_loss={epoch_losses['g_loss'][-1]:.4f} "
                        f"D_loss={epoch_losses['d_loss'][-1]:.4f} PSNR={epoch_losses['psnr'][-1]:.2f} "
                        f"SSIM={epoch_losses['ssim'][-1]:.4f}")

            # Resumen epoch
            avg_losses = {k: np.mean(v) for k,v in epoch_losses.items()}
            print(
                f"- Epoch Summary - G_loss: {avg_losses['g_loss']:.4f}, D_loss: {avg_losses['d_loss']:.4f}, "
                f"PSNR: {avg_losses['psnr']:.2f}, SSIM: {avg_losses['ssim']:.4f}")

            # Validación si existe
            if val_data_struct is not None:
                val_psnr, val_ssim = [], []
                for i, (lr_v, hr_v) in enumerate(val_data_struct.take(val_steps)):
                    hr_fake_v = self.generator(lr_v, training=False)
                    hr_real_eval = (hr_v + 1.0) / 2.0
                    hr_gen_eval  = (hr_fake_v + 1.0) / 2.0
                    val_psnr.append(float(tf.reduce_mean(tf.image.psnr(hr_real_eval, hr_gen_eval, 1.0)).numpy()))
                    val_ssim.append(float(tf.reduce_mean(tf.image.ssim(hr_real_eval, hr_gen_eval, 1.0)).numpy()))
                print(f"  Validation -> PSNR: {np.mean(val_psnr):.2f}, SSIM: {np.mean(val_ssim):.4f}")

            self.trained = True
    
    def evaluate(self, test_dataset):
        """
        Evaluate the trained model with a test dataset.
        
        Args:
            test_dataset: Test dataset (tf.data.Dataset)
            
        Returns:
            Dictionary containing evaluation metrics
        """
        
        if not self.trained:
            raise RuntimeError("Model has not been trained.")
        
        print("Evaluating model on test dataset...")
        
        # Initialize metrics
        total_psnr = 0.0
        total_ssim = 0.0
        total_pixel_loss = 0.0
        total_perceptual_loss = 0.0
        num_batches = 0
        
        for lr_batch, hr_batch in test_dataset:
            # Generate high-resolution images
            hr_generated = self.generator(lr_batch, training=False)
            
            # Pixel loss
            pixel_loss = self._pixel_loss(hr_batch, hr_generated)
            total_pixel_loss += eval(pixel_loss)
            
            # Perceptual loss
            perceptual_loss = self._perceptual_loss(hr_batch, hr_generated)
            total_perceptual_loss += eval(perceptual_loss)
            
            # Convert to [0, 1] range for PSNR and SSIM
            hr_real_eval = (hr_batch + 1.0) / 2.0
            hr_gen_eval = (hr_generated + 1.0) / 2.0
            
            # Calculate PSNR and SSIM
            psnr_score = tf.image.psnr(hr_real_eval, hr_gen_eval, max_val=1.0)
            ssim_score = tf.image.ssim(hr_real_eval, hr_gen_eval, max_val=1.0)
            
            total_psnr += eval(mean(psnr_score))
            total_ssim += eval(mean(ssim_score))
            
            num_batches += 1
        
        # Calculate averages
        avg_psnr = total_psnr / num_batches
        avg_ssim = total_ssim / num_batches
        avg_pixel_loss = total_pixel_loss / num_batches
        avg_perceptual_loss = total_perceptual_loss / num_batches
        
        metrics = {
            "avg_psnr": avg_psnr,
            "avg_ssim": avg_ssim,
            "avg_pixel_loss": avg_pixel_loss,
            "avg_perceptual_loss": avg_perceptual_loss
        }
        
        print(f"Evaluation Results:")
        print(f"  Average PSNR: {avg_psnr:.4f}")
        print(f"  Average SSIM: {avg_ssim:.4f}")
        print(f"  Average Pixel Loss: {avg_pixel_loss:.4f}")
        print(f"  Average Perceptual Loss: {avg_perceptual_loss:.4f}")
        
        return metrics
    
    def save(self, directory="models/ESRGAN", scale_factor=2):
        """Save the trained model with a timestamp in the specified directory."""
        
        if not self.trained:
            raise RuntimeError("Cannot save an untrained model.")
        
        os.makedirs(directory, exist_ok=True)
        
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        generator_path = os.path.join(directory, f"ESRGAN_generator_x{scale_factor}_{timestamp}.h5")
        discriminator_path = os.path.join(directory, f"ESRGAN_discriminator_x{scale_factor}_{timestamp}.h5")
        
        self.generator.save(generator_path)
        self.discriminator.save(discriminator_path)
        
        print(f"Generator model saved to {generator_path}")
        print(f"Discriminator model saved to {discriminator_path}")

In [4]:
HR_ROOT = os.path.abspath(os.path.join(os.getcwd(), "../../data/images/HR"))
LR_ROOT = os.path.abspath(os.path.join(os.getcwd(), "../../data/images/LR"))

In [None]:
X, Y = load_dataset_as_patches(HR_ROOT, LR_ROOT, mode="scale", patch_size=ESRGAN_PATCH_SIZE, stride=ESRGAN_STRIDE, scale_factor=ESRGAN_SCALE_FACTOR)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, shuffle=True, random_state=RANDOM_SEED)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, shuffle=True, random_state=RANDOM_SEED)

print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_val shape: {X_val.shape}, Y_val shape: {Y_val.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

X_train shape: (7873, 24, 24, 3), Y_train shape: (7873, 48, 48, 3)
X_val shape: (875, 24, 24, 3), Y_val shape: (875, 48, 48, 3)
X_test shape: (972, 24, 24, 3), Y_test shape: (972, 48, 48, 3)


In [6]:
BATCH_SIZE = 16
EPOCHS = 1          # Ajusta según necesidad

# Se define también el dataset de test (solo para evaluación posterior)
# Se normaliza a [-1,1] para que coincida con lo usado en entrenamiento
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test)).batch(BATCH_SIZE)
test_dataset = test_dataset.map(lambda x,y: (x*2.0 - 1.0, y*2.0 - 1.0), num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

In [7]:
model = ESRGAN()

model.setup_model(
    scale_factor=2, 
    growth_channels=32, 
    num_rrdb_blocks=23, 
    input_shape=X_train.shape[1:],
    output_shape=Y_train.shape[1:],
    from_trained=False
)

GENERATOR SUMMARY
Model: "Generator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 lr_input (InputLayer)          [(None, 24, 24, 3)]  0           []                               
                                                                                                  
 initial_conv (Conv2D)          (None, 24, 24, 64)   1792        ['lr_input[0][0]']               
                                                                                                  
 rrdb_0_dense1_conv1 (Conv2D)   (None, 24, 24, 32)   18464       ['initial_conv[0][0]']           
                                                                                                  
 rrdb_0_dense1_concat1 (Concate  (None, 24, 24, 96)  0           ['initial_conv[0][0]',           
 nate)                                                            'rrdb_

In [8]:
model.fit(
    X_train=X_train, 
    Y_train=Y_train, 
    X_val=X_val, 
    Y_val=Y_val,
    epochs=EPOCHS, 
    batch_size=BATCH_SIZE,
    use_augmentation=True, 
    use_mix=False,
    augment_validation=False
)

Training on GPU: ['/physical_device:GPU:0']
Epoch 1/1
  Step 10/493 G_loss=179.6460 D_loss=1.3576 PSNR=7.93 SSIM=0.0205
  Step 10/493 G_loss=179.6460 D_loss=1.3576 PSNR=7.93 SSIM=0.0205
  Step 20/493 G_loss=142.9199 D_loss=1.3567 PSNR=9.44 SSIM=0.3396
  Step 20/493 G_loss=142.9199 D_loss=1.3567 PSNR=9.44 SSIM=0.3396


KeyboardInterrupt: 

In [None]:
pretrained_model = ESRGAN()

pretrained_model.setup_model(
    from_trained=True, 
    generator_pretrained_path="models/ESRGAN_generator_x2_20250627_164319.h5", 
    discriminator_pretrained_path="models/ESRGAN_discriminator_x2_20250627_164319.h5"
)

- Generator loaded from: models/ESRGAN_generator_x2_20250627_164319.h5
- Discriminator loaded from: models/ESRGAN_discriminator_x2_20250627_164319.h5
- VGG model built for perceptual loss


In [None]:
pretrained_model.evaluate(test_dataset)

Evaluating model on test dataset...
Evaluation Results:
  Average PSNR: 3.8115
  Average SSIM: 0.0027
  Average Pixel Loss: 1.0958
  Average Perceptual Loss: 118.6404


{'avg_psnr': 3.8114942800278633,
 'avg_ssim': 0.002731919829302386,
 'avg_pixel_loss': 1.0958046497114553,
 'avg_perceptual_loss': 118.64038116659894}