In [11]:
import os
import cv2
import keras
import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.layers import Conv2D, InputLayer
from keras.optimizers import Adam
from keras.models import Sequential, load_model
from sklearn.model_selection import train_test_split
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
def load_images(folder_path, scale=2, patch_size=33, stride=14):
    """
    Loads RGB high-resolution images and generates matching low-resolution patches.

    Parameters:
        folder_path (str): Path to folder containing HR images.
        scale (int): Downscaling factor (e.g., 2, 3, 4).
        patch_size (int): Size of output patch (default 33x33).
        stride (int): Step size between patches.

    Returns:
        X (np.ndarray): Low-resolution RGB image patches (model input).
        Y (np.ndarray): High-resolution RGB image patches (target).
    """
    
    X, Y = [], []

    for filename in os.listdir(folder_path):
        if filename.lower().endswith((".jpg", ".jpeg", ".png")):
            img_path = os.path.join(folder_path, filename)
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)  # RGB
            if img is None:
                continue

            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = img.astype(np.float32) / 255.0

            h, w, _ = img.shape
            h_scaled, w_scaled = h // scale, w // scale

            # Simulate LR image by downscaling and then upscaling
            img_lr = cv2.resize(img, (w_scaled, h_scaled), interpolation=cv2.INTER_CUBIC)
            img_lr_up = cv2.resize(img_lr, (w, h), interpolation=cv2.INTER_CUBIC)

            # Extract patches
            for i in range(0, h - patch_size + 1, stride):
                for j in range(0, w - patch_size + 1, stride):
                    hr_patch = img[i:i+patch_size, j:j+patch_size, :]
                    lr_patch = img_lr_up[i:i+patch_size, j:j+patch_size, :]

                    X.append(lr_patch)
                    Y.append(hr_patch)

    return np.array(X), np.array(Y)

In [None]:
class SRCNNModel:
    def __init__(self):
        self.model = None
        self._trained = False
        
    def _psnr(y_true, y_pred):
        max_pixel = 1.0
        
        return keras.metrics.PSNR(max_val=max_pixel)(y_true, y_pred)

    def setup_model(self, input_shape=(33, 33, 1), learning_rate=1e-4, loss="mean_squared_error", from_pretrained=False, pretrained_path=None):
        """Sets up the model: either loads pretrained or builds + compiles a new model."""
        
        if from_pretrained:
            if pretrained_path is None or not os.path.isfile(pretrained_path):
                raise FileNotFoundError(f"Pretrained model file not found at {pretrained_path}")
            
            self.model = load_model(pretrained_path, custom_objects={"psnr": self._psnr})
            print(f"Loaded pretrained model from {pretrained_path}")
            self._trained = True
        else:
            self._build_model(input_shape)
            self._compile_model(learning_rate, loss)

    def _build_model(self, input_shape):
        """Builds the SRCNN model using Sequential API."""
        
        self.model = Sequential([
            InputLayer(input_shape=input_shape), 
            Conv2D(64, (9, 9), activation="relu", padding="same"),
            Conv2D(32, (1, 1), activation="relu", padding="same"),
            Conv2D(1, (5, 5), activation="linear", padding="same")
        ])

    def _compile_model(self, learning_rate, loss):
        """Compiles the model."""
        
        optimizer = Adam(learning_rate=learning_rate)
        self.model.compile(optimizer=optimizer, loss=loss, metrics=[self._psnr])

    def fit(self, X, Y, batch_size=16, epochs=10, validation_split=0.1, use_augmentation=False):
        """Trains the model with optional data augmentation and callbacks."""
        
        if self.model is None:
            raise ValueError("Model has not been set up.")
        
        devices = tf.config.list_physical_devices("GPU")
        if devices:
            print("Training on GPU:", devices[0].name)
        else:
            print("Training on CPU")

        # Callbacks
        callbacks = [
            EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
            ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, min_lr=1e-6, verbose=1)
        ]

        if use_augmentation:
            datagen = ImageDataGenerator(
                rotation_range=15,
                width_shift_range=0.1,
                height_shift_range=0.1,
                zoom_range=0.1,
                horizontal_flip=True
            )
            self.model.fit(
                datagen.flow(X, Y, batch_size=batch_size),
                steps_per_epoch=len(X) // batch_size,
                epochs=epochs,
                validation_split=validation_split,
                callbacks=callbacks
            )
        else:
            self.model.fit(
                X, Y,
                batch_size=batch_size,
                epochs=epochs,
                validation_split=validation_split,
                callbacks=callbacks
            )

        self._trained = True

    def evaluate(self, X_test, Y_test):
        """Evaluates the model."""
        
        if not self._trained:
            raise RuntimeError("Model has not been trained.")
        
        results = self.model.evaluate(X_test, Y_test)
        print(f"Loss: {results[0]:.4f}, PSNR: {results[1]:.2f} dB")
        
        return results
    
    def super_resolve_image(self, image_path, scale=2, patch_size=33, stride=14):
        """Performs super-resolution on a single image."""
        
        if not self._trained:
            raise RuntimeError("Model has not been trained.")
        if not os.path.isfile(image_path):
            raise FileNotFoundError(f"Image file not found at {image_path}")
        
        def extract_patches_from_image(image, patch_size=33, stride=14):
            """Extracts patches from an image."""
            
            h, w, _ = image.shape
            patches = []
            positions = []

            for i in range(0, h - patch_size + 1, stride):
                for j in range(0, w - patch_size + 1, stride):
                    patch = image[i:i+patch_size, j:j+patch_size, :]
                    patches.append(patch)
                    positions.append((i, j))

            return np.array(patches), positions, h, w

        def reconstruct_from_patches(patches, positions, image_shape, patch_size=33):
            """Reconstructs an image from patches."""
            
            h, w = image_shape[:2]
            reconstructed = np.zeros((h, w, 3), dtype=np.float32)
            weight = np.zeros((h, w, 3), dtype=np.float32)

            for patch, (i, j) in zip(patches, positions):
                reconstructed[i:i+patch_size, j:j+patch_size, :] += patch
                weight[i:i+patch_size, j:j+patch_size, :] += 1.0

            reconstructed /= np.maximum(weight, 1e-8)  # avoid division by zero
            return np.clip(reconstructed, 0, 1)

        # Load and normalize original image
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0

        # TODO: dont downscale, already LR, just upscale
        # Downscale and upscale to simulate LR input
        h, w = img.shape[:2]
        img_lr = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_CUBIC)
        img_lr_up = cv2.resize(img_lr, (w, h), interpolation=cv2.INTER_CUBIC)

        # Patchify
        patches, positions, _, _ = extract_patches_from_image(img_lr_up, patch_size, stride)
        patches = np.array(patches)

        # Predict
        preds = self.model.predict(patches, batch_size=16)

        # Reconstruct
        sr_img = reconstruct_from_patches(preds, positions, img.shape, patch_size)

        return sr_img

    def save(self, directory="models"):
        """Saves the model to a .h5 file with a timestamp."""
        
        if not self._trained:
            raise RuntimeError("Cannot save an untrained model.")
        
        os.makedirs(directory, exist_ok=True)
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filepath = os.path.join(directory, f"SRCNN_{timestamp}.h5")
        self.model.save(filepath)
        print(f"Model saved to {filepath}")

In [None]:
model = SRCNNModel()

model.setup_model()

In [None]:
X, Y = load_images("path_to_dataset")
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1)

In [None]:
model.fit(X_train, Y_train, epochs=50, use_augmentation=True)

In [None]:
model.evaluate(X_test, Y_test)

In [None]:
model.save()

In [None]:
model = SRCNNModel()

model.setup_model(from_pretrained=True, pretrained_path="path_to_pretrained_model.h5")

In [None]:
sr_image = model.super_resolve_image("path/to/new_image.jpg")
sr_image_uint8 = (sr_image * 255).astype(np.uint8)

plt.imshow(sr_image_uint8)
plt.axis('off')
plt.title("Super-Resolved Image")
plt.tight_layout()
plt.show()