In [1]:
import os
import cv2
import datetime
import numpy as np
import tensorflow as tf
from keras.optimizers import Adam
from keras.models import Model, load_model
from sklearn.model_selection import train_test_split
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.layers import Input, Conv2D, Activation, Add, Lambda

In [2]:
def load_dataset(hr_root, lr_root, scale_factor=2):
    """
    Loads HR and LR images from separate folders for EDSR training.
    No upscaling is performed on LR images as EDSR handles this internally.

    Parameters:
        hr_root (str): Root path to HR images (e.g., data/images/HR).
        lr_root (str): Root path to LR images (e.g., data/images/LR).
        scale_factor (int): The scale factor (2, 3, or 4) to verify image dimensions.

    Returns:
        X (np.ndarray): Low-resolution images (model input).
        Y (np.ndarray): High-resolution images (target).
    """

    if not os.path.exists(hr_root) or not os.path.exists(lr_root):
        raise ValueError("Both HR and LR root directories must exist.")
    if not os.path.isdir(hr_root) or not os.path.isdir(lr_root):
        raise ValueError("Both HR and LR root paths must be directories.")

    X, Y = [], []

    def get_all_image_paths(root):
        image_paths = []
        
        for dirpath, _, filenames in os.walk(root):
            for filename in filenames:
                if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff")):
                    image_paths.append(os.path.join(dirpath, filename))
                    
        return sorted(image_paths)

    hr_paths = get_all_image_paths(hr_root)
    lr_paths = get_all_image_paths(lr_root)

    if not hr_paths or not lr_paths:
        raise ValueError("No images found in the specified directories.")

    # Match HR and LR images by filename (after last folder)
    hr_dict = {os.path.basename(p): p for p in hr_paths}
    lr_dict = {os.path.basename(p): p for p in lr_paths}
    common_filenames = sorted(set(hr_dict.keys()) & set(lr_dict.keys()))

    if not common_filenames:
        raise ValueError("No matching filenames found between HR and LR directories.")

    non_common_images = set()
    dimension_mismatches = []

    for fname in common_filenames:
        hr_img = cv2.imread(hr_dict[fname], cv2.IMREAD_COLOR)
        lr_img = cv2.imread(lr_dict[fname], cv2.IMREAD_COLOR)

        if hr_img is None or lr_img is None:
            non_common_images.add(fname)
            continue

        # Convert to RGB and normalize
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

        # Verify that dimensions match the scale factor
        hr_h, hr_w, _ = hr_img.shape
        lr_h, lr_w, _ = lr_img.shape
        
        expected_lr_h = hr_h // scale_factor
        expected_lr_w = hr_w // scale_factor
        
        if lr_h != expected_lr_h or lr_w != expected_lr_w:
            dimension_mismatches.append(
                f"{fname}: HR({hr_h}x{hr_w}) -> Expected LR({expected_lr_h}x{expected_lr_w}), "
                f"Got LR({lr_h}x{lr_w})"
            )
            continue
        
        # Add to dataset (LR as input, HR as target)
        X.append(lr_img)
        Y.append(hr_img)

    if non_common_images:
        print(f"Skipped {len(non_common_images)} images due to loading errors: {', '.join(list(non_common_images)[:5])}{'...' if len(non_common_images) > 5 else ''}")

    if dimension_mismatches:
        print(f"Skipped {len(dimension_mismatches)} images due to dimension mismatches:")
        for mismatch in dimension_mismatches[:3]:  # Show first 3 mismatches
            print(f"  {mismatch}")
        if len(dimension_mismatches) > 3:
            print(f"  ... and {len(dimension_mismatches) - 3} more")

    if not X:
        raise ValueError("No valid image pairs found. Check your data and scale factor.")

    X_array = np.array(X)
    Y_array = np.array(Y)
    
    print(f"Loaded {len(X)} image pairs")
    print(f"LR images shape: {X_array.shape}")
    print(f"HR images shape: {Y_array.shape}")
    
    # Return shapes for reference
    lr_shape = (X_array.shape[1], X_array.shape[2])  # (height, width)
    hr_shape = (Y_array.shape[1], Y_array.shape[2])  # (height, width)

    return X_array, Y_array

In [3]:
def load_patches_edsr(hr_root, lr_root, patch_size_lr=48, scale_factor=2, stride=24, 
                     max_patches_per_image=None):
    """
    Loads HR and LR image patches from separate folders for EDSR training.
    Extracts patches from both LR and HR images maintaining the scale factor relationship.

    Parameters:
        hr_root (str): Root path to HR images.
        lr_root (str): Root path to LR images.
        patch_size_lr (int): Size of LR patches (HR patches will be patch_size_lr * scale_factor).
        scale_factor (int): The scale factor (2, 3, or 4).
        stride (int): Stride for patch extraction.
        max_patches_per_image (int): Maximum patches to extract per image (None for all).

    Returns:
        X (np.ndarray): Low-resolution patches (model input).
        Y (np.ndarray): High-resolution patches (target).
    """
    
    if not os.path.exists(hr_root) or not os.path.exists(lr_root):
        raise ValueError("Both HR and LR root directories must exist.")
    
    patch_size_hr = patch_size_lr * scale_factor
    X, Y = [], []

    def get_all_image_paths(root):
        image_paths = []
        for dirpath, _, filenames in os.walk(root):
            for filename in filenames:
                if filename.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff")):
                    image_paths.append(os.path.join(dirpath, filename))
        return sorted(image_paths)

    hr_paths = get_all_image_paths(hr_root)
    lr_paths = get_all_image_paths(lr_root)

    # Match HR and LR images by filename
    hr_dict = {os.path.basename(p): p for p in hr_paths}
    lr_dict = {os.path.basename(p): p for p in lr_paths}
    common_filenames = sorted(set(hr_dict.keys()) & set(lr_dict.keys()))

    if not common_filenames:
        raise ValueError("No matching filenames found between HR and LR directories.")

    total_patches = 0
    
    for fname in common_filenames:
        hr_img = cv2.imread(hr_dict[fname], cv2.IMREAD_COLOR)
        lr_img = cv2.imread(lr_dict[fname], cv2.IMREAD_COLOR)

        if hr_img is None or lr_img is None:
            continue

        # Convert to RGB and normalize
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

        lr_h, lr_w, _ = lr_img.shape
        hr_h, hr_w, _ = hr_img.shape

        # Extract patches
        patches_this_image = 0
        
        for i in range(0, lr_h - patch_size_lr + 1, stride):
            for j in range(0, lr_w - patch_size_lr + 1, stride):
                # Extract LR patch
                lr_patch = lr_img[i:i+patch_size_lr, j:j+patch_size_lr]
                
                # Extract corresponding HR patch
                hr_i = i * scale_factor
                hr_j = j * scale_factor
                
                if hr_i + patch_size_hr <= hr_h and hr_j + patch_size_hr <= hr_w:
                    hr_patch = hr_img[hr_i:hr_i+patch_size_hr, hr_j:hr_j+patch_size_hr]
                    
                    X.append(lr_patch)
                    Y.append(hr_patch)
                    patches_this_image += 1
                    total_patches += 1
                    
                    if max_patches_per_image and patches_this_image >= max_patches_per_image:
                        break
            
            if max_patches_per_image and patches_this_image >= max_patches_per_image:
                break

    if not X:
        raise ValueError("No patches could be extracted. Check your patch size and image dimensions.")

    X_array = np.array(X)
    Y_array = np.array(Y)
    
    print(f"Extracted {total_patches} patch pairs from {len(common_filenames)} images")
    print(f"LR patches shape: {X_array.shape}")
    print(f"HR patches shape: {Y_array.shape}")

    return X_array, Y_array

In [4]:
def psnr(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def ssim(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

class EDSR:
    def __init__(self):
        self.model = None
        self.trained = False
        self.scale_factor = 2

    def setup_model(self, scale_factor=2, channels=3, num_res_blocks=16, num_filters=64, 
                   res_scaling=0.1, learning_rate=1e-4, loss="mean_absolute_error", 
                   from_pretrained=False, pretrained_path=None):
        """Set up the EDSR model, either by loading a pretrained model or building a new one."""
        
        self.scale_factor = scale_factor
        
        if from_pretrained:
            if pretrained_path is None or not os.path.isfile(pretrained_path):
                raise FileNotFoundError(f"Pretrained model file not found at {pretrained_path}")
            
            self.model = load_model(pretrained_path, custom_objects={"psnr": psnr, "ssim": ssim})
            self.trained = True
            print(f"Loaded pretrained model from {pretrained_path}")
        else:
            self._build_model(scale_factor, channels, num_res_blocks, num_filters, res_scaling)
            self._compile_model(learning_rate, loss)

    def _residual_block(self, x, num_filters, res_scaling):
        """Build a residual block without batch normalization (key feature of EDSR)."""
        
        shortcut = x
        
        # First conv layer
        x = Conv2D(num_filters, (3, 3), padding="same", kernel_initializer="he_normal")(x)
        x = Activation("relu")(x)
        
        # Second conv layer
        x = Conv2D(num_filters, (3, 3), padding="same", kernel_initializer="he_normal")(x)
        
        # Scale the residual
        if res_scaling != 1.0:
            x = Lambda(lambda t: t * res_scaling)(x)
        
        # Add shortcut connection
        x = Add()([x, shortcut])
        
        return x

    def _upsampling_block(self, x, scale_factor, num_filters):
        """Create upsampling block using sub-pixel convolution."""
        
        if scale_factor == 2:
            x = Conv2D(num_filters * 4, (3, 3), padding="same", kernel_initializer="he_normal")(x)
            x = Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
        elif scale_factor == 3:
            x = Conv2D(num_filters * 9, (3, 3), padding="same", kernel_initializer="he_normal")(x)
            x = Lambda(lambda x: tf.nn.depth_to_space(x, 3))(x)
        elif scale_factor == 4:
            # Two 2x upsampling blocks
            x = Conv2D(num_filters * 4, (3, 3), padding="same", kernel_initializer="he_normal")(x)
            x = Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
            x = Conv2D(num_filters * 4, (3, 3), padding="same", kernel_initializer="he_normal")(x)
            x = Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
        else:
            raise ValueError(f"Scale factor {scale_factor} not supported. Use 2, 3, or 4.")
        
        return x

    def _build_model(self, scale_factor, channels, num_res_blocks, num_filters, res_scaling):
        """Construct the EDSR model architecture using functional API."""
        
        inputs = Input(shape=(None, None, channels), name="input")
        
        # Initial convolution (head)
        x = Conv2D(num_filters, (3, 3), padding="same", kernel_initializer="he_normal")(inputs)
        
        # Store for global residual connection
        head_output = x
        
        # Residual blocks (body)
        for i in range(num_res_blocks):
            x = self._residual_block(x, num_filters, res_scaling)
        
        # Final convolution of the body
        x = Conv2D(num_filters, (3, 3), padding="same", kernel_initializer="he_normal")(x)
        
        # Global residual connection
        x = Add()([x, head_output])
        
        # Upsampling blocks (tail)
        x = self._upsampling_block(x, scale_factor, num_filters)
        
        # Final convolution to produce RGB output
        outputs = Conv2D(channels, (3, 3), padding="same", kernel_initializer="he_normal")(x)
        
        self.model = Model(inputs, outputs, name="EDSR")

    def _compile_model(self, learning_rate, loss):
        """Compile the model with Adam optimizer and specified loss, including PSNR and SSIM metrics."""
        
        optimizer = Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
        self.model.compile(optimizer=optimizer, loss=loss, metrics=[psnr, ssim])
        self.model.summary()

    def fit(self, X_train, Y_train, X_val, Y_val, batch_size=16, epochs=300, use_augmentation=False):
        """Train the model using optional image data augmentation and standard callbacks."""
        
        if self.model is None:
            raise ValueError("Model is not built yet.")

        devices = tf.config.list_physical_devices("GPU")
        if devices:
            print("Training on GPU:", devices[0].name)
        else:
            print("Training on CPU")

        callbacks = [
            EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
            ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=1e-7, verbose=1)
        ]

        if use_augmentation:
            datagen = tf.keras.preprocessing.image.ImageDataGenerator(
                rotation_range=10,
                width_shift_range=0.05,
                height_shift_range=0.05,
                horizontal_flip=True
            )
            train_gen = datagen.flow(X_train, Y_train, batch_size=batch_size)
            val_gen = datagen.flow(X_val, Y_val, batch_size=batch_size)

            self.model.fit(
                train_gen,
                steps_per_epoch=len(X_train) // batch_size,
                epochs=epochs,
                validation_data=val_gen,
                validation_steps=len(X_val) // batch_size,
                callbacks=callbacks
            )
        else:
            self.model.fit(
                X_train, Y_train,
                batch_size=batch_size,
                epochs=epochs,
                validation_data=(X_val, Y_val),
                callbacks=callbacks
            )

        self.trained = True

    def evaluate(self, X_test, Y_test):
        """Evaluate the model on test data and print loss, PSNR, and SSIM."""
        if not self.trained:
            raise RuntimeError("Model has not been trained.")

        results = self.model.evaluate(X_test, Y_test)
        print(f"Loss: {results[0]:.4f}, PSNR: {results[1]:.2f} dB, SSIM: {results[2]:.4f}")
        return results

    def super_resolve_image(self, image_path, interpolation=cv2.INTER_CUBIC):
        """Performs super-resolution on a single image."""
        
        if not self.trained:
            raise RuntimeError("Model has not been trained.")
        if not os.path.isfile(image_path):
            raise FileNotFoundError(f"Image file not found at {image_path}")

        # Load and normalize original image
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0
        
        # For EDSR, we input the low-resolution image directly
        # The model will upscale it by the scale factor
        original_shape = img.shape[:2]
        
        # Add batch dimension
        img_batch = np.expand_dims(img, axis=0)
        
        # Predict
        sr_img = self.model.predict(img_batch)[0]
        
        # Clip values to valid range
        sr_img = np.clip(sr_img, 0.0, 1.0)

        return sr_img

    def super_resolve_batch(self, images):
        """Performs super-resolution on a batch of images."""
        
        if not self.trained:
            raise RuntimeError("Model has not been trained.")
        
        if images.ndim != 4:
            raise ValueError("Input should be a 4D array (batch, height, width, channels)")
        
        # Normalize if not already normalized
        if images.max() > 1.0:
            images = images.astype(np.float32) / 255.0
        
        # Predict
        sr_images = self.model.predict(images)
        
        # Clip values to valid range
        sr_images = np.clip(sr_images, 0.0, 1.0)
        
        return sr_images

    def save(self, directory="models/EDSR"):
        """Save the trained model with a timestamp in the specified directory."""
        
        if not self.trained:
            raise RuntimeError("Cannot save an untrained model.")

        os.makedirs(directory, exist_ok=True)
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        path = os.path.join(directory, f"EDSR_x{self.scale_factor}_{timestamp}.h5")
        self.model.save(path)
        print(f"Model saved to {path}")
        return path

In [14]:
X, Y = load_dataset("../../data/images/HR", "../../data/images/LR")
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, shuffle=True, random_state=42)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, shuffle=True, random_state=42)

print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_val shape: {X_val.shape}, Y_val shape: {Y_val.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

Loaded 1611 image pairs
LR images shape: (1611, 119, 119, 3)
HR images shape: (1611, 239, 239, 3)
X_train shape: (1304, 119, 119, 3), Y_train shape: (1304, 239, 239, 3)
X_val shape: (145, 119, 119, 3), Y_val shape: (145, 239, 239, 3)
X_test shape: (162, 119, 119, 3), Y_test shape: (162, 239, 239, 3)


In [5]:
X, Y = load_patches_edsr("../../data/images/HR", "../../data/images/LR", patch_size_lr=24, scale_factor=2, stride=12)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, shuffle=True, random_state=42)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, shuffle=True, random_state=42)

print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_val shape: {X_val.shape}, Y_val shape: {Y_val.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

Extracted 170748 patch pairs from 527 images
LR patches shape: (170748, 24, 24, 3)
HR patches shape: (170748, 48, 48, 3)
X_train shape: (138305, 24, 24, 3), Y_train shape: (138305, 48, 48, 3)
X_val shape: (15368, 24, 24, 3), Y_val shape: (15368, 48, 48, 3)
X_test shape: (17075, 24, 24, 3), Y_test shape: (17075, 48, 48, 3)


In [6]:
model = EDSR()

model.setup_model(scale_factor=2, num_res_blocks=16, num_filters=64)

Model: "EDSR"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input (InputLayer)             [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 conv2d (Conv2D)                (None, None, None,   1792        ['input[0][0]']                  
                                64)                                                               
                                                                                                  
 conv2d_1 (Conv2D)              (None, None, None,   36928       ['conv2d[0][0]']                 
                                64)                                                            

In [7]:
model.fit(X_train, Y_train, X_val, Y_val, batch_size=16, epochs=50, use_augmentation=True)

Training on GPU: /physical_device:GPU:0
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50

KeyboardInterrupt: 

In [None]:
model.evaluate(X_test, Y_test)