In [1]:
import os
import cv2
import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.optimizers import SGD
from keras.models import Model, load_model
from sklearn.model_selection import train_test_split
from keras.layers import Input, Conv2D, Activation, Add
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

In [2]:
def load_full_images(hr_root, lr_root, interpolation=cv2.INTER_CUBIC):
    """
    Loads HR and LR images from separate folders, upscales LR images.

    Parameters:
        hr_root (str): Root path to HR images (e.g., data/images/HR).
        lr_root (str): Root path to LR images (e.g., data/images/LR).
        interpolation: OpenCV interpolation method for upscaling LR images.

    Returns:
        X (np.ndarray): Low-resolution image patches (model input).
        Y (np.ndarray): High-resolution image patches (target).
    """

    if not os.path.exists(hr_root) or not os.path.exists(lr_root):
        raise ValueError("Both HR and LR root directories must exist.")
    if not os.path.isdir(hr_root) or not os.path.isdir(lr_root):
        raise ValueError("Both HR and LR root paths must be directories.")

    X, Y = [], []

    def get_all_image_paths(root):
        image_paths = []
        
        for dirpath, _, filenames in os.walk(root):
            for filename in filenames:
                if filename.lower().endswith((".jpg", ".jpeg", ".png")):
                    image_paths.append(os.path.join(dirpath, filename))
                    
        return sorted(image_paths)

    hr_paths = get_all_image_paths(hr_root)
    lr_paths = get_all_image_paths(lr_root)

    if not hr_paths or not lr_paths:
        raise ValueError("No images found in the specified directories.")

    # Match HR and LR images by filename (after last folder)
    hr_dict = {os.path.basename(p): p for p in hr_paths}
    lr_dict = {os.path.basename(p): p for p in lr_paths}
    common_filenames = sorted(set(hr_dict.keys()) & set(lr_dict.keys()))

    non_common_images = set()

    for fname in common_filenames:
        hr_img = cv2.imread(hr_dict[fname], cv2.IMREAD_COLOR)
        lr_img = cv2.imread(lr_dict[fname], cv2.IMREAD_COLOR)

        if hr_img is None or lr_img is None:
            non_common_images.add(fname)
            continue

        # Normalize
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

        # Upscale LR to HR size
        h, w, _ = hr_img.shape
        lr_up = cv2.resize(lr_img, (w, h), interpolation=interpolation)
        
        X.append(lr_up)
        Y.append(hr_img)

    if non_common_images:
        print(f"Skipped {len(non_common_images)} images due to missing files: {', '.join(non_common_images)}")

    return np.array(X), np.array(Y), h, w

In [11]:
def psnr(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def ssim(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

class VDSR:
    def __init__(self):
        self.model = None
        self.trained = False

    def setup_model(self, channels=3, num_layers=20, num_filters=64, learning_rate=0.1, loss="mean_squared_error", from_pretrained=False, pretrained_path=None):
        """Set up the VDSR model, either by loading a pretrained model or building a new one."""
        
        if from_pretrained:
            if pretrained_path is None or not os.path.isfile(pretrained_path):
                raise FileNotFoundError(f"Pretrained model file not found at {pretrained_path}")
            
            self.model = load_model(pretrained_path, custom_objects={"psnr": psnr, "ssim": ssim})
            self.trained = True
            print(f"Loaded pretrained model from {pretrained_path}")
        else:
            self._build_model(channels, num_layers, num_filters)
            self._compile_model(learning_rate, loss)

    def _build_model(self, channels, num_layers, num_filters):
        """Construct the VDSR model architecture using functional API."""
        
        inputs = Input(shape=(None, None, channels), name="input")
        x = Conv2D(num_filters, (3, 3), padding="same", kernel_initializer="he_normal")(inputs)
        x = Activation("relu")(x)

        for _ in range(num_layers - 2):
            x = Conv2D(num_filters, (3, 3), padding="same", kernel_initializer="he_normal")(x)
            x = Activation("relu")(x)

        x = Conv2D(channels, (3, 3), padding="same", kernel_initializer="he_normal")(x)
        outputs = Add(name="output")([x, inputs])

        self.model = Model(inputs, outputs, name="VDSR")

    def _compile_model(self, learning_rate, loss):
        """Compile the model with SGD optimizer and specified loss, including PSNR metric."""
        
        optimizer = SGD(learning_rate=learning_rate, momentum=0.9, clipnorm=0.01)
        self.model.compile(optimizer=optimizer, loss=loss, metrics=[psnr, ssim])
        self.model.summary()

    def fit(self, X_train, Y_train, X_val, Y_val, batch_size=64, epochs=50, use_augmentation=False):
        """Train the model using optional image data augmentation and standard callbacks."""
        
        if self.model is None:
            raise ValueError("Model is not built yet.")

        devices = tf.config.list_physical_devices("GPU")
        if devices:
            print("Training on GPU:", devices[0].name)
        else:
            print("Training on CPU")

        callbacks = [
            EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
            ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, min_lr=1e-6, verbose=1)
        ]

        if use_augmentation:
            datagen = tf.keras.preprocessing.image.ImageDataGenerator(
                rotation_range=15,
                width_shift_range=0.1,
                height_shift_range=0.1,
                zoom_range=0.1,
                horizontal_flip=True
            )
            train_gen = datagen.flow(X_train, Y_train, batch_size=batch_size)
            val_gen = datagen.flow(X_val, Y_val, batch_size=batch_size)

            self.model.fit(
                train_gen,
                steps_per_epoch=len(X_train) // batch_size,
                epochs=epochs,
                validation_data=val_gen,
                validation_steps=len(X_val) // batch_size,
                callbacks=callbacks
            )
        else:
            self.model.fit(
                X_train, Y_train,
                batch_size=batch_size,
                epochs=epochs,
                validation_data=(X_val, Y_val),
                callbacks=callbacks
            )

        self.trained = True

    def evaluate(self, X_test, Y_test):
        """Evaluate the model on test data and print both loss and PSNR."""
        if not self.trained:
            raise RuntimeError("Model has not been trained.")

        results = self.model.evaluate(X_test, Y_test)
        print(f"Loss: {results[0]:.4f}, PSNR: {results[1]:.2f} dB")
        return results

    def super_resolve_image(self, image_path, hr_h, hr_w, interpolation=cv2.INTER_CUBIC):
        """Performs super-resolution using padding to avoid border issues."""
        
        if not self._trained:
            raise RuntimeError("Model has not been trained.")
        if not os.path.isfile(image_path):
            raise FileNotFoundError(f"Image file not found at {image_path}")

        # Load and normalize original image
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0
        
        img_lr_up = cv2.resize(img, (hr_w, hr_h), interpolation=interpolation)

        # Predict
        img_lr_up = np.expand_dims(img_lr_up, axis=0)
        sr_img = self.model.predict(img_lr_up)[0]

        return sr_img

    def save(self, directory="models/VDSR"):
        """Save the trained model with a timestamp in the specified directory."""
        
        if not self.trained:
            raise RuntimeError("Cannot save an untrained model.")

        os.makedirs(directory, exist_ok=True)
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        path = os.path.join(directory, f"VDSR_{timestamp}.h5")
        self.model.save(path)
        print(f"Model saved to {path}")

In [None]:
INTERPOLATION = cv2.INTER_CUBIC
RANDOM_SEED = 42

In [27]:
X, Y, hr_h, hr_w = load_full_images("../../data/images/HR", "../../data/images/LR", interpolation=INTERPOLATION)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, shuffle=True, random_state=RANDOM_SEED)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, shuffle=True, random_state=RANDOM_SEED)

In [28]:
print(f"X shape: {X.shape}, Y shape: {Y.shape}")
print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_val shape: {X_val.shape}, Y_val shape: {Y_val.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

X shape: (1611, 478, 478, 3), Y shape: (1611, 478, 478, 3)
X_train shape: (1304, 478, 478, 3), Y_train shape: (1304, 478, 478, 3)
X_val shape: (145, 478, 478, 3), Y_val shape: (145, 478, 478, 3)
X_test shape: (162, 478, 478, 3), Y_test shape: (162, 478, 478, 3)


In [29]:
# calculate psnr between X and Y first patches
psnr_value = psnr(Y[0:1], X[0:1])
print(f"PSNR between first patches: {psnr_value.numpy()[0]:.2f} dB")

PSNR between first patches: 29.17 dB


In [30]:
model = VDSR()

model.setup_model(channels=3, num_layers=20, num_filters=64, learning_rate=0.1, loss="mean_squared_error", from_pretrained=False)

Model: "VDSR"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input (InputLayer)             [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 conv2d_80 (Conv2D)             (None, None, None,   1792        ['input[0][0]']                  
                                64)                                                               
                                                                                                  
 activation_76 (Activation)     (None, None, None,   0           ['conv2d_80[0][0]']              
                                64)                                                            

In [31]:
model.fit(X_train, Y_train, X_val, Y_val, batch_size=4, epochs=50, use_augmentation=True)

Training on GPU: /physical_device:GPU:0
Epoch 1/50


ResourceExhaustedError: Graph execution error:

Detected at node 'VDSR/activation_94/Relu' defined at (most recent call last):
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
      app.launch_new_instance()
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
      app.start()
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel\kernelapp.py", line 739, in start
      self.io_loop.start()
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\tornado\platform\asyncio.py", line 211, in start
      self.asyncio_loop.run_forever()
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\asyncio\base_events.py", line 603, in run_forever
      self._run_once()
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\asyncio\base_events.py", line 1909, in _run_once
      handle._run()
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
      await self.process_one()
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
      await dispatch(*args)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
      await result
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
      await super().execute_request(stream, ident, parent)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
      reply_content = await reply_content
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
      res = shell.run_cell(
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\IPython\core\interactiveshell.py", line 3077, in run_cell
      result = self._run_cell(
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\IPython\core\interactiveshell.py", line 3132, in _run_cell
      result = runner(coro)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\IPython\core\async_helpers.py", line 128, in _pseudo_sync_runner
      coro.send(None)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\IPython\core\interactiveshell.py", line 3336, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\IPython\core\interactiveshell.py", line 3519, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\IPython\core\interactiveshell.py", line 3579, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\bgmanuel\AppData\Local\Temp\ipykernel_24604\3118361437.py", line 1, in <module>
      model.fit(X_train, Y_train, X_val, Y_val, batch_size=4, epochs=50, use_augmentation=True)
    File "C:\Users\bgmanuel\AppData\Local\Temp\ipykernel_24604\2396356974.py", line 77, in fit
      self.model.fit(
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\layers\core\activation.py", line 59, in call
      return self.activation(inputs)
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\activations.py", line 317, in relu
      return backend.relu(
    File "c:\Users\bgmanuel\anaconda3\envs\py310\lib\site-packages\keras\backend.py", line 5366, in relu
      x = tf.nn.relu(x)
Node: 'VDSR/activation_94/Relu'
OOM when allocating tensor with shape[4,64,478,478] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node VDSR/activation_94/Relu}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_34470]

In [10]:
model.evaluate(X_test, Y_test)

Loss: 0.0048, PSNR: 23.21 dB


[0.004818345885723829, 23.213205337524414]

In [None]:
model.save()

In [None]:
pretrained_model = VDSR()

pretrained_model.setup_model(from_pretrained=True, pretrained_path="")

In [None]:
sr_image = pretrained_model.super_resolve_image("low_z_offset.png", hr_h, hr_w, interpolation=INTERPOLATION)
sr_image_uint8 = (sr_image * 255).astype(np.uint8)

plt.imshow(sr_image_uint8)
plt.axis('off')
plt.title("Super-Resolved Image")
plt.tight_layout()
plt.show()