# Here we run experiments to extract statistical data insights on models

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

import tensorflow as tf
from tensorflow.keras.models import load_model

In [None]:
# Validation set
VALIDATE_RAW = "/Users/amayakof/Desktop/2025_autumn/deep_learning/SIS/3/project/validate"

# Preprocessing functions are the same as preprocessing_img.ipynb
PREP_FUNCS = {}   # Will fill below

# Where to store GT preprocessed styles (generated via filters)
VALIDATE_GT_DIR = "/Users/amayakof/Desktop/2025_autumn/deep_learning/SIS/3/project/validate/gt"
os.makedirs(VALIDATE_GT_DIR, exist_ok=True)

# Where models save filtered outputs
VALIDATE_SIMPLE_DIR = "/Users/amayakof/Desktop/2025_autumn/deep_learning/SIS/3/project/validate/basic"
VALIDATE_UPD_DIR    = "/Users/amayakof/Desktop/2025_autumn/deep_learning/SIS/3/project/validate/upd"

os.makedirs(VALIDATE_SIMPLE_DIR, exist_ok=True)
os.makedirs(VALIDATE_UPD_DIR, exist_ok=True)

MODEL_SIMPLE_DIR = "/Users/amayakof/Desktop/2025_autumn/deep_learning/SIS/3/project/models"
MODEL_UPD_DIR    = "/Users/amayakof/Desktop/2025_autumn/deep_learning/SIS/3/project/models/upd"

IMAGE_SIZE = (256, 256)
STYLES = ["blur", "night_vis", "poster", "outline"]

In [None]:
# === IMPORT YOUR FILTERS ===

def strong_gaussian_blur(img):
    return cv2.GaussianBlur(img, (45,45), sigmaX=12)

def posterize_super_cartoon(img, levels=4):
    smooth = cv2.bilateralFilter(img, d=15, sigmaColor=90, sigmaSpace=90)
    x = smooth.astype(np.float32) / 255.0
    x = np.floor(x * levels) / levels
    poster = (x * 255).astype(np.uint8)
    gray = cv2.cvtColor(poster, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 80, 180)
    edges = cv2.dilate(edges, np.ones((2,2), np.uint8), iterations=1)
    edges = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
    edges_inv = 255 - edges
    return cv2.bitwise_and(poster, edges_inv)

def outlines_soft(img, levels=6):
    base = posterize_super_cartoon(img, levels)
    smooth = cv2.bilateralFilter(base, 9, 50, 50)
    gray = cv2.cvtColor(base, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 40, 120)
    edges_soft = cv2.GaussianBlur(edges, (5,5), 0)
    edges_rgb = cv2.cvtColor(edges_soft, cv2.COLOR_GRAY2BGR)
    light_edges = 255 - edges_rgb
    light_edges = (light_edges * 0.30).astype(np.uint8) + 180
    return cv2.addWeighted(smooth, 1.0, light_edges, 0.45, 0)

# Mapping
PREP_FUNCS = {
    "blur": strong_gaussian_blur,
    "poster": posterize_super_cartoon,
    "outline": outlines_soft,
    "night_vis": lambda x: x  # Night-vis GT comes from raw folder
}


In [None]:
def load_rgb(path):
    img = cv2.imread(path)
    if img is None:
        return None
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, IMAGE_SIZE)
    return img.astype("float32") / 255.0

def save_rgb(path, img):
    img = (img * 255).astype("uint8")
    bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(path, bgr)

def compute_metrics(gt, pred):
    gt_gray   = cv2.cvtColor((gt*255).astype("uint8"), cv2.COLOR_RGB2GRAY)
    pred_gray = cv2.cvtColor((pred*255).astype("uint8"), cv2.COLOR_RGB2GRAY)

    s = ssim(gt_gray, pred_gray)
    p = psnr(gt, pred, data_range=1.0)
    mse = np.mean((gt - pred)**2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(gt - pred))

    return s, p, mse, rmse, mae


In [None]:
print("=== Generating GT validation styles ===")

for fname in sorted(os.listdir(VALIDATE_RAW)):
    if not fname.lower().endswith((".png",".jpg",".jpeg")):
        continue

    img = load_rgb(os.path.join(VALIDATE_RAW, fname))
    if img is None:
        continue

    img_bgr = cv2.cvtColor((img*255).astype("uint8"), cv2.COLOR_RGB2BGR)

    for style in STYLES:
        out_dir = os.path.join(VALIDATE_GT_DIR, style)
        os.makedirs(out_dir, exist_ok=True)

        if style == "night_vis":
            # GT comes from raw style folder
            raw_nv = f"{VALIDATE_RAW}/{fname.replace('.png','_night_vis.png')}"
            # but if not found, fallback to preprocessing
            styled = img_bgr
        else:
            styled = PREP_FUNCS[style](img_bgr)

        # Save GT
        save_rgb(f"{out_dir}/{fname}", cv2.cvtColor(styled, cv2.COLOR_BGR2RGB))

print("GT validation images generated.")


In [None]:
print("=== Running SIMPLE models ===")

for style in STYLES:
    model_path = f"{MODEL_SIMPLE_DIR}/autoencoder_{style}.keras"
    if not os.path.exists(model_path):
        print("Missing simple model:", model_path)
        continue

    model = load_model(model_path)

    out_dir = os.path.join(VALIDATE_SIMPLE_DIR, style)
    os.makedirs(out_dir, exist_ok=True)

    for fname in sorted(os.listdir(VALIDATE_RAW)):
        inp = load_rgb(os.path.join(VALIDATE_RAW, fname))
        if inp is None:
            continue

        pred = model.predict(np.expand_dims(inp,0))[0]
        save_rgb(f"{out_dir}/{fname}", pred)

print("Simple model validation results saved.")


In [None]:
print("=== Running UPDATED models ===")

# load custom objects
def mae_basic(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))

def nightvis_weighted_mae(y_true, y_pred):
    lum = tf.reduce_mean(y_true, axis=-1, keepdims=True)
    weights = 1 + 4 * lum
    return tf.reduce_mean(weights * tf.abs(y_true - y_pred))

custom_objects = {
    "mae_basic": mae_basic,
    "nightvis_weighted_mae": nightvis_weighted_mae
}

for style in STYLES:
    model_path = f"{MODEL_UPD_DIR}/autoencoder_{style}_upd.keras"
    if not os.path.exists(model_path):
        print("Missing updated model:", model_path)
        continue

    model = load_model(model_path, custom_objects=custom_objects)

    out_dir = os.path.join(VALIDATE_UPD_DIR, style)
    os.makedirs(out_dir, exist_ok=True)

    for fname in sorted(os.listdir(VALIDATE_RAW)):
        inp = load_rgb(os.path.join(VALIDATE_RAW, fname))
        if inp is None:
            continue

        pred = model.predict(np.expand_dims(inp,0))[0]
        save_rgb(f"{out_dir}/{fname}", pred)

print("Updated model validation results saved.")


In [None]:
results = {style: [] for style in STYLES}

for style in STYLES:
    print(f"\n=== Metrics for style {style} ===")

    gt_dir = os.path.join(VALIDATE_GT_DIR, style)
    simple_dir = os.path.join(VALIDATE_SIMPLE_DIR, style)
    upd_dir = os.path.join(VALIDATE_UPD_DIR, style)

    for fname in sorted(os.listdir(gt_dir)):

        gt   = load_rgb(os.path.join(gt_dir, fname))
        sm   = load_rgb(os.path.join(simple_dir, fname))
        upd  = load_rgb(os.path.join(upd_dir, fname))

        if gt is None or sm is None or upd is None:
            continue

        sm_metrics  = compute_metrics(gt, sm)
        upd_metrics = compute_metrics(gt, upd)

        results[style].append((fname, sm_metrics, upd_metrics))

        print(f"{fname}:")
        print("  SIMPLE :", sm_metrics)
        print("  UPDATED:", upd_metrics)


In [None]:
def plot_metric(style, index, metric_name):
    names = [x[0] for x in results[style]]

    simple_vals = [x[1][index] for x in results[style]]
    upd_vals    = [x[2][index] for x in results[style]]

    plt.figure(figsize=(8,4))
    plt.plot(simple_vals, label="Simple", marker="o")
    plt.plot(upd_vals, label="Updated", marker="o")
    plt.title(f"{style}: {metric_name}")
    plt.xlabel("Image Index")
    plt.ylabel(metric_name)
    plt.legend()
    plt.grid()
    plt.show()

# metrics index: 0=SSIM, 1=PSNR, 2=MSE, 3=RMSE, 4=MAE
for style in STYLES:
    plot_metric(style, 0, "SSIM")
    plot_metric(style, 1, "PSNR")
    plot_metric(style, 2, "MSE")
    plot_metric(style, 3, "RMSE")
    plot_metric(style, 4, "MAE")
