In [None]:
import tensorflow as tf
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow import keras

# Set environment variable to ensure Matplotlib uses a non-interactive backend
plt.switch_backend('Agg')

# --- Configuration Constants ---
SCALE_FACTOR = 4
BATCH_SIZE = 8
LR_SIZE = 96
HR_SIZE = LR_SIZE * SCALE_FACTOR # 384
TRAIN_SAMPLES = 800
EPOCHS = 30 # Updated to 30 epochs to match user's expected run duration
PIXEL_LOSS_WEIGHT = 0.001

# Set logging level to error to reduce noise
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

print(f"TensorFlow Version: {tf.__version__}")
print("-" * 50)

# ----------------------------------------------------------------------
# 1. KERAS OPTIMIZER AND CALLBACK CONFIGURATION
# ----------------------------------------------------------------------

# Using Adam optimizer with a piecewise constant learning rate schedule
optim_edsr = keras.optimizers.Adam(
    learning_rate=keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=[5000], # After 5000 steps
        values=[1e-4, 5e-5] # Starts at 1e-4, drops to 5e-5
    ),
)

# Checkpoint the best model weights based on validation loss
best_weights_checkpoint_path = "best-model.weights.h5"

# NOTE: Monitoring "loss" as we do not have a separate validation set configured.
save_best_cb = keras.callbacks.ModelCheckpoint(
    filepath=best_weights_checkpoint_path,
    monitor="loss",
    save_best_only=True,
    save_weights_only=True,
    save_freq="epoch",
)

# ----------------------------------------------------------------------
# 1.5. QUANTITATIVE METRICS (PSNR and SSIM)
# ----------------------------------------------------------------------

def psnr_metric(y_true, y_pred):
    """Peak Signal-to-Noise Ratio (PSNR) calculated over normalized [0, 1] images."""
    # PSNR is calculated over normalized image data [0, 1]
    # We use tf.image.psnr, which requires the max_val parameter.
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def ssim_metric(y_true, y_pred):
    """Structural Similarity Index Measure (SSIM) calculated over normalized [0, 1] images."""
    # SSIM is calculated over normalized image data [0, 1]
    # We take the mean across the batch since SSIM returns a value per image.
    return tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))

# ----------------------------------------------------------------------
# 2. DATA AUGMENTATION FUNCTIONS (TensorFlow Operators)
# ----------------------------------------------------------------------

def flip_left_right(lowres_img, highres_img):
    """Flips Images to left and right."""

    # Outputs random values from a uniform distribution in between 0 to 1
    rn = tf.random.uniform(shape=(), maxval=1)

    # If rn is less than 0.5 it returns original lowres_img and highres_img
    # If rn is greater than 0.5 it returns the flipped image
    return tf.cond(
        rn < 0.5,
        lambda: (lowres_img, highres_img),
        lambda: (
            tf.image.flip_left_right(lowres_img),
            tf.image.flip_left_right(highres_img),
        ),
    )


def random_rotate(lowres_img, highres_img):
    """Rotates Images by 90 degrees."""

    # Outputs random values from uniform distribution in between 0 and 3.
    rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)

    # rn signifies number of times the image(s) are rotated by 90 degrees
    return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)

# ----------------------------------------------------------------------
# 3. DATASET LOADING (Updated with Augmentation)
# ----------------------------------------------------------------------

def load_or_simulate_dataset(num_samples, batch_size, lr_shape, hr_shape):
    """
    Loads the real DIV2K dataset via TFDS and applies preprocessing and augmentation.
    """
    try:
        print("Attempting to load real DIV2K dataset via TFDS...")

        train_ds, info = tfds.load(
            'div2k/bicubic_x4',
            split='train',
            as_supervised=True,
            with_info=True
        )

        def filter_min_size(hr_img, lr_img):
            hr_shape = tf.shape(hr_img)
            hr_h, hr_w = hr_shape[0], hr_shape[1]
            return tf.logical_and(hr_h >= HR_SIZE, hr_w >= HR_SIZE)

        initial_samples = info.splits['train'].num_examples
        train_ds = train_ds.filter(filter_min_size)

        print(f"Initial TFDS samples: {initial_samples}")

        def preprocess_image_pair(hr_img, lr_img):
            # 1. Normalize to [0, 1]
            hr_img = tf.image.convert_image_dtype(hr_img, tf.float32)
            lr_img = tf.image.convert_image_dtype(lr_img, tf.float32)

            # 2. Random cropping/patch extraction
            hr_shape = tf.shape(hr_img)
            hr_h, hr_w = hr_shape[0], hr_shape[1]

            max_offset_h = hr_h - HR_SIZE
            max_offset_w = hr_w - HR_SIZE

            offset_h = tf.random.uniform(shape=[], minval=0, maxval=max_offset_h + 1, dtype=tf.int32)
            offset_w = tf.random.uniform(shape=[], minval=0, maxval=max_offset_w + 1, dtype=tf.int32)

            offset_h = (offset_h // SCALE_FACTOR) * SCALE_FACTOR
            offset_w = (offset_w // SCALE_FACTOR) * SCALE_FACTOR

            hr_patch = tf.image.crop_to_bounding_box(hr_img, offset_h, offset_w, HR_SIZE, HR_SIZE)

            lr_offset_h = offset_h // SCALE_FACTOR
            lr_offset_w = offset_w // SCALE_FACTOR
            lr_patch = tf.image.crop_to_bounding_box(lr_img, lr_offset_h, lr_offset_w, LR_SIZE, LR_SIZE)

            # 3. Data Augmentation (Flip and Rotate)
            lr_patch, hr_patch = flip_left_right(lr_patch, hr_patch)
            lr_patch, hr_patch = random_rotate(lr_patch, hr_patch)

            return lr_patch, hr_patch

        # Apply preprocessing, shuffling, and batching
        dataset = train_ds.map(preprocess_image_pair, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.shuffle(buffer_size=10).batch(batch_size).prefetch(tf.data.AUTOTUNE)

        print(f"TFDS DIV2K dataset processing configured with BATCH_SIZE={batch_size}.")
        return dataset

    except Exception as e:
        print(f"TFDS Loading failed: {e}. Falling back to simulation.")
        # --- Simulated Data Generation (Fallback Path) ---
        print(f"Creating SIMULATED DIV2K dataset: {num_samples} samples, batch size {batch_size}")
        lr_data = np.random.rand(num_samples, *lr_shape).astype(np.float32)
        hr_data = np.random.rand(num_samples, *hr_shape).astype(np.float32)
        dataset = tf.data.Dataset.from_tensor_slices((lr_data, hr_data))
        dataset = dataset.shuffle(buffer_size=100).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        return dataset

# ----------------------------------------------------------------------
# 4. KERAS MODEL DEFINITION (Updated with Custom Optimizer)
# ----------------------------------------------------------------------

def ResidualBlock(x):
    """A standard Residual Block structure."""
    x_res = x
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')(x)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same')(x)
    x = tf.keras.layers.Add()([x, x_res])
    x = tf.keras.layers.LeakyReLU()(x)
    return x

def create_sr_model(scale=SCALE_FACTOR, lr_size=LR_SIZE):
    """Defines a concrete Super-Resolution model structure."""
    print("\nDefining Keras Super-Resolution Model (Multi-Block Residual Style - 5 Blocks)...")

    input_tensor = tf.keras.Input(shape=(lr_size, lr_size, 3))

    # 1. Feature Extraction (Initial Conv)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=9, padding='same', activation='relu')(input_tensor)
    global_res = x

    # 2. Residual Blocks
    for _ in range(5):
        x = ResidualBlock(x)

    # Global Residual Skip Connection
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same')(x)
    x = tf.keras.layers.Add()([x, global_res])

    # 3. Upscaling (Sub-pixel Convolution is a better alternative but UpSampling2D is simpler here)
    x = tf.keras.layers.UpSampling2D(size=(scale, scale), interpolation='nearest')(x)

    # 4. Reconstruction (Final layer)
    output_tensor = tf.keras.layers.Conv2D(filters=3, kernel_size=5, padding='same', activation='sigmoid')(x)

    model = tf.keras.Model(inputs=input_tensor, outputs=output_tensor, name="SR_Residual_Model_MAE_Loss")

    # Compile the model using MAE Loss, the custom Adam optimizer, and quantitative metrics
    model.compile(optimizer=optim_edsr, loss='mae', metrics=[psnr_metric, ssim_metric])
    print("SR Model defined and compiled successfully using **MAE Loss**, Custom Adam Optimizer, and PSNR/SSIM Metrics.")

    return model

# ----------------------------------------------------------------------
# 5. EVALUATION COMPONENTS
# ----------------------------------------------------------------------

def predict_super_resolution(model, lr_image):
    """Uses the trained Keras model to generate a super-resolved image."""
    print("Generating Super-Resolution prediction using the Keras model...")

    # Normalize back to [0, 1] for model input
    lr_float = tf.cast(lr_image, tf.float32) / 255.0
    lr_batched = tf.expand_dims(lr_float, axis=0)

    # Model prediction (outputs float [0, 1])
    sr_float_batched = model.predict(lr_batched, verbose=0)
    sr_float = tf.squeeze(sr_float_batched, axis=0)

    # Convert back to uint8 [0, 255] for metrics and display
    sr_image = tf.cast(tf.clip_by_value(sr_float * 255.0, 0, 255), tf.uint8)

    print(f"  SR Image shape: {sr_image.shape}")
    return sr_image

def plot_lr_sr_ad_hoc(lr_image, sr_image, index, scale=SCALE_FACTOR):
    """
    Displays a simple LR Input vs SR Output comparison by saving the plot to a file.
    """
    # Determine the display size
    HR_SIZE_LOCAL = lr_image.shape[0] * scale

    # Create the figure
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Resize LR for visualization only
    lr_display = tf.image.resize(lr_image.numpy(), [HR_SIZE_LOCAL, HR_SIZE_LOCAL], method='nearest').numpy().astype(np.uint8)

    axes[0].imshow(lr_display)
    axes[0].set_title(f"Sample {index+1}: Low Resolution Input ({lr_image.shape[0]}x{lr_image.shape[1]})", fontsize=10)
    axes[0].axis("off")

    axes[1].imshow(sr_image.numpy().astype(np.uint8))
    axes[1].set_title(f"Sample {index+1}: Super-Resolution Output (x{scale})", fontsize=10)
    axes[1].axis("off")

    plt.tight_layout()

    # Save the plot to a file instead of trying to show it interactively
    filepath = f"sr_comparison_sample_{index+1}.png"
    plt.savefig(filepath)
    plt.close(fig) # Close the figure to free up memory
    print(f"Plot saved to {filepath}")


def run_ad_hoc_evaluation(model, dataset, num_samples=8):
    """
    Runs the ad-hoc evaluation loop requested by the user, skipping HR ground truth
    and metric calculation, focusing only on LR-vs-SR visual comparison.
    """
    print(f"\n--- Starting Ad-Hoc Visual Evaluation ({num_samples} Samples, Model: MAE Loss) ---")

    # Iterate over the specified number of batches
    for i, (lr_batch, _) in enumerate(dataset.take(num_samples)):
        if i >= num_samples:
            break

        # Get the first image from the batch
        lowres_img = tf.cast(lr_batch[0] * 255.0, tf.uint8)

        # Upscale the image
        sr_img = predict_super_resolution(model, lowres_img)

        # Plot the LR and SR results, which now saves them as files
        plot_lr_sr_ad_hoc(lowres_img, sr_img, i)

    print("--- Ad-Hoc Visual Evaluation Complete. ---")

# ----------------------------------------------------------------------
# 6. EXECUTION AND EVALUATION
# ----------------------------------------------------------------------

if __name__ == '__main__':
    # 1. Initialize the dataset.
    simulated_train_dataset = load_or_simulate_dataset(
        num_samples=TRAIN_SAMPLES,
        batch_size=BATCH_SIZE,
        lr_shape=(LR_SIZE, LR_SIZE, 3),
        hr_shape=(HR_SIZE, HR_SIZE, 3)
    )

    # 2. Create the fully defined SR model
    sr_model = create_sr_model()

    print("\n" * 2)

    # 3. Training
    print(f"--- Starting Training ({EPOCHS} Epochs with MAE Loss, Custom LR Schedule, and Checkpointing) ---")
    sr_model.fit(
        simulated_train_dataset,
        epochs=EPOCHS,
        callbacks=[save_best_cb], # Pass the checkpoint callback here
        verbose=1
    )
    print("--- Training Complete. Starting Evaluation. ---")

    try:
        # Run the ad-hoc visual evaluation
        run_ad_hoc_evaluation(sr_model, simulated_train_dataset, num_samples=8)

        print("\nModel Evaluation successfully completed. ")

    except Exception as e:
        print(f"\nFATAL ERROR: The script failed unexpectedly during sample retrieval or evaluation.")
        print(f"Error detail: {e}")
        sys.exit(1)


TensorFlow Version: 2.19.0
--------------------------------------------------
Attempting to load real DIV2K dataset via TFDS...




Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/div2k/bicubic_x4/2.0.0...
EXTRACTING {'train_lr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip', 'valid_lr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip', 'train_hr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', 'valid_hr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip'}


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]