In [None]:
# =========================
# Cell 1 — Setup & Imports
# =========================
import os, sys, tempfile, time, json
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

# Make sure Python can see your uploaded modules
sys.path.append("/mnt/data")

import parameters as P
import utilities as U
import watermarking as W
import attacks as A

print("Modules loaded:", P, U, W, A)
print("Defaults:", dict(DWT_LEVELS=P.DWT_LEVELS, QIM_STEP=P.QIM_STEP, WM_SIZE=P.WATERMARK_SIZE))


In [None]:
# =========================================
# Cell 2 — User Config
# =========================================
# 1) List as many host images as you like (RGB preferred; 512–1080 on H and W).
HOST_PATHS = [
    r"images\800.jpg",
    r"images\512.jpg",
    r"images\i512.png",
    r"images\i800.png",
]

# 2) One watermark image used across ALL host images
WATERMARK_PATH = r"images\cat-sm.jpg"

# 3) Hyperparameter grid (we will set into your modules dynamically)
DWT_LEVELS_GRID = [1, 2, 3]
QIM_STEPS_GRID  = [40, 60, 80, 100, 120]

# 4) Attack settings
JPEG_Q      = 75
CROP_RATIO  = 0.05
CROP_PATCH  = 1
ROT_ANGLE   = 5.0

# 5) Algorithms to compare (both use your embed; extract differs)
#    - "basic"     -> W.extract_watermark
#    - "enhanced"  -> W.extract_watermark_enhanced (small-angle search)
ALGORITHMS = ["basic", "enhanced"]

# 6) Output directory to save CSV and plots
OUTDIR = Path("./tuning_outputs")
OUTDIR.mkdir(parents=True, exist_ok=True)

print("Hosts:", len(HOST_PATHS), "Watermark:", WATERMARK_PATH)


In [None]:
# =========================================
# Cell 3 — Small helpers (I/O, metrics)
# =========================================
def load_rgb(path):
    return np.array(Image.open(path).convert("RGB"), dtype=np.uint8)

def rgb_to_ycbcr_arrays(rgb_u8):
    # Reuse your util for consistency
    pil = Image.fromarray(rgb_u8, mode="RGB")
    return U.to_ycbcr_arrays(pil)  # (Y float64, Cb u8, Cr u8)

def ycbcr_to_rgb(Y_float, Cb_u8, Cr_u8):
    return np.array(U.from_ycbcr_arrays(Y_float, Cb_u8, Cr_u8), dtype=np.uint8)

def prepare_wm_bits(path, size=None):
    if size is None: size = P.WATERMARK_SIZE
    pil = Image.open(path).convert("RGB")
    return U.prepare_watermark_bits(pil, size=size)  # Otsu + resize to size×size

def bits_metrics(true_bits_2d, pred_bits_2d):
    t = true_bits_2d.astype(np.uint8).ravel()
    p = pred_bits_2d.astype(np.uint8).ravel()
    ber = U.bit_error_rate(t, p)
    ncc = U.normalized_cross_correlation(t, p)
    return ber, ncc

def set_hparams(levels, step):
    """
    IMPORTANT: Your watermarking/utility modules import constants from parameters.py by value.
    To tune, we set the values on *all three* modules (P, U, W) before calling.
    """
    P.DWT_LEVELS = int(levels)
    P.QIM_STEP   = float(step)
    # Mirror values into the modules that imported them
    W.DWT_LEVELS = P.DWT_LEVELS
    W.QIM_STEP   = P.QIM_STEP
    U.DWT_LEVELS = P.DWT_LEVELS
    # other params remain unchanged (e.g., wavelet)


In [None]:
# =========================================
# Cell 4 — One run for (host, levels, step, algorithm)
# =========================================
def run_one_case(host_rgb, wm_bits2d, alg, levels, step,
                 jpeg_q=75, crop_ratio=0.05, crop_patch=1, rot_angle=5.0):
    """
    Embeds using W.embed_watermark (your function), then extracts with either:
      - basic:    W.extract_watermark
      - enhanced: W.extract_watermark_enhanced   (with small-angle search)
    """
    # --- Set hyperparameters into your modules
    set_hparams(levels, step)

    # --- Prepare channels
    Y, Cb, Cr = rgb_to_ycbcr_arrays(host_rgb)

    # --- Embed (your function)
    Y_wm = W.embed_watermark(Y, wm_bits2d)
    wm_rgb = ycbcr_to_rgb(Y_wm, Cb, Cr)

    # --- PSNR (imperceptibility)
    psnr_clean = U.psnr(host_rgb, wm_rgb)

    # --- Extractors (pick now so we can reuse)
    def extract_clean(Y_in):
        if alg == "basic":
            return W.extract_watermark(Y_in, wm_size=P.WATERMARK_SIZE)
        else:
            return W.extract_watermark_enhanced(Y_in, wm_size=P.WATERMARK_SIZE)

    # --- Clean extraction
    bits_clean = extract_clean(Y_wm)
    ber_clean, ncc_clean = bits_metrics(wm_bits2d, bits_clean)

    rows = []
    def add_row(scenario, psnr, ber, ncc):
        rows.append(dict(
            alg=alg, dwt_levels=levels, qim_step=step, scenario=scenario,
            psnr=psnr, ber=ber, ncc=ncc
        ))

    add_row("clean", psnr_clean, ber_clean, ncc_clean)

    # --- JPEG attack (your helper expects file paths)
    with tempfile.TemporaryDirectory() as td:
        src = Path(td, "wm.png"); out = Path(td, "jpeg.jpg")
        Image.fromarray(wm_rgb).save(src)
        A.jpeg_attack(str(src), str(out), quality=jpeg_q)
        jpeg_rgb = load_rgb(str(out))
    Yj, Cbj, Crj = rgb_to_ycbcr_arrays(jpeg_rgb)
    bits_jpeg = W.extract_watermark(Yj, wm_size=P.WATERMARK_SIZE) if alg=="basic" \
                else W.extract_watermark_enhanced(Yj, wm_size=P.WATERMARK_SIZE)
    ber_jpeg, ncc_jpeg = bits_metrics(wm_bits2d, bits_jpeg)
    add_row(f"jpeg_q{jpeg_q}", None, ber_jpeg, ncc_jpeg)

    # --- Crop attack (keeps size)
    crop_pil = A.crop_attack(np.array(wm_rgb, dtype=np.uint8),
                             area_ratio=crop_ratio, num_patches=crop_patch)
    crop_rgb = np.array(crop_pil, dtype=np.uint8)
    Yc, Cbc, Crc = rgb_to_ycbcr_arrays(crop_rgb)
    bits_crop = W.extract_watermark(Yc, wm_size=P.WATERMARK_SIZE) if alg=="basic" \
                else W.extract_watermark_enhanced(Yc, wm_size=P.WATERMARK_SIZE)
    ber_crop, ncc_crop = bits_metrics(wm_bits2d, bits_crop)
    add_row(f"crop_{crop_ratio:.02f}x{crop_patch}", None, ber_crop, ncc_crop)

    # --- Rotation attack (+5°)
    rot_pil = A.rotation_attack(Image.fromarray(wm_rgb), angle=rot_angle, fill_color=(255,255,255))
    rot_rgb = np.array(rot_pil, dtype=np.uint8)
    Yr, Cbr, Crr = rgb_to_ycbcr_arrays(rot_rgb)
    # Here "enhanced" should help; we still respect selected alg per request
    bits_rot = W.extract_watermark(Yr, wm_size=P.WATERMARK_SIZE) if alg=="basic" \
               else W.extract_watermark_enhanced(Yr, wm_size=P.WATERMARK_SIZE)
    ber_rot, ncc_rot = bits_metrics(wm_bits2d, bits_rot)
    add_row(f"rot_{rot_angle:+.0f}", None, ber_rot, ncc_rot)

    return pd.DataFrame(rows), dict(
        wm_rgb=wm_rgb,
        bits_clean=bits_clean, bits_jpeg=bits_jpeg, bits_crop=bits_crop, bits_rot=bits_rot
    )


In [None]:
# =========================================
# Cell 5 — Run grid on multiple hosts
# =========================================
# Prepare watermark bits once (same across hosts)
wm_bits = prepare_wm_bits(WATERMARK_PATH, size=P.WATERMARK_SIZE)

all_rows = []
per_host_examples = {}  # optional stash for quick inspection

for host_path in HOST_PATHS:
    host_name = Path(host_path).name
    try:
        host_rgb = load_rgb(host_path)
    except Exception as e:
        print(f"[SKIP] Cannot load {host_name}: {e}")
        continue

    print(f"\n=== HOST: {host_name} ===")
    for alg in ALGORITHMS:
        for L in DWT_LEVELS_GRID:
            for step in QIM_STEPS_GRID:
                try:
                    df, ex = run_one_case(host_rgb, wm_bits, alg, L, step,
                                          jpeg_q=JPEG_Q, crop_ratio=CROP_RATIO,
                                          crop_patch=CROP_PATCH, rot_angle=ROT_ANGLE)
                    df.insert(0, "host", host_name)
                    all_rows.append(df)

                    # Save one middle example per (host, alg) to look at later
                    if (host_name, alg) not in per_host_examples \
                       and L == DWT_LEVELS_GRID[len(DWT_LEVELS_GRID)//2] \
                       and step == QIM_STEPS_GRID[len(QIM_STEPS_GRID)//2]:
                        per_host_examples[(host_name, alg)] = ex
                except Exception as e:
                    print(f"[WARN] {host_name} | {alg} | L={L} | Δ={step}: {e}")
                    all_rows.append(pd.DataFrame([dict(
                        host=host_name, alg=alg, dwt_levels=L, qim_step=step,
                        scenario="error", psnr=None, ber=None, ncc=None, error=str(e)
                    )]))

results = pd.concat(all_rows, ignore_index=True)
print("Total rows:", len(results))
results.head()


In [None]:
# =========================================
# Cell 6 — Save raw results (CSV + JSON config)
# =========================================
ts = time.strftime("%Y%m%d-%H%M%S")
csv_path = OUTDIR / f"tuning_results_{ts}.csv"
cfg_path = OUTDIR / f"tuning_config_{ts}.json"

results.to_csv(csv_path, index=False)
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(dict(
        hosts=HOST_PATHS,
        watermark=WATERMARK_PATH,
        dwt_levels_grid=DWT_LEVELS_GRID,
        qim_steps_grid=QIM_STEPS_GRID,
        attacks=dict(jpeg_q=JPEG_Q, crop_ratio=CROP_RATIO, crop_patch=CROP_PATCH, rot_angle=ROT_ANGLE),
        algorithms=ALGORITHMS
    ), f, indent=2)

print("Saved:", csv_path, "and", cfg_path)


In [None]:
# =========================================
# Cell 7 — Quick summary tables
# =========================================
# Sort nicely
results = results.sort_values(by=["host","alg","scenario","dwt_levels","qim_step"]).reset_index(drop=True)
display(results.head(20))

# Avg across hosts (per alg, scenario, hyperparams)
avg_over_hosts = (results[results.scenario!="error"]
                  .groupby(["alg","scenario","dwt_levels","qim_step"], as_index=False)
                  .agg(psnr_mean=("psnr", "mean"),
                       ber_mean=("ber", "mean"),
                       ncc_mean=("ncc", "mean")))
display(avg_over_hosts.head(20))


In [None]:
# =========================================
# Cell 8 — BER heatmaps (avg across hosts)
# =========================================
scenarios = ["clean", f"jpeg_q{JPEG_Q}", f"crop_{CROP_RATIO:.02f}x{CROP_PATCH}", f"rot_{ROT_ANGLE:+.0f}"]

for alg in ALGORITHMS:
    for sc in scenarios:
        sub = avg_over_hosts[(avg_over_hosts.alg==alg) & (avg_over_hosts.scenario==sc)]
        if sub.empty:
            continue
        pivot = sub.pivot(index="dwt_levels", columns="qim_step", values="ber_mean")
        plt.figure(figsize=(6,4))
        plt.title(f"Avg BER heatmap — {alg} — {sc}")
        plt.imshow(pivot.values, aspect='auto')
        plt.xticks(range(len(pivot.columns)), pivot.columns)
        plt.yticks(range(len(pivot.index)), pivot.index)
        plt.xlabel("QIM step (Δ)")
        plt.ylabel("DWT levels")
        plt.colorbar(label="BER (mean across hosts)")
        fig_path = OUTDIR / f"heatmap_ber_{alg}_{sc}.png"
        plt.savefig(fig_path, dpi=150, bbox_inches="tight")
        plt.show()
        print("Saved plot:", fig_path)


In [None]:
# =========================================
# Cell 9 — PSNR curves vs QIM step (per DWT level, avg across hosts)
# =========================================
for alg in ALGORITHMS:
    sub = avg_over_hosts[(avg_over_hosts.alg==alg) & (avg_over_hosts.scenario=="clean")]
    if sub.empty:
        continue
    for L in sorted(sub.dwt_levels.unique()):
        chunk = sub[sub.dwt_levels==L]
        plt.figure(figsize=(6,4))
        plt.title(f"PSNR vs Δ — {alg} — clean — L={L}")
        plt.plot(chunk.qim_step.values, chunk.psnr_mean.values, marker="o")
        plt.xlabel("QIM step (Δ)")
        plt.ylabel("PSNR (dB)")
        plt.grid(True, alpha=0.4)
        fig_path = OUTDIR / f"psnr_curve_{alg}_L{L}.png"
        plt.savefig(fig_path, dpi=150, bbox_inches="tight")
        plt.show()
        print("Saved plot:", fig_path)


In [None]:
# =========================================
# Cell 10 — BER/NCC vs QIM step (per DWT level, avg across hosts)
# =========================================
for alg in ALGORITHMS:
    for sc in scenarios:
        sub = avg_over_hosts[(avg_over_hosts.alg==alg) & (avg_over_hosts.scenario==sc)]
        if sub.empty:
            continue
        for L in sorted(sub.dwt_levels.unique()):
            chunk = sub[sub.dwt_levels==L].sort_values("qim_step")
            plt.figure(figsize=(6,4))
            plt.title(f"BER vs Δ — {alg} — {sc} — L={L}")
            plt.plot(chunk.qim_step.values, chunk.ber_mean.values, marker="o")
            plt.xlabel("QIM step (Δ)")
            plt.ylabel("BER (mean across hosts)")
            plt.grid(True, alpha=0.4)
            fig_path = OUTDIR / f"ber_curve_{alg}_{sc}_L{L}.png"
            plt.savefig(fig_path, dpi=150, bbox_inches="tight")
            plt.show()

            plt.figure(figsize=(6,4))
            plt.title(f"NCC vs Δ — {alg} — {sc} — L={L}")
            plt.plot(chunk.qim_step.values, chunk.ncc_mean.values, marker="o")
            plt.xlabel("QIM step (Δ)")
            plt.ylabel("NCC (mean across hosts)")
            plt.grid(True, alpha=0.4)
            fig_path = OUTDIR / f"ncc_curve_{alg}_{sc}_L{L}.png"
            plt.savefig(fig_path, dpi=150, bbox_inches="tight")
            plt.show()


In [None]:
# =========================================
# Cell 11 — (Optional) Save a few example outputs
# =========================================
# For each (host, alg), we stashed a mid-grid example. Save the extracted watermark mosaics for report.
for (host_name, alg), ex in per_host_examples.items():
    grid = Image.new("L", (P.WATERMARK_SIZE*4, P.WATERMARK_SIZE))
    for k, key in enumerate(["bits_clean", "bits_jpeg", "bits_crop", "bits_rot"]):
        bits = ex[key].astype(np.uint8)*255
        grid.paste(Image.fromarray(bits, mode="L"), (k*P.WATERMARK_SIZE, 0))
    out = OUTDIR / f"example_bits_{host_name}_{alg}.png"
    grid.save(out)
    print("Saved mosaic:", out)

print("All done.")
