In [2]:
import os
import cv2
import datetime
import numpy as np
import tensorflow as tf
from keras import layers, Model
from keras.optimizers import Adam
from keras.models import load_model
from keras.applications import VGG19, vgg19
from sklearn.model_selection import train_test_split
from keras.losses import BinaryCrossentropy, MeanAbsoluteError, MeanSquaredError
from keras.layers import Input, Conv2D, LeakyReLU, Add, Concatenate, GlobalAveragePooling2D, Dense, BatchNormalization, Lambda, Layer

In [3]:
def load_dataset_as_patches(hr_root, lr_root, patch_size_lr=48, stride=24, scale_factor=2, max_patches_per_image=None):
    """
    Loads HR and LR image patches from separate folders for EDSR training.
    Extracts patches from both LR and HR images maintaining the scale factor relationship.

    Parameters:
        hr_root (str): Root path to HR images.
        lr_root (str): Root path to LR images.
        patch_size_lr (int): Size of LR patches (HR patches will be patch_size_lr * scale_factor).
        scale_factor (int): The scale factor (2, 3, or 4).
        stride (int): Stride for patch extraction.
        max_patches_per_image (int): Maximum patches to extract per image (None for all).

    Returns:
        X (np.ndarray): Low-resolution patches (model input).
        Y (np.ndarray): High-resolution patches (target).
    """
    
    def add_padding(image, patch_size, stride):
        """Add padding to ensure full coverage."""
        
        h, w, c = image.shape
        
        # Calcular cuánto padding se necesita
        pad_h = (patch_size - (h % stride)) % stride if h % stride != 0 else 0
        pad_w = (patch_size - (w % stride)) % stride if w % stride != 0 else 0
        
        # Agregar padding extra para asegurar cobertura completa
        pad_h = max(pad_h, patch_size - stride)
        pad_w = max(pad_w, patch_size - stride)
        
        # Padding reflejado (mirror) para mantener continuidad
        padded_img = np.pad(
            image, 
            ((0, pad_h), (0, pad_w), (0, 0)), 
            mode='reflect'
        )
        
        return padded_img
    
    if not os.path.exists(hr_root) or not os.path.exists(lr_root):
        raise ValueError("Both HR and LR root directories must exist.")
    
    patch_size_hr = patch_size_lr * scale_factor
    X, Y = [], []

    def get_all_image_paths(root):
        image_paths = []
        for dirpath, _, filenames in os.walk(root):
            for filename in filenames:
                if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff")):
                    image_paths.append(os.path.join(dirpath, filename))
        return sorted(image_paths)

    hr_paths = get_all_image_paths(hr_root)
    lr_paths = get_all_image_paths(lr_root)

    # Match HR and LR images by filename
    hr_dict = {os.path.basename(p): p for p in hr_paths}
    lr_dict = {os.path.basename(p): p for p in lr_paths}
    common_filenames = sorted(set(hr_dict.keys()) & set(lr_dict.keys()))

    if not common_filenames:
        raise ValueError("No matching filenames found between HR and LR directories.")

    total_patches = 0
    
    for fname in common_filenames:
        hr_img = cv2.imread(hr_dict[fname], cv2.IMREAD_COLOR)
        lr_img = cv2.imread(lr_dict[fname], cv2.IMREAD_COLOR)

        if hr_img is None or lr_img is None:
            continue

        # Convert to RGB and normalize
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

        lr_h, lr_w, _ = lr_img.shape
        hr_h, hr_w, _ = hr_img.shape
        
        # Add padding to ensure full coverage
        hr_img = add_padding(hr_img, patch_size_hr, stride)
        lr_img = add_padding(lr_img, patch_size_lr, stride)

        # Generate patches
        patches_this_image = 0
        
        for i in range(0, lr_h - patch_size_lr + 1, stride):
            for j in range(0, lr_w - patch_size_lr + 1, stride):
                # Extract LR patch
                lr_patch = lr_img[i:i+patch_size_lr, j:j+patch_size_lr]
                
                # Extract corresponding HR patch
                hr_i = i * scale_factor
                hr_j = j * scale_factor
                
                if hr_i + patch_size_hr <= hr_h and hr_j + patch_size_hr <= hr_w:
                    hr_patch = hr_img[hr_i:hr_i+patch_size_hr, hr_j:hr_j+patch_size_hr]
                    
                    X.append(lr_patch)
                    Y.append(hr_patch)
                    
                    patches_this_image += 1
                    total_patches += 1
                    
                    if max_patches_per_image and patches_this_image >= max_patches_per_image:
                        break
            
            if max_patches_per_image and patches_this_image >= max_patches_per_image:
                break

    if not X:
        raise ValueError("No patches could be extracted. Check your patch size and image dimensions.")

    X_array = np.array(X)
    Y_array = np.array(Y)
    
    print(f"Extracted {total_patches} patch pairs from {len(common_filenames)} images")
    print(f"LR patches shape: {X_array.shape}")
    print(f"HR patches shape: {Y_array.shape}")

    return X_array, Y_array

In [4]:
class SpectralNormalization(Layer):
    """
    Spectral Normalization layer for stabilizing GAN training by constraining 
    the Lipschitz constant of the discriminator.
    
    Parameters:
    - layer: The layer to apply spectral normalization to
    - power_iterations: Number of power iterations for spectral norm computation
    
    Returns:
    - Normalized layer output
    """
    def __init__(self, layer, power_iterations=1, **kwargs):
        super().__init__(**kwargs)
        self.layer = layer
        self.power_iterations = power_iterations
        
    def build(self, input_shape):
        self.layer.build(input_shape)
        
        # Get the weight matrix
        if hasattr(self.layer, 'kernel'):
            self.kernel = self.layer.kernel
        else:
            raise ValueError("Layer must have a kernel attribute")
            
        # Initialize u and v vectors for power iteration
        kernel_shape = self.kernel.shape
        self.u = self.add_weight(
            shape=(1, kernel_shape[-1]),
            initializer='random_normal',
            trainable=False,
            name='u'
        )
        
        super().build(input_shape)
    
    def call(self, inputs, training=None):
        # Power iteration method
        u = self.u
        
        for _ in range(self.power_iterations):
            # Reshape kernel for matrix multiplication
            kernel_reshaped = tf.reshape(self.kernel, [-1, self.kernel.shape[-1]])
            
            # v = u @ W^T / ||u @ W^T||
            v = tf.nn.l2_normalize(tf.matmul(u, kernel_reshaped, transpose_b=True))
            
            # u = v @ W / ||v @ W||
            u = tf.nn.l2_normalize(tf.matmul(v, kernel_reshaped))
        
        if training:
            self.u.assign(u)
        
        # Compute spectral norm
        sigma = tf.matmul(tf.matmul(u, kernel_reshaped, transpose_b=True), 
                         tf.transpose(v))
        
        # Normalize the kernel
        self.layer.kernel.assign(self.kernel / sigma)
        
        return self.layer(inputs)

In [None]:
class ESRGAN:
    """
    Enhanced Super-Resolution Generative Adversarial Network (ESRGAN) implementation.
    
    This implementation includes all advanced features: spectral normalization,
    learning rate scheduling, network interpolation, and data augmentation.
    """
    
    def __init__(self, lr_size=(64, 64, 3), scale=2, vgg_layer='block5_conv4'):
        """
        Initialize ESRGAN model.
        
        Parameters:
        - lr_size: Tuple, shape of low-resolution input images (height, width, channels)
        - scale: Int, upscaling factor (2, 4, or 8)
        - vgg_layer: String, VGG layer name for perceptual loss computation
        
        Returns:
        - None (initializes class attributes)
        """
        self.lr_size = lr_size
        self.hr_size = (lr_size[0]*scale, lr_size[1]*scale, 3)
        self.scale = scale
        self.vgg_layer = vgg_layer
        
        # Calculate number of upsampling blocks based on scale
        self.num_upsample_blocks = int(np.log2(scale))
        if 2**self.num_upsample_blocks != scale:
            raise ValueError(f"Scale must be a power of 2 (2, 4, 8). Got {scale}")
        
        # Models will be initialized in setup_model
        self.generator = None
        self.discriminator = None
        self.vgg = None
        self.psnr_generator = None
        
        # Loss functions
        self.bce = BinaryCrossentropy(from_logits=False)
        self.l1 = MeanAbsoluteError()
        self.mse = MeanSquaredError()
        
        # Learning rate configuration
        self.initial_lr = 1e-4
        self.g_optimizer = None
        self.d_optimizer = None
        self.g_lr_schedule = None
        self.d_lr_schedule = None
        self._trained = False

    def setup_model(
            self, 
            num_blocks=23, 
            filters=64, 
            growth_channels=32, 
            beta=0.2, 
            use_spectral_norm=True, 
            from_pretrained=False, 
            pretrained_path=None):
        """
        Setup the ESRGAN model either from scratch or from pretrained weights.
        
        Parameters:
        - from_pretrained: Bool, whether to load from pretrained weights
        - pretrained_path: String, path to pretrained model files (without extension)
        
        Returns:
        - None (initializes model components)
        """
        generator_file = f"{pretrained_path}_generator.h5" if pretrained_path else None
        discriminator_file = f"{pretrained_path}_discriminator.h5" if pretrained_path else None
        psnr_generator_file = f"{pretrained_path}_psnr_generator.h5" if pretrained_path else None

        if from_pretrained and pretrained_path and os.path.isfile(generator_file) and os.path.isfile(discriminator_file):
            print("Loading pretrained models...")
            self.generator = load_model(generator_file, compile=False)
            self.discriminator = load_model(discriminator_file, compile=False)
            if os.path.isfile(psnr_generator_file):
                self.psnr_generator = load_model(psnr_generator_file, compile=False)
            else:
                print("PSNR generator not found, will be created if needed...")
            print("Pretrained models loaded successfully!")
            self._trained = True
        else:
            print("Building new models...")
            self._build_new_models(num_blocks, filters, growth_channels, beta, use_spectral_norm)
            self._trained = True

        # Always build VGG for perceptual loss
        self.vgg = self._build_vgg()
        # Setup optimizers and learning rate schedules
        self._setup_optimizers()

    def _build_new_models(self, num_blocks, filters, growth_channels, beta, use_spectral_norm):
        """
        Build new generator and discriminator models from scratch.
        
        Parameters:
        - None
        
        Returns:
        - None (sets self.generator and self.discriminator)
        """
        self.generator = self._build_generator(num_blocks, filters, growth_channels, beta)
        self.discriminator = self._build_discriminator(use_spectral_norm)

    def _setup_optimizers(self):
        """
        Setup optimizers and learning rate schedules.
        
        Parameters:
        - None
        
        Returns:
        - None (initializes optimizers and schedules)
        """
        # Learning rate schedules
        self.g_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=self.initial_lr,
            decay_steps=50000,
            decay_rate=0.5,
            staircase=True
        )
        self.d_lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=self.initial_lr,
            decay_steps=50000,
            decay_rate=0.5,
            staircase=True
        )
        
        # Optimizers
        self.g_optimizer = Adam(self.initial_lr, beta_1=0.9)
        self.d_optimizer = Adam(self.initial_lr, beta_1=0.9)

    def _dense_block(self, x, filters, growth_channels, beta):
        """
        Dense block with growth connections for feature reuse.
        
        Parameters:
        - x: Tensor, input feature maps
        - filters: Int, number of output filters
        - growth_channels: Int, number of channels added by each layer
        - beta: Float, residual scaling factor
        
        Returns:
        - Tensor, output feature maps with residual connection
        """
        inputs = x
        concat_layers = [x]
        
        for i in range(5):
            if i > 0:
                x = Concatenate()(concat_layers)
            
            out = Conv2D(growth_channels, 3, padding='same')(x)
            out = LeakyReLU(0.2)(out)
            concat_layers.append(out)
        
        x = Concatenate()(concat_layers)
        out = Conv2D(filters, 1, padding='same')(x)
        out = Lambda(lambda x: x * beta)(out)
        return Add()([out, inputs])

    def _rrdb_block(self, x, filters, growth_channels, beta):
        """
        Residual-in-Residual Dense Block - core building block of ESRGAN generator.
        
        Parameters:
        - x: Tensor, input feature maps
        - filters: Int, number of filters
        - beta: Float, residual scaling factor for stability
        
        Returns:
        - Tensor, output feature maps with nested residual connections
        """
        inputs = x
        
        for _ in range(3):
            x = self._dense_block(x, filters, growth_channels, beta)
        
        x = Lambda(lambda x: x * beta)(x)
        return Add()([inputs, x])

    def _build_generator(self, num_blocks, filters, growth_channels, beta):
        """
        Build the ESRGAN generator network with RRDB blocks.
        
        Parameters:
        - num_blocks: Int, number of RRDB blocks
        - filters: Int, number of base filters
        
        Returns:
        - Model, Keras model for the generator
        """
        inputs = Input(shape=self.lr_size)
        
        x = Conv2D(filters, 3, padding='same')(inputs)
        conv1 = x

        for _ in range(num_blocks):
            x = self._rrdb_block(x, filters, growth_channels, beta)

        x = Conv2D(filters, 3, padding='same')(x)
        x = Add()([x, conv1])

        # Upsampling blocks (adaptive based on scale)
        for _ in range(self.num_upsample_blocks):
            x = Conv2D(filters * 4, 3, padding='same')(x)
            x = Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
            x = LeakyReLU(0.2)(x)

        out = Conv2D(3, 3, padding='same', activation='tanh')(x)
        return Model(inputs, out, name="Generator")

    def _build_discriminator(self, use_spectral_norm):
        """
        Build the discriminator network with optional spectral normalization.
        
        Parameters:
        - use_spectral_norm: Bool, whether to apply spectral normalization
        
        Returns:
        - Model, Keras model for the discriminator
        """
        def d_block(x, filters, strides=1, bn=True, sn=True):
            """
            Discriminator convolutional block.
            
            Parameters:
            - x: Tensor, input
            - filters: Int, number of filters
            - strides: Int, convolution stride
            - bn: Bool, whether to use batch normalization
            - sn: Bool, whether to use spectral normalization
            
            Returns:
            - Tensor, processed feature maps
            """
            if sn and use_spectral_norm:
                conv_layer = Conv2D(filters, 3, strides=strides, padding='same')
                x = SpectralNormalization(conv_layer)(x)
            else:
                x = Conv2D(filters, 3, strides=strides, padding='same')(x)
            
            if bn:
                x = BatchNormalization()(x)
            return LeakyReLU(0.2)(x)

        inputs = layers.Input(shape=self.hr_size)
        x = d_block(inputs, 64, bn=False)
        x = d_block(x, 64, strides=2)
        x = d_block(x, 128)
        x = d_block(x, 128, strides=2)
        x = d_block(x, 256)
        x = d_block(x, 256, strides=2)
        x = d_block(x, 512)
        x = d_block(x, 512, strides=2)
        
        x = GlobalAveragePooling2D()(x)
        x = Dense(1024)(x)
        x = LeakyReLU(0.2)(x)
        out = Dense(1)(x)
        
        return Model(inputs, out, name="Discriminator")

    def _build_vgg(self):
        """
        Build VGG19 network for perceptual loss computation.
        
        Parameters:
        - None
        
        Returns:
        - Model, VGG19 model for feature extraction
        """
        vgg = VGG19(weights='imagenet', include_top=False, input_shape=self.hr_size)
        vgg.trainable = False
        
        def preprocess_vgg(x):
            """
            Preprocess images for VGG19 input.
            
            Parameters:
            - x: Tensor, input images in [-1, 1] range
            
            Returns:
            - Tensor, preprocessed images for VGG19
            """
            x = (x + 1.0) * 127.5
            return vgg19.preprocess_input(x)
        
        inputs = Input(shape=self.hr_size)
        x = Lambda(preprocess_vgg)(inputs)
        x = vgg(x)
        features = vgg.get_layer(self.vgg_layer)(x)
        
        return Model(inputs, features)

    def _data_augmentation(self, lr_batch, hr_batch):
        """
        Apply data augmentation to training batches.
        
        Parameters:
        - lr_batch: Tensor, batch of low-resolution images
        - hr_batch: Tensor, batch of high-resolution images
        
        Returns:
        - Tuple of Tensors, augmented (lr_batch, hr_batch)
        """
        batch_size = tf.shape(lr_batch)[0]
        
        # Random horizontal flip
        flip_prob = tf.random.uniform([batch_size, 1, 1, 1])
        lr_batch = tf.where(flip_prob < 0.5, 
                           tf.image.flip_left_right(lr_batch), lr_batch)
        hr_batch = tf.where(flip_prob < 0.5, 
                           tf.image.flip_left_right(hr_batch), hr_batch)
        
        # Random vertical flip
        flip_prob = tf.random.uniform([batch_size, 1, 1, 1])
        lr_batch = tf.where(flip_prob < 0.5, 
                           tf.image.flip_up_down(lr_batch), lr_batch)
        hr_batch = tf.where(flip_prob < 0.5, 
                           tf.image.flip_up_down(hr_batch), hr_batch)
        
        # Random 90-degree rotations
        k = tf.random.uniform([batch_size], maxval=4, dtype=tf.int32)
        lr_batch = tf.image.rot90(lr_batch, k=k[0])  # Simplified for batch
        hr_batch = tf.image.rot90(hr_batch, k=k[0])
        
        return lr_batch, hr_batch

    def _build_psnr_generator(self, num_blocks, filters, growth_channels, beta):
        """
        Build and setup PSNR-oriented generator for network interpolation.
        
        Parameters:
        - num_blocks: Int, number of RRDB blocks in PSNR generator
        - filters: Int, number of filters in PSNR generator
        - growth_channels: Int, number of growth channels in PSNR generator
        - beta: Float, residual scaling factor for PSNR generator
        
        Returns:
        - Function, PSNR training step function
        """
        self.psnr_generator = self._build_generator(num_blocks, filters, growth_channels, beta)
        
        # PSNR training uses only MSE loss
        psnr_optimizer = Adam(1e-4)
        
        @tf.function
        def psnr_train_step(lr, hr):
            """
            Single PSNR training step.
            
            Parameters:
            - lr: Tensor, low-resolution images
            - hr: Tensor, high-resolution images
            
            Returns:
            - Tensor, MSE loss value
            """
            with tf.GradientTape() as tape:
                sr = self.psnr_generator(lr, training=True)
                loss = self.mse(hr, sr)
            
            grads = tape.gradient(loss, self.psnr_generator.trainable_variables)
            psnr_optimizer.apply_gradients(zip(grads, self.psnr_generator.trainable_variables))
            return loss
        
        return psnr_train_step

    def _interpolate_networks(self, alpha=0.2):
        """
        Perform network interpolation between PSNR and GAN models.
        
        Parameters:
        - alpha: Float, interpolation factor (0=full GAN, 1=full PSNR)
        
        Returns:
        - None (modifies generator weights in-place)
        """
        if self.psnr_generator is None:
            raise ValueError("PSNR generator not built. Call build_psnr_generator first.")
        
        # Interpolate weights
        for psnr_var, gan_var in zip(self.psnr_generator.trainable_variables, 
                                    self.generator.trainable_variables):
            interpolated = alpha * psnr_var + (1 - alpha) * gan_var
            gan_var.assign(interpolated)

    def _perceptual_loss(self, hr, sr):
        """
        Compute perceptual loss using VGG19 features.
        
        Parameters:
        - hr: Tensor, high-resolution ground truth images
        - sr: Tensor, super-resolved images
        
        Returns:
        - Tensor, perceptual loss value
        """
        sr_features = self.vgg(sr)
        hr_features = self.vgg(hr)
        return self.l1(hr_features, sr_features)

    def _relativistic_discriminator_loss(self, real_logits, fake_logits):
        """
        Compute relativistic discriminator loss.
        
        Parameters:
        - real_logits: Tensor, discriminator output for real images
        - fake_logits: Tensor, discriminator output for fake images
        
        Returns:
        - Tensor, relativistic discriminator loss
        """
        real_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.ones_like(real_logits),
                logits=real_logits - tf.reduce_mean(fake_logits)
            )
        )
        fake_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(fake_logits),
                logits=fake_logits - tf.reduce_mean(real_logits)
            )
        )
        return real_loss + fake_loss

    def _relativistic_generator_loss(self, real_logits, fake_logits):
        """
        Compute relativistic generator loss.
        
        Parameters:
        - real_logits: Tensor, discriminator output for real images
        - fake_logits: Tensor, discriminator output for fake images
        
        Returns:
        - Tensor, relativistic generator loss
        """
        fake_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.ones_like(fake_logits),
                logits=fake_logits - tf.reduce_mean(real_logits)
            )
        )
        real_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(real_logits),
                logits=real_logits - tf.reduce_mean(fake_logits)
            )
        )
        return fake_loss + real_loss

    def _update_learning_rates(self, step):
        """
        Update learning rates according to schedules.
        
        Parameters:
        - step: Int, current training step
        
        Returns:
        - None (updates optimizer learning rates)
        """
        new_g_lr = self.g_lr_schedule(step)
        new_d_lr = self.d_lr_schedule(step)
        
        self.g_optimizer.learning_rate.assign(new_g_lr)
        self.d_optimizer.learning_rate.assign(new_d_lr)

    @tf.function(reduce_retracing=True)
    def _train_step(self, lr, hr, step, use_augmentation):
        """
        Single training step for ESRGAN.
        
        Parameters:
        - lr: Tensor, batch of low-resolution images
        - hr: Tensor, batch of high-resolution images
        - step: Tensor, current training step
        - use_augmentation: Bool, whether to apply data augmentation
        
        Returns:
        - Dict, training metrics and losses
        """
        if use_augmentation:
            lr, hr = self._data_augmentation(lr, hr)
        
        with tf.GradientTape(persistent=True) as tape:
            sr = self.generator(lr, training=True)

            real_logits = self.discriminator(hr, training=True)
            fake_logits = self.discriminator(sr, training=True)

            # Discriminator loss
            d_loss = self._relativistic_discriminator_loss(real_logits, fake_logits)

            # Generator losses
            perceptual = self._perceptual_loss(hr, sr)
            adv_loss = self._relativistic_generator_loss(real_logits, fake_logits)
            pixel_loss = self.l1(hr, sr)

            # Combined generator loss
            g_loss = 5e-3 * adv_loss + 1.0 * perceptual + 1e-2 * pixel_loss

        # Apply gradients
        grads_g = tape.gradient(g_loss, self.generator.trainable_variables)
        grads_d = tape.gradient(d_loss, self.discriminator.trainable_variables)

        self.g_optimizer.apply_gradients(zip(grads_g, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(grads_d, self.discriminator.trainable_variables))

        # Update learning rates
        self._update_learning_rates(step)

        return {
            "psnr": tf.image.psnr(hr, sr, max_val=2.0),  # Assuming tanh output [-1, 1]
            "ssim": tf.image.ssim(hr, sr, max_val=2.0),  # Assuming tanh output [-1, 1]
            'g_loss': g_loss,
            'd_loss': d_loss,
            'perceptual_loss': perceptual,
            'pixel_loss': pixel_loss,
            'adv_loss': adv_loss,
            'g_lr': self.g_optimizer.learning_rate,
            'd_lr': self.d_optimizer.learning_rate
        }

    def _train_psnr_phase(self, dataset, epochs, num_blocks, filters, growth_channels, beta):
        """
        Pre-train generator with PSNR (MSE) loss only.
        
        Parameters:
        - dataset: tf.data.Dataset, training dataset
        - epochs: Int, number of PSNR training epochs
        - num_blocks: Int, number of RRDB blocks in PSNR generator
        - filters: Int, number of filters in PSNR generator
        - growth_channels: Int, number of growth channels in PSNR generator
        - beta: Float, residual scaling factor for PSNR generator
        
        Returns:
        - None (trains self.psnr_generator)
        """
        print("Starting PSNR pre-training...")
        psnr_train_step = self._build_psnr_generator(num_blocks, filters, growth_channels, beta)
        
        for epoch in range(epochs):
            print(f"PSNR Epoch {epoch+1}/{epochs}")
            for step, (lr, hr) in enumerate(dataset):
                loss = psnr_train_step(lr, hr)
                if step % 100 == 0:
                    print(f"Step {step}: PSNR Loss: {float(loss):.4f}")

    def fit(
            self, 
            dataset, 
            epochs=100, 
            psnr_epochs=0, 
            interpolation_alpha=0.2, 
            use_agmentation=True, 
            val_dataset=None, 
            val_metrics=None,
            val_max_batches=None, 
            num_blocks=23, 
            filters=64, 
            growth_channels=32, 
            beta=0.2):
        """
        Full training pipeline with optional PSNR pre-training and network interpolation.
        
        Parameters:
        - dataset: tf.data.Dataset, training dataset with (lr, hr) pairs
        - epochs: Int, number of GAN training epochs
        - psnr_epochs: Int, number of PSNR pre-training epochs (0 to skip)
        - interpolation_alpha: Float, network interpolation factor
        - use_agmentation: Bool, whether to apply data augmentation
        - val_dataset: tf.data.Dataset, validation dataset for evaluation
        - val_metrics: List of metric functions for validation
        - val_max_batches: Int or None, maximum number of validation batches to evaluate
        - num_blocks: Int, number of RRDB blocks in generator
        - filters: Int, number of filters in generator
        - growth_channels: Int, number of growth channels in generator
        - beta: Float, residual scaling factor for generator
        
        Returns:
        - None (trains the model)
        """
        
        if self.generator is None:
            raise ValueError("Model not setup. Call setup_model() first.")
            
        global_step = 0
        
        # PSNR pre-training
        if psnr_epochs > 0:
            self._train_psnr_phase(dataset, psnr_epochs, num_blocks, filters, growth_channels, beta)
            # Copy PSNR weights to main generator
            for psnr_var, gan_var in zip(self.psnr_generator.trainable_variables, self.generator.trainable_variables):
                gan_var.assign(psnr_var)
        
        # GAN training
        print("Starting GAN training...")
        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            for step, (lr, hr) in enumerate(dataset):
                logs = self._train_step(lr, hr, global_step, use_agmentation)
                global_step += 1
                
                if step % 10 == 0:
                    print({k: f"{tf.reduce_mean(v).numpy():.4f}" for k, v in logs.items()})
                
                # Network interpolation every 1000 steps
                if global_step % 1000 == 0 and self.psnr_generator is not None:
                    self._interpolate_networks(interpolation_alpha)
                    print(f"Applied network interpolation with alpha={interpolation_alpha}")
            
            # Validation step
            if val_dataset is not None:
                self._evaluate(val_dataset, metrics=val_metrics, max_batches=val_max_batches)
                    
        self._trained = True
        
    def _evaluate(self, dataset, metrics=None, max_batches=None):
        """
        Evaluate the trained model on a dataset.

        Parameters:
        - dataset: tf.data.Dataset, dataset of (lr, hr) pairs
        - metrics: list of metric functions, each should accept (y_true, y_pred)
        - max_batches: int or None, maximum number of batches to evaluate

        Returns:
        - results: dict of metric names and their average values
        """
        if not self._trained:
            raise ValueError("Model is not trained or loaded. Please train or load a pretrained model first.")

        if metrics is None:
            # Default metrics: PSNR and L1 loss
            def psnr(y_true, y_pred):
                return tf.image.psnr(y_true, y_pred, max_val=2.0)  # assuming tanh output [-1,1]
            def ssim(y_true, y_pred):
                return tf.image.ssim(y_true, y_pred, max_val=2.0)  # assuming tanh output [-1,1]
            def l1(y_true, y_pred):
                return tf.reduce_mean(tf.abs(y_true - y_pred))
            metrics = [("PSNR", psnr), ("SSIM", ssim), ("L1", l1)]

        results = {name: [] for name, _ in metrics}
        for i, (lr, hr) in enumerate(dataset):
            if max_batches is not None and i >= max_batches:
                break
            sr = self.generator(lr, training=False)
            for name, func in metrics:
                results[name].append(func(hr, sr).numpy())

        # Average results
        avg_results = {name: float(np.mean(vals)) for name, vals in results.items()}
        print("Evaluation results:", avg_results)
        return avg_results

    def super_resolve(self, lr_image):
        """
        Super-resolve a single image or batch of images.
        
        Parameters:
        - lr_image: Tensor, low-resolution image(s) to super-resolve
        
        Returns:
        - Tensor, super-resolved image(s)
        """
        if not self._trained:
            raise ValueError("Model is not trained or loaded. Please train or load a pretrained model first.")
            
        if len(lr_image.shape) == 3:
            lr_image = tf.expand_dims(lr_image, axis=0)
        sr = self.generator(lr_image, training=False)
        if lr_image.shape[0] == 1:
            sr = tf.squeeze(sr, axis=0)
        return sr

    def save(self, directory="models/ESRGAN"):
        """
        Save all model components to disk.
        
        Parameters:
        - path: String, base path for saving models (without extension)
        
        Returns:
        - None (saves models to disk)
        """
        if not self._trained:
            raise ValueError("Model is not trained or loaded. Please train or load a pretrained model first.")
        
        os.makedirs(directory, exist_ok=True)
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        generator_path = os.path.join(directory, f"ESRGAN_x{self.scale}_{timestamp}_generator.h5")
        discriminator_path = os.path.join(directory, f"ESRGAN_x{self.scale}_{timestamp}_discriminator.h5")
        psnr_generator_path = os.path.join(directory, f"ESRGAN_x{self.scale}_{timestamp}_psnr_generator.h5")
        self.generator.save(generator_path)
        self.discriminator.save(discriminator_path)
        if self.psnr_generator is not None:
            self.psnr_generator.save(psnr_generator_path)
        print(f"Generator model saved to {generator_path}")
        print(f"Discriminator model saved to {discriminator_path}")
        print(f"PSNR generator model saved to {psnr_generator_path}")

In [6]:
PATCH_SIZE_LR = 24
STRIDE = 12
SCALE_FACTOR = 2
NUM_BLOCKS = 23
FILTERS = 64
GROWTH_CHANNELS = 32
BETA = 0.2

In [7]:
X, Y = load_dataset_as_patches("../../data/images/HR", "../../data/images/LR", patch_size_lr=PATCH_SIZE_LR, stride=STRIDE, scale_factor=SCALE_FACTOR)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, shuffle=True, random_state=42)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, shuffle=True, random_state=42)

print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_val shape: {X_val.shape}, Y_val shape: {Y_val.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

Extracted 274104 patch pairs from 846 images
LR patches shape: (274104, 24, 24, 3)
HR patches shape: (274104, 48, 48, 3)
X_train shape: (222023, 24, 24, 3), Y_train shape: (222023, 48, 48, 3)
X_val shape: (24670, 24, 24, 3), Y_val shape: (24670, 48, 48, 3)
X_test shape: (27411, 24, 24, 3), Y_test shape: (27411, 48, 48, 3)


In [8]:
# Use a generator to avoid memory issues
def data_generator(X, Y):
    for x, y in zip(X, Y):
        yield x, y

output_signature = (
    tf.TensorSpec(shape=X_train.shape[1:], dtype=tf.float32),
    tf.TensorSpec(shape=Y_train.shape[1:], dtype=tf.float32)
)

BATCH_SIZE = 4

train_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(X_train, Y_train),
    output_signature=output_signature
).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(X_val, Y_val),
    output_signature=output_signature
).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(X_test, Y_test),
    output_signature=output_signature
).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

In [17]:
model = ESRGAN(lr_size=(PATCH_SIZE_LR, PATCH_SIZE_LR, 3), scale=SCALE_FACTOR)

model.setup_model(
    num_blocks=NUM_BLOCKS, 
    filters=FILTERS, 
    growth_channels=GROWTH_CHANNELS, 
    beta=BETA, 
    use_spectral_norm=True, 
    from_pretrained=False
)

Building new models...


In [18]:
model.fit(
    train_dataset,
    epochs=10,
    interpolation_alpha=0.2,
    psnr_epochs=5, 
    use_agmentation=True,
    val_dataset=val_dataset, 
    num_blocks=NUM_BLOCKS,
    filters=FILTERS,
    growth_channels=GROWTH_CHANNELS,
    beta=BETA
)

Starting PSNR pre-training...
PSNR Epoch 1/5
Step 0: PSNR Loss: 1.9934
Step 100: PSNR Loss: 0.1170
Step 200: PSNR Loss: 0.1068
Step 300: PSNR Loss: 0.1157
Step 400: PSNR Loss: 0.1204
Step 500: PSNR Loss: 0.1498


KeyboardInterrupt: 