In [2]:
import tensorflow as tf
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow import keras


plt.switch_backend('Agg')

# --- Configuration Constants ---
SCALE_FACTOR = 4
BATCH_SIZE = 8
LR_SIZE = 96
HR_SIZE = LR_SIZE * SCALE_FACTOR
TRAIN_SAMPLES = 800
EPOCHS = 30
PIXEL_LOSS_WEIGHT = 0.001

# Set logging level to error to reduce noise
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

print(f"TensorFlow Version: {tf.__version__}")
print("-" * 50)

# ----------------------------------------------------------------------
# 1. KERAS OPTIMIZER AND CALLBACK CONFIGURATION
# ----------------------------------------------------------------------

# Using Adam optimizer with a piecewise constant learning rate schedule.

optim_edsr = keras.optimizers.Adam(
    learning_rate=keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=[5000],
        values=[1e-5, 5e-6]
    ),
)

# Checkpoint the best model weights based on validation loss
best_weights_checkpoint_path = "best-model.weights.h5"


save_best_cb = keras.callbacks.ModelCheckpoint(
    filepath=best_weights_checkpoint_path,
    monitor="loss",
    save_best_only=True,
    save_weights_only=True,
    save_freq="epoch",
)

# ----------------------------------------------------------------------
# 1.5. QUANTITATIVE METRICS (PSNR and SSIM)
# ----------------------------------------------------------------------

def psnr_metric(y_true, y_pred):
    """Peak Signal-to-Noise Ratio (PSNR) calculated over normalized [0, 1] images."""
    # PSNR is calculated over normalized image data [0, 1]

    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def ssim_metric(y_true, y_pred):
    """Structural Similarity Index Measure (SSIM) calculated over normalized [0, 1] images."""
    # SSIM is calculated over normalized image data [0, 1]

    return tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))

# ----------------------------------------------------------------------
# 2. DATA AUGMENTATION FUNCTIONS (TensorFlow Operators)
# ----------------------------------------------------------------------

def flip_left_right(lowres_img, highres_img):
    """Flips Images to left and right."""

    # Outputs random values from a uniform distribution in between 0 to 1
    rn = tf.random.uniform(shape=(), maxval=1)

    # If rn is less than 0.5 it returns original lowres_img and highres_img
    # If rn is greater than 0.5 it returns the flipped image
    return tf.cond(
        rn < 0.5,
        lambda: (lowres_img, highres_img),
        lambda: (
            tf.image.flip_left_right(lowres_img),
            tf.image.flip_left_right(highres_img),
        ),
    )


def random_rotate(lowres_img, highres_img):
    """Rotates Images by 90 degrees."""

    # Outputs random values from uniform distribution in between 0 and 3.
    rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)

    # rn signifies number of times the image(s) are rotated by 90 degrees
    return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)

# ----------------------------------------------------------------------
# 3. DATASET LOADING (Updated with Augmentation)
# ----------------------------------------------------------------------

def load_or_simulate_dataset(num_samples, batch_size, lr_shape, hr_shape):
    """
    Loads the real DIV2K dataset via TFDS and applies preprocessing and augmentation.
    """
    try:
        print("Attempting to load real DIV2K dataset via TFDS...")

        train_ds, info = tfds.load(
            'div2k/bicubic_x4',
            split='train',
            as_supervised=True,
            with_info=True
        )

        def filter_min_size(hr_img, lr_img):
            hr_shape = tf.shape(hr_img)
            hr_h, hr_w = hr_shape[0], hr_shape[1]
            return tf.logical_and(hr_h >= HR_SIZE, hr_w >= HR_SIZE)

        initial_samples = info.splits['train'].num_examples
        train_ds = train_ds.filter(filter_min_size)

        print(f"Initial TFDS samples: {initial_samples}")

        def preprocess_image_pair(hr_img, lr_img):

            hr_img = tf.image.convert_image_dtype(hr_img, tf.float32)
            lr_img = tf.image.convert_image_dtype(lr_img, tf.float32)


            hr_shape = tf.shape(hr_img)
            hr_h, hr_w = hr_shape[0], hr_shape[1]

            max_offset_h = hr_h - HR_SIZE
            max_offset_w = hr_w - HR_SIZE

            offset_h = tf.random.uniform(shape=[], minval=0, maxval=max_offset_h + 1, dtype=tf.int32)
            offset_w = tf.random.uniform(shape=[], minval=0, maxval=max_offset_w + 1, dtype=tf.int32)

            offset_h = (offset_h // SCALE_FACTOR) * SCALE_FACTOR
            offset_w = (offset_w // SCALE_FACTOR) * SCALE_FACTOR

            hr_patch = tf.image.crop_to_bounding_box(hr_img, offset_h, offset_w, HR_SIZE, HR_SIZE)

            lr_offset_h = offset_h // SCALE_FACTOR
            lr_offset_w = offset_w // SCALE_FACTOR
            lr_patch = tf.image.crop_to_bounding_box(lr_img, lr_offset_h, lr_offset_w, LR_SIZE, LR_SIZE)


            lr_patch, hr_patch = flip_left_right(lr_patch, hr_patch)
            lr_patch, hr_patch = random_rotate(lr_patch, hr_patch)

            return lr_patch, hr_patch

        # Apply preprocessing, shuffling, and batching
        dataset = train_ds.map(preprocess_image_pair, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.shuffle(buffer_size=10).batch(batch_size).prefetch(tf.data.AUTOTUNE)

        print(f"TFDS DIV2K dataset processing configured with BATCH_SIZE={batch_size}.")
        return dataset

    except Exception as e:
        print(f"TFDS Loading failed: {e}. Falling back to simulation.")

        print(f"Creating SIMULATED DIV2K dataset: {num_samples} samples, batch size {batch_size}")
        lr_data = np.random.rand(num_samples, *lr_shape).astype(np.float32)
        hr_data = np.random.rand(num_samples, *hr_shape).astype(np.float32)
        dataset = tf.data.Dataset.from_tensor_slices((lr_data, hr_data))
        dataset = dataset.shuffle(buffer_size=100).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        return dataset

#Load the BSD100 Test Set
def load_test_dataset(scale=SCALE_FACTOR):
    """
    Loads the BSD100 dataset for testing, generating LR images via bicubic downsampling
    and pairing them with the HR ground truth.
    """
    try:
        print("\nAttempting to load BSD100 (B100) Test Set via TFDS...")


        try:
            test_ds = tfds.load('bsd100', split='test', as_supervised=True)
        except:
            # Fallback for datasets without a 'test' split (e.g., Set5/14)
            test_ds = tfds.load('bsd100', split='all', as_supervised=True)

        def preprocess_test_image(hr_img, _):

            hr_img = tf.image.convert_image_dtype(hr_img, tf.float32)

            #Generate LR image using Bicubic downsampling (as per Unified Preprocessing standard)
            hr_h = tf.shape(hr_img)[0]
            hr_w = tf.shape(hr_img)[1]

            # Calculate LR dimensions
            lr_h = hr_h // scale
            lr_w = hr_w // scale

            # Downsample using Bicubic method
            lr_img = tf.image.resize(
                hr_img,
                size=[lr_h, lr_w],
                method=tf.image.ResizeMethod.BICUBIC
            )

            # Return LR input and HR ground truth
            return lr_img, hr_img

        # Apply preprocessing (downsampling and normalization)
        test_dataset = test_ds.map(preprocess_test_image, num_parallel_calls=tf.data.AUTOTUNE)

        # Batch size of 1 is typical for evaluation to handle variable image sizes
        test_dataset = test_dataset.batch(1).prefetch(tf.data.AUTOTUNE)

        print("BSD100 Test Set configured for evaluation (LR generated via Bicubic downsampling).")
        return test_dataset

    except Exception as e:
        print(f"TFDS BSD100 Loading failed: {e}. Returning None.")
        return None

# ----------------------------------------------------------------------
# 4. KERAS MODEL DEFINITION
# ----------------------------------------------------------------------

def ResidualBlock(x):
    """A standard Residual Block structure."""
    x_res = x
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')(x)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same')(x)
    x = tf.keras.layers.Add()([x, x_res])
    x = tf.keras.layers.LeakyReLU()(x)
    return x

def create_sr_model(scale=SCALE_FACTOR, lr_size=LR_SIZE):
    """Defines a concrete Super-Resolution model structure."""
    print("\nDefining Keras Super-Resolution Model (Multi-Block Residual Style - 5 Blocks)...")


    input_tensor = tf.keras.Input(shape=(None, None, 3))

    # 1. Feature Extraction (Initial Conv)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=9, padding='same', activation='relu')(input_tensor)
    global_res = x

    # 2. Residual Blocks
    for _ in range(5):
        x = ResidualBlock(x)

    # Global Residual Skip Connection
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same')(x)
    x = tf.keras.layers.Add()([x, global_res])

    # 3. Upscaling (Sub-pixel Convolution is a better alternative but UpSampling2D is simpler here)
    x = tf.keras.layers.UpSampling2D(size=(scale, scale), interpolation='nearest')(x)

    # 4. Reconstruction (Final layer)
    output_tensor = tf.keras.layers.Conv2D(filters=3, kernel_size=5, padding='same', activation='sigmoid')(x)

    model = tf.keras.Model(inputs=input_tensor, outputs=output_tensor, name="SR_Residual_Model_MAE_Loss")

    # Compile the model using MAE Loss, the custom Adam optimizer, and quantitative metrics
    model.compile(optimizer=optim_edsr, loss='mae', metrics=[psnr_metric, ssim_metric])
    print("SR Model defined and compiled successfully using **MAE Loss**, Custom Adam Optimizer, and PSNR/SSIM Metrics.")

    return model

# ----------------------------------------------------------------------
# 5. EVALUATION COMPONENTS
# ----------------------------------------------------------------------

def predict_super_resolution(model, lr_image):
    """
    Uses the trained Keras model to generate a super-resolved image.
    Accepts normalized [0, 1] tensor.
    """
    # Ensure it's batched
    lr_batched = tf.expand_dims(lr_image, axis=0)

    # Model prediction (outputs float [0, 1])
    sr_float_batched = model.predict(lr_batched, verbose=0)
    sr_float = tf.squeeze(sr_float_batched, axis=0)

    # Convert back to uint8 [0, 255] for display
    sr_image_uint8 = tf.cast(tf.clip_by_value(sr_float * 255.0, 0, 255), tf.uint8)

    return sr_image_uint8, sr_float # Return both uint8 for display and float for metrics

def plot_lr_sr_hr(lr_image, sr_image, hr_image, psnr, ssim, index, scale=SCALE_FACTOR):
    """
    Displays the LR Input, SR Output, and HR Ground Truth comparison by saving the plot to a file.
    Includes calculated metrics for the current sample.
    """
    # Determine the display size
    HR_SIZE_LOCAL = hr_image.shape[0]

    # Create the figure
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Resize LR for visualization only (using nearest neighbor for clarity)
    lr_display = tf.image.resize(tf.cast(lr_image * 255.0, tf.uint8).numpy(), [HR_SIZE_LOCAL, HR_SIZE_LOCAL], method='nearest').numpy().astype(np.uint8)

    # Convert HR and SR (uint8) for display
    hr_display = tf.cast(hr_image * 255.0, tf.uint8).numpy()
    sr_display = sr_image.numpy()

    axes[0].imshow(lr_display)
    axes[0].set_title(f"Low Resolution Input (x{scale} Bicubic)", fontsize=10)
    axes[0].axis("off")

    axes[1].imshow(sr_display.astype(np.uint8))
    axes[1].set_title(f"SR Output (PSNR: {psnr:.2f} dB, SSIM: {ssim:.4f})", fontsize=10)
    axes[1].axis("off")

    axes[2].imshow(hr_display.astype(np.uint8))
    axes[2].set_title(f"High Resolution Ground Truth", fontsize=10)
    axes[2].axis("off")


    plt.suptitle(f"BSD100 Test Sample {index+1}", fontsize=12)
    plt.tight_layout()

    # Save the plot to a file instead of trying to show it interactively
    filepath = f"bsd100_test_comparison_sample_{index+1}.png"
    plt.savefig(filepath)
    plt.close(fig) # Close the figure to free up memory
    print(f"Plot saved to {filepath}")


def run_test_evaluation_bsd100(model, test_dataset, num_samples=5):
    """
    Runs evaluation on the BSD100 dataset, calculating metrics and saving plots.
    """
    print(f"\n--- Starting BSD100 Test Evaluation ({num_samples} Samples) ---")

    total_psnr = 0.0
    total_ssim = 0.0
    count = 0

    # Iterate over the first few samples for visual plotting
    for i, (lr_batch, hr_batch) in enumerate(test_dataset.take(num_samples)):
        if i >= num_samples:
            break

        # Extract the single image from the batch
        lr_img_norm = lr_batch[0] # Normalized LR [0, 1]
        hr_img_norm = hr_batch[0] # Normalized HR [0, 1]

        # Upscale the image
        sr_img_uint8, sr_img_norm = predict_super_resolution(model, lr_img_norm)

        # Calculate metrics for this sample
        current_psnr = psnr_metric(hr_img_norm, sr_img_norm).numpy()
        # tf.image.ssim returns a single value if both inputs have the same shape
        current_ssim = tf.image.ssim(hr_img_norm, sr_img_norm, max_val=1.0).numpy()

        # Accumulate metrics
        total_psnr += current_psnr
        total_ssim += current_ssim
        count += 1

        # Plot the LR, SR, and HR results, including metrics
        plot_lr_sr_hr(lr_img_norm, sr_img_uint8, hr_img_norm, current_psnr, current_ssim, i)

    # Calculate and print final mean metrics if data was processed
    if count > 0:
        mean_psnr = total_psnr / count
        mean_ssim = total_ssim / count
        print(f"\n--- RESULTS ON BSD100 TEST SET (First {count} Samples) ---")
        print(f"Mean PSNR: {mean_psnr:.4f} dB")
        print(f"Mean SSIM: {mean_ssim:.4f}")
    else:
        print("No BSD100 test samples were available for evaluation.")

    print("--- BSD100 Test Evaluation Complete. ---")

# ----------------------------------------------------------------------
# 6. EXECUTION AND EVALUATION
# ----------------------------------------------------------------------

if __name__ == '__main__':
    # 1. Initialize the training dataset (DIV2K)
    train_dataset = load_or_simulate_dataset(
        num_samples=TRAIN_SAMPLES,
        batch_size=BATCH_SIZE,
        lr_shape=(LR_SIZE, LR_SIZE, 3),
        hr_shape=(HR_SIZE, HR_SIZE, 3)
    )

    # 1.5. Initialize the test dataset (BSD100)
    test_dataset = load_test_dataset()

    # 2. Create the fully defined SR model
    sr_model = create_sr_model()

    print("\n" * 2)

    # 3. Training
    print(f"--- Starting Training ({EPOCHS} Epochs with MAE Loss, Custom LR Schedule, and Checkpointing) ---")
    sr_model.fit(
        train_dataset,
        epochs=EPOCHS,
        callbacks=[save_best_cb], # Pass the checkpoint callback here
        verbose=1
    )
    print("--- Training Complete. Starting Evaluation. ---")

    try:
        # 4. Evaluation using BSD100
        if test_dataset:
            # Run the formal test evaluation on BSD100
            run_test_evaluation_bsd100(sr_model, test_dataset, num_samples=5)
        else:
            # Fallback to ad-hoc visual evaluation on training data if BSD100 failed to load
            print("\nWARNING: Could not load BSD100. Running ad-hoc visual check on training data.")
            # We'll keep a simplified version of the old function name for this fallback
            def run_ad_hoc_evaluation_fallback(model, dataset, num_samples=8):
                print(f"\n--- Starting Ad-Hoc Visual Evaluation ({num_samples} Samples, Model: MAE Loss) ---")
                for i, (lr_batch, _) in enumerate(dataset.take(num_samples)):
                    if i >= num_samples: break
                    lowres_img = lr_batch[0]
                    # Note: predict_super_resolution now takes normalized input
                    sr_img_uint8, _ = predict_super_resolution(model, lowres_img)

                    # Create a simple LR vs SR plot (using the old plotting function structure)
                    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
                    # Resizing LR input to match SR output size for visualization
                    lr_display = tf.image.resize(tf.cast(lowres_img * 255.0, tf.uint8).numpy(), [sr_img_uint8.shape[0], sr_img_uint8.shape[1]], method='nearest').numpy().astype(np.uint8)
                    axes[0].imshow(lr_display)
                    axes[0].set_title(f"Low Resolution Input", fontsize=10)
                    axes[0].axis("off")
                    axes[1].imshow(sr_img_uint8.numpy().astype(np.uint8))
                    axes[1].set_title(f"Super-Resolution Output", fontsize=10)
                    axes[1].axis("off")
                    plt.tight_layout()
                    filepath = f"sr_comparison_sample_{i+1}_fallback.png"
                    plt.savefig(filepath); plt.close(fig)
                    print(f"Fallback plot saved to {filepath}")
                print("--- Ad-Hoc Visual Evaluation Complete. ---")

            run_ad_hoc_evaluation_fallback(sr_model, train_dataset, num_samples=8)


        print("\nModel Evaluation successfully completed. ")

    except Exception as e:
        print(f"\nFATAL ERROR: The script failed unexpectedly during evaluation.")
        print(f"Error detail: {e}")
        sys.exit(1)


TensorFlow Version: 2.19.0
--------------------------------------------------
Attempting to load real DIV2K dataset via TFDS...
Initial TFDS samples: 800
TFDS DIV2K dataset processing configured with BATCH_SIZE=8.

Attempting to load BSD100 (B100) Test Set via TFDS...
TFDS BSD100 Loading failed: Dataset bsd100 not found.
Available datasets:
	- abstract_reasoning
	- accentdb
	- aeslc
	- aflw2k3d
	- ag_news_subset
	- ai2_arc
	- ai2_arc_with_ir
	- ai2dcaption
	- aloha_mobile
	- amazon_us_reviews
	- anli
	- answer_equivalence
	- arc
	- asimov_dilemmas_auto_val
	- asimov_dilemmas_scifi_train
	- asimov_dilemmas_scifi_val
	- asimov_injury_val
	- asimov_multimodal_auto_val
	- asimov_multimodal_manual_val
	- asqa
	- asset
	- assin2
	- asu_table_top_converted_externally_to_rlds
	- austin_buds_dataset_converted_externally_to_rlds
	- austin_sailor_dataset_converted_externally_to_rlds
	- austin_sirius_dataset_converted_externally_to_rlds
	- bair_robot_pushing_small
	- bc_z
	- bccd
	- beans
	- bee_d



[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 3s/step - loss: 0.2339 - psnr_metric: 11.5255 - ssim_metric: 0.2827
Epoch 3/30
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 3s/step - loss: 0.2341 - psnr_metric: 11.4948 - ssim_metric: 0.2740
Epoch 4/30
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 3s/step - loss: 0.2288 - psnr_metric: 11.6845 - ssim_metric: 0.2808
Epoch 5/30
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 3s/step - loss: 0.2329 - psnr_metric: 11.5414 - ssim_metric: 0.2800
Epoch 6/30
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 3s/step - loss: 0.2290 - psnr_metric: 11.6957 - ssim_metric: 0.2823
Epoch 7/30
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 3s/step - loss: 0.2305 - psnr_metric: 11.6256 - ssim_metric: 0.2815
Epoch 8/30
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 3s/step - loss: 0.2278 - psnr_metric: 11.7359 - ssim_metric: 0.2805
Epo