In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
import tensorflow_datasets as tfds

# Data Preprocessing

## Load Data

In [None]:
def preprocess(example):
    # LR = low-res image (input), HR = high-res image (label)
    lr = tf.image.convert_image_dtype(example['lr'], tf.float32)
    hr = tf.image.convert_image_dtype(example['hr'], tf.float32)

    # Upscale the LR image to match HR size (x4 for DIV2K)
    lr_upscaled = tf.image.resize(lr, size=tf.shape(hr)[:2], method='bicubic')

    return lr_upscaled, hr

# Load pre-defined splits directly
train_data = tfds.load('div2k/bicubic_x4', split='train', shuffle_files=True)
val_data = tfds.load('div2k/bicubic_x4', split='validation', shuffle_files=True)

# Apply preprocessing
train_data = train_data.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
val_data = val_data.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

# Prepare batches
BATCH_SIZE = 1
train_data = train_data.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_data = val_data.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)


# Creating The Model

In [None]:
def SRCNN():
    model = Sequential()
    # Patch extraction and representation
    model.add(Conv2D(64, (9, 9), activation='relu', padding='same', input_shape=(None, None, 3)))
    # Non-linear mapping
    model.add(Conv2D(32, (1, 1), activation='relu', padding='same'))
    # Reconstruction
    model.add(Conv2D(3, (5, 5), activation='linear', padding='same'))
    return model

srcnn = SRCNN()
srcnn.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])

# Training & Evaluation

In [None]:
srcnn.fit(train_data, validation_data=val_data, epochs=10)

# Test Output

In [None]:
def load_and_preprocess_image(path, upscale_size=None):
    # Load image
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=3)

    # Resize (optional, if not already the upscaled size)
    if upscale_size:
        img = tf.image.resize(img, upscale_size, method='bicubic')

    # Convert to float32 and normalize
    img = tf.image.convert_image_dtype(img, tf.float32)

    # Add batch dimension
    img = tf.expand_dims(img, axis=0)  # shape: (1, h, w, 3)

    return img

# Load the low-res upscaled image
input_img = load_and_preprocess_image("willy_14.png")

# Predict high-res version using SRCNN
output = srcnn.predict(input_img)

# Remove batch dimension for visualization
output_img = tf.squeeze(output, axis=0)

# Show original (input) and super-resolved (output) side by side
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Input (Bicubic Upscaled)")
plt.imshow(tf.squeeze(input_img))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title("SRCNN Output")
plt.imshow(output_img)
plt.axis('off')
plt.show()