In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt

PATCH_SIZE = 128
CHANNELS = 3
BATCH_SIZE = 16
EPOCHS = 15
INPUT_DIR = "/content/drive/MyDrive/Colab Notebooks/FocusAI/fivek_raw_input/"
TARGET_DIR = "/content/drive/MyDrive/Colab Notebooks/FocusAI/fivek_expert_c_target/"

def load_and_preprocess_image(path):
    """
    Loads a full-resolution image without resizing.
    """
    img = tf.io.read_file(path)
    img = tf.image.decode_image(img, channels=CHANNELS, expand_animations=False)
    img = tf.cast(img, dtype=tf.float32) / 255.0
    return img


def load_paired_data(input_path):
    """
    Loads, center-crops to matching size, and creates a random patch.
    """

    filename = tf.strings.split(input_path, os.path.sep)[-1]
    target_path = tf.strings.join([TARGET_DIR, filename])


    input_img = load_and_preprocess_image(input_path)
    target_img = load_and_preprocess_image(target_path)


    input_shape = tf.shape(input_img)
    target_shape = tf.shape(target_img)

    min_height = tf.minimum(input_shape[0], target_shape[0])
    min_width = tf.minimum(input_shape[1], target_shape[1])

    input_img_cropped = tf.image.crop_to_bounding_box(
        input_img,
        (input_shape[0] - min_height) // 2,
        (input_shape[1] - min_width) // 2,
        min_height,
        min_width
    )
    target_img_cropped = tf.image.crop_to_bounding_box(
        target_img,
        (target_shape[0] - min_height) // 2,
        (target_shape[1] - min_width) // 2,
        min_height,
        min_width
    )


    stacked_images = tf.stack([input_img_cropped, target_img_cropped], axis=0)

    cropped_stacked_images = tf.image.random_crop(
        stacked_images, size=[2, PATCH_SIZE, PATCH_SIZE, CHANNELS]
    )

    input_patch = cropped_stacked_images[0]
    target_patch = cropped_stacked_images[1]

    return input_patch, target_patch


def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)


def ssim_metric(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)



def residual_block(x, filters):
    shortcut = x
    x = layers.Conv2D(filters, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(filters, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Add()([shortcut, x])
    x = layers.Activation('relu')(x)
    return x


def create_enhancement_model(input_shape):
    inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(inputs)
    for _ in range(8):
        x = residual_block(x, 64)
    residual_correction = layers.Conv2D(CHANNELS, kernel_size=3, padding='same', activation='linear')(x)
    outputs = layers.Add()([inputs, residual_correction])
    outputs = layers.Activation('sigmoid')(outputs)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

def main():

    input_paths = glob(os.path.join(INPUT_DIR, "*.jpg")) + glob(os.path.join(INPUT_DIR, "*.png"))
    if not input_paths:
        print(f"ERROR: No image files found in '{INPUT_DIR}'.")
        return

    valid_paths = []

    for path in input_paths:
        try:

            img = Image.open(path)
            if img.width < PATCH_SIZE or img.height < PATCH_SIZE:
                print(f"Skipping {path} (too small)")
                continue

            filename = os.path.basename(path)
            target_path = os.path.join(TARGET_DIR, filename)
            if not os.path.exists(target_path):
                print(f"Skipping {path} (no target found)")
                continue

            target_img = Image.open(target_path)
            if target_img.width < PATCH_SIZE or target_img.height < PATCH_SIZE:
                print(f"Skipping {path} (target is too small)")
                continue

            valid_paths.append(path)
        except Exception as e:
            print(f"Skipping {path}: {e}")


    split_index = int(0.9 * len(valid_paths))
    train_paths = valid_paths[:split_index]
    test_paths = valid_paths[split_index:]

    if not train_paths:
        print("ERROR: No valid training data found. Check paths and image sizes.")
        return

    train_ds = tf.data.Dataset.from_tensor_slices(train_paths)
    train_ds = (
        train_ds.map(load_paired_data, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .shuffle(buffer_size=100)
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    test_ds = tf.data.Dataset.from_tensor_slices(test_paths)
    test_ds = (
        test_ds.map(load_paired_data, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(1)
        .cache()
        .prefetch(tf.data.AUTOTUNE)
    )


    model = create_enhancement_model(input_shape=(PATCH_SIZE, PATCH_SIZE, CHANNELS))

    model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4),
                  loss='mean_absolute_error',
                  metrics=['mse', psnr_metric, ssim_metric])

    model.summary()


    model.fit(train_ds, epochs=EPOCHS, validation_data=test_ds)


    model_save_path = "/content/drive/MyDrive/Colab Notebooks/FocusAI/enhancer_patch_model_2.keras"
    model.save(model_save_path)


    results = model.evaluate(test_ds, verbose=1)


if __name__ == "__main__":
    main()

Epoch 1/15
[1m282/282[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2109s[0m 7s/step - loss: 0.1619 - mse: 0.0515 - psnr_metric: 15.4711 - ssim_metric: 0.3585 - val_loss: 0.1263 - val_mse: 0.0235 - val_psnr_metric: 17.7336 - val_ssim_metric: 0.7091
Epoch 2/15
[1m282/282[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 252ms/step - loss: 0.1001 - mse: 0.0184 - psnr_metric: 20.0741 - ssim_metric: 0.6763 - val_loss: 0.0967 - val_mse: 0.0154 - val_psnr_metric: 19.9106 - val_ssim_metric: 0.7162
Epoch 3/15
[1m282/282[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 262ms/step - loss: 0.0997 - mse: 0.0180 - psnr_metric: 20.3913 - ssim_metric: 0.7118 - val_loss: 0.0759 - val_mse: 0.0111 - val_psnr_metric: 22.5799 - val_ssim_metric: 0.7739
Epoch 4/15
[1m282/282[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 271ms/step - loss: 0.0957 - mse: 0.0165 - psnr_metric: 20.8395 - ssim_metric: 0.7304 - val_loss: 0.0736 - val_mse: 0.0105 - val_psnr_metric: 22.8378 - val_ssim_metr