# Final Code with the Normal Resizing Code

In [None]:
"""
SVD–QIM Watermarking (Color host, blind, no host resizing at embed)

Core idea
---------
1) Work in the Y (luma) channel for perceptual robustness.
2) Apply multi-level DWT; operate only on detail bands (default) for imperceptibility.
3) For each 8×8 DCT block in selected bands/tiles, take a mid-band patch,
   do SVD, and quantize ONE singular value (σ_k) via QIM (Σ-only embedding).
4) Use randomized block assignment (seeded by secret key + band/tile tags) and repetition
   + soft voting at extraction for robustness (e.g., compression and small local cutouts).

This file contains ONLY the SVD_QIM method (the original DCT-QIM branch is removed).
"""

import os
import math
import cv2
import numpy as np
from math import log10, sqrt
from PIL import Image
import pywt
from scipy.fftpack import dct, idct

# ===============================
# Utility: PSNR & BER
# ===============================

def psnr(img_a: np.ndarray, img_b: np.ndarray) -> float:
    """Peak-SNR between two uint8 images (same size)."""
    a = img_a.astype(np.float64)
    b = img_b.astype(np.float64)
    mse = np.mean((a - b) ** 2)
    if mse <= 1e-12:
        return 100.0
    return 20.0 * math.log10(255.0 / math.sqrt(mse))

def bit_error_rate(true_bits, pred_bits, threshold=None, return_counts=False):
    """
    Compute BER between two binary (or gray) arrays/lists.
    If arrays are grayscale, they will be binarized by 0.5 (0..1) or 127 (0..255) unless threshold given.
    """
    a = np.asarray(true_bits)
    b = np.asarray(pred_bits)
    if a.shape != b.shape:
        raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}")

    def binarize(x):
        x = np.asarray(x)
        thr = (0.5 if x.max() <= 1.0 else 127) if threshold is None else threshold
        return (x > thr).astype(np.uint8)

    a_bin = binarize(a)
    b_bin = binarize(b)
    errors = np.count_nonzero(a_bin ^ b_bin)
    N = a_bin.size
    ber = errors / N

    if not return_counts:
        return ber

    tp = np.count_nonzero((a_bin == 1) & (b_bin == 1))
    tn = np.count_nonzero((a_bin == 0) & (b_bin == 0))
    fp = np.count_nonzero((a_bin == 0) & (b_bin == 1))
    fn = np.count_nonzero((a_bin == 1) & (b_bin == 0))
    return ber, {
        "errors": errors, "total": N, "accuracy": 1.0 - ber,
        "tp": tp, "tn": tn, "fp": fp, "fn": fn
    }

# ===============================
# I/O: Keep original size, YCbCr (we embed in Y only)
# ===============================

def load_rgb_keep_size(path: str):
    """
    Load image as RGB, convert to YCbCr, and return (Y_float64, Cb_u8, Cr_u8, (W,H)).
    We process Y for watermarking, then recombine with original Cb/Cr to save.
    """
    im = Image.open(path).convert("RGB")
    ycbcr = im.convert("YCbCr")
    y_luma = np.array(ycbcr.getchannel(0), dtype=np.float64)
    cb_chroma = np.array(ycbcr.getchannel(1), dtype=np.uint8)
    cr_chroma = np.array(ycbcr.getchannel(2), dtype=np.uint8)
    return y_luma, cb_chroma, cr_chroma, im.size  # (W,H)

def save_rgb_from_y(y_luma: np.ndarray, cb_chroma: np.ndarray, cr_chroma: np.ndarray, out_path: str):
    """
    Save RGB by merging a (possibly float) Y with original Cb/Cr.
    """
    y_u8 = np.rint(np.clip(y_luma, 0, 255)).astype(np.uint8)
    ycbcr = Image.merge(
        "YCbCr",
        (Image.fromarray(y_u8), Image.fromarray(cb_chroma), Image.fromarray(cr_chroma))
    )
    rgb = ycbcr.convert("RGB")
    rgb.save(out_path)

# ===============================
# DWT helpers
# ===============================

def dwt2(image_2d: np.ndarray, wavelet: str = 'haar', level: int = 1, mode: str = 'periodization'):
    return pywt.wavedec2(image_2d, wavelet=wavelet, level=level, mode=mode)

def idwt2(coeffs, wavelet: str = 'haar', mode: str = 'periodization'):
    return pywt.waverec2(coeffs, wavelet=wavelet, mode=mode)

def wavelet_subband_map(coeffs):
    """
    Convert pywt 2D coeffs (level=1) to a dict: {"LL": arr, "LH": arr, "HL": arr, "HH": arr}
    """
    d = {"LL": coeffs[0]}
    lh, hl, hh = coeffs[1]
    d["LH"], d["HL"], d["HH"] = lh, hl, hh
    return d

def update_wavelet_subband(coeffs, band: str, new_arr: np.ndarray):
    """Write back a specific subband into coeffs (level=1)."""
    lh, hl, hh = coeffs[1]
    if band == "LL":
        coeffs[0] = new_arr
    elif band == "LH":
        coeffs[1] = (new_arr, hl, hh)
    elif band == "HL":
        coeffs[1] = (lh, new_arr, hh)
    elif band == "HH":
        coeffs[1] = (lh, hl, new_arr)

# ===============================
# Block DCT with padding to multiple of 8 (no host resize)
# ===============================

def dct2d_blocks_8x8(full_arr: np.ndarray):
    """
    Return DCT of the image (padded to multiples of 8), and original H,W for inverse crop.
    We apply 2D DCT per 8×8 block (JPEG-like).
    """
    H, W = full_arr.shape
    H8 = (H + 7) // 8 * 8
    W8 = (W + 7) // 8 * 8
    padded = np.pad(full_arr, ((0, H8 - H), (0, W8 - W)), mode='edge').astype(np.float64)
    D = np.empty((H8, W8), dtype=np.float64)
    for i in range(0, H8, 8):
        for j in range(0, W8, 8):
            blk = padded[i:i+8, j:j+8]
            D[i:i+8, j:j+8] = dct(dct(blk.T, norm="ortho").T, norm="ortho")
    return D, H, W

def idct2d_blocks_8x8(D: np.ndarray, H: int, W: int):
    """Inverse of dct2d_blocks_8x8; crop back to original H×W."""
    H8, W8 = D.shape
    rec = np.empty_like(D)
    for i in range(0, H8, 8):
        for j in range(0, W8, 8):
            blk = D[i:i+8, j:j+8]
            rec[i:i+8, j:j+8] = idct(idct(blk.T, norm="ortho").T, norm="ortho")
    return rec[:H, :W]

# ===============================
# QIM primitives (for Σ quantization inside SVD)
# ===============================

def embed_bit_by_qim(coeff_value: float, bit: int, step: float) -> float:
    """
    Scalar QIM embed of 1 bit into a single coefficient value.
    Places coeff into the center of the nearest bit-labeled interval of size 'step'.
    - bit=0 → interval centered at 0.25*step
    - bit=1 → interval centered at 0.75*step
    """
    q = np.floor(coeff_value / step)
    return (q + (0.25 if bit == 0 else 0.75)) * step

def qim_log_likelihood_ratio(coeff_value: float, step: float) -> float:
    """
    Soft decision: >0 means 'closer to bit=1', <0 means 'closer to bit=0'.
    Computed via distances to the 0.25*step and 0.75*step targets within the quantization cell.
    """
    x = coeff_value / step
    frac = x - np.floor(x)
    return abs(frac - 0.25) - abs(frac - 0.75)

# ===============================
# SVD-on-DCT-block helpers (Σ-only QIM)
# ===============================

def embed_bit_in_block_svd_qim(
    dct_block_8x8: np.ndarray,
    bit: int,
    step: float,
    svd_index: int = 0,
    patch_rows=(2, 6),
    patch_cols=(2, 6)
) -> None:
    """
    In-place modification of a mid-band patch inside an 8×8 DCT block via Σ-only QIM.

    Steps:
    1) Extract a mid-band patch from the 8×8 DCT block (rows 2..5, cols 2..5 by default).
    2) Compute SVD of the patch: P = U * diag(S) * V^T.
    3) Quantize ONE singular value S[svd_index] using QIM to embed 'bit'.
    4) Reconstruct the patch with the modified singular value and write it back.
    """
    r0, r1 = patch_rows
    c0, c1 = patch_cols
    patch = dct_block_8x8[r0:r1, c0:c1]
    U, S, Vt = np.linalg.svd(patch, full_matrices=False)
    k = min(svd_index, len(S) - 1)
    S[k] = embed_bit_by_qim(S[k], int(bit), step)
    patch_mod = (U * S) @ Vt  # equivalent to U @ diag(S) @ Vt
    dct_block_8x8[r0:r1, c0:c1] = patch_mod

def svd_qim_llr_for_block(
    dct_block_8x8: np.ndarray,
    step: float,
    svd_index: int = 0,
    patch_rows=(2, 6),
    patch_cols=(2, 6)
) -> float:
    """
    Return soft evidence (LLR) for the embedded bit from a single 8×8 block.
    """
    r0, r1 = patch_rows
    c0, c1 = patch_cols
    patch = dct_block_8x8[r0:r1, c0:c1]
    U, S, Vt = np.linalg.svd(patch, full_matrices=False)
    k = min(svd_index, len(S) - 1)
    return qim_log_likelihood_ratio(S[k], step)

# ===============================
# Tiling (for crop/small cutout robustness)
# ===============================

def compute_tile_slices(H: int, W: int, tiles=(2, 2)):
    """
    Split an array into a grid of tiles. Returns list of (rows_slice, cols_slice, tile_r, tile_c).
    We embed/extract bits across tiles with independent RNG shuffles, improving spatial spread.
    """
    tr, tc = tiles
    h_edges = [0] + [(H * r) // tr for r in range(1, tr)] + [H]
    w_edges = [0] + [(W * c) // tc for c in range(1, tc)] + [W]
    out = []
    for r in range(tr):
        for c in range(tc):
            r0, r1 = h_edges[r], h_edges[r + 1]
            c0, c1 = w_edges[c], w_edges[c + 1]
            out.append((slice(r0, r1), slice(c0, c1), r, c))
    return out

# ===============================
# Single-level SVD-QIM (kept for completeness)
# ===============================

BAND_TAG = {"LL": 101, "LH": 102, "HL": 103, "HH": 104}

def embed_bits_tiled_y_svd_qim(
    y_luma: np.ndarray,
    bits: list,
    secret_key: int,
    quant_step: float,
    svd_index: int = 0,
    svd_patch=((2, 6), (2, 6)),
    bands=("LL", "LH", "HL"),
    wavelet='haar',
    dwt_level=1,
    tiles=(2, 2),
    repeat=0
):
    """
    Embed bitstream into selected DWT subbands of Y using SVD-QIM on 8×8 DCT patches.
    """
    coeffs = dwt2(y_luma, wavelet, dwt_level)
    bands_map = wavelet_subband_map(coeffs)
    pr, pc = svd_patch

    for band in bands:
        sb = bands_map[band]
        for sl_r, sl_c, tile_r, tile_c in compute_tile_slices(*sb.shape, tiles=tiles):
            tile = sb[sl_r, sl_c]
            D, H0, W0 = dct2d_blocks_8x8(tile)
            H8, W8 = D.shape
            n_rows, n_cols = H8 // 8, W8 // 8
            n_blocks = n_rows * n_cols
            n_bits = len(bits)

            # Decide how many repetitions per bit (if repeat==0, fill all blocks)
            rep = repeat if repeat > 0 else max(1, n_blocks // n_bits)
            if n_bits * rep > n_blocks:
                rep = max(1, n_blocks // n_bits)

            # RNG shuffle for block positions (seeded by key+tile+band)
            seed = (int(secret_key) ^ (int(tile_r) << 8) ^ (int(tile_c) << 4) ^ BAND_TAG.get(band, 99)) & 0xFFFFFFFF
            rng = np.random.RandomState(seed)
            idx = np.arange(n_blocks)
            rng.shuffle(idx)

            pos = 0
            for bit_idx, bit in enumerate(bits):
                for _ in range(rep):
                    if pos >= n_blocks:
                        break
                    k = int(idx[pos]); pos += 1
                    r = k // n_cols; c = k % n_cols
                    i, j = r * 8, c * 8
                    embed_bit_in_block_svd_qim(
                        D[i:i+8, j:j+8], int(bit), quant_step,
                        svd_index=svd_index, patch_rows=pr, patch_cols=pc
                    )

            sb[sl_r, sl_c] = idct2d_blocks_8x8(D, H0, W0)

        bands_map[band] = sb

    for band in bands:
        update_wavelet_subband(coeffs, band, bands_map[band])
    y_watermarked = idwt2(coeffs, wavelet)
    return y_watermarked

def extract_bits_tiled_y_svd_qim(
    y_luma: np.ndarray,
    n_bits: int,
    secret_key: int,
    quant_step: float,
    svd_index: int = 0,
    svd_patch=((2, 6), (2, 6)),
    bands=("LL", "LH", "HL"),
    wavelet='haar',
    dwt_level=1,
    tiles=(2, 2),
):
    """
    Extract bitstream using soft voting over randomized, repeated block positions.
    Returns a list of {0,1}.
    """
    coeffs = dwt2(y_luma, wavelet, dwt_level)
    bands_map = wavelet_subband_map(coeffs)
    soft_sum = np.zeros(n_bits, dtype=np.float64)
    pr, pc = svd_patch

    for band in bands:
        sb = bands_map[band]
        for sl_r, sl_c, tile_r, tile_c in compute_tile_slices(*sb.shape, tiles=tiles):
            tile = sb[sl_r, sl_c]
            D, H0, W0 = dct2d_blocks_8x8(tile)
            H8, W8 = D.shape
            n_rows, n_cols = H8 // 8, W8 // 8
            n_blocks = n_rows * n_cols

            # How many blocks per bit (mirror of embed when repeat=0)
            rep = max(1, n_blocks // n_bits)

            seed = (int(secret_key) ^ (int(tile_r) << 8) ^ (int(tile_c) << 4) ^ BAND_TAG.get(band, 99)) & 0xFFFFFFFF
            rng = np.random.RandomState(seed)
            idx = np.arange(n_blocks)
            rng.shuffle(idx)

            pos = 0
            for bit_idx in range(n_bits):
                score = 0.0
                for _ in range(rep):
                    if pos >= n_blocks:
                        break
                    k = int(idx[pos]); pos += 1
                    r = k // n_cols; c = k % n_cols
                    i, j = r * 8, c * 8
                    score += svd_qim_llr_for_block(
                        D[i:i+8, j:j+8], quant_step,
                        svd_index=svd_index, patch_rows=pr, patch_cols=pc
                    )
                soft_sum[bit_idx] += score

    # Soft decision across all bands/tiles:
    out_bits = (soft_sum >= 0).astype(np.uint8).tolist()
    return out_bits

# ===============================
# Multi-level variants (LEVEL>=1)
# ===============================

def bands_all_levels(coeffs, use_LL=True, use_HV=True, use_D=False):
    """
    Create a list of (level_index, band_name, array_view) across levels.
    Level indexing: L is deepest. For each level, return H (horizontal), V (vertical), D (diagonal) as requested.
    """
    out = []
    L = len(coeffs) - 1
    if use_LL:
        out.append((L, 'LL', coeffs[0]))
    for l in range(L, 0, -1):
        (cH, cV, cD) = coeffs[L - l + 1]
        if use_HV:
            out.append((l, 'H', cH))
            out.append((l, 'V', cV))
        if use_D:
            out.append((l, 'D', cD))
    return out

def set_band_at_level(coeffs, level_index, band_name, new_arr):
    """Write back band at specific level."""
    L = len(coeffs) - 1
    if band_name == 'LL':
        coeffs[0] = new_arr; return
    idx = L - level_index + 1
    cH, cV, cD = coeffs[idx]
    if band_name == 'H':
        coeffs[idx] = (new_arr, cV, cD)
    elif band_name == 'V':
        coeffs[idx] = (cH, new_arr, cD)
    elif band_name == 'D':
        coeffs[idx] = (cH, cV, new_arr)

def embed_bits_y_multilevel_svd_qim(
    y_luma: np.ndarray,
    bits: list,
    secret_key: int,
    quant_step: float,
    svd_index: int = 0,
    svd_patch=((2, 6), (2, 6)),
    wavelet='haar',
    dwt_levels=2,
    tiles=(2, 2),
    repeat=0,
    include_LL=False,
    include_D=False
):
    """
    Multi-level version: pad Y to match DWT stride, embed across H/V (and optionally LL/D) at all levels.
    """
    # Pad so that after multi-level transforms, we still align to multiples of 8 for DCT blocks
    stride = 1 << dwt_levels
    mult = stride if stride >= 8 else 8
    H0, W0 = y_luma.shape
    Hm = (H0 + mult - 1) // mult * mult
    Wm = (W0 + mult - 1) // mult * mult
    y_pad = np.pad(y_luma, ((0, Hm - H0), (0, Wm - W0)), mode='edge')

    coeffs = dwt2(y_pad, wavelet, dwt_levels)
    target_bands = bands_all_levels(coeffs, use_LL=include_LL, use_HV=True, use_D=include_D)
    pr, pc = svd_patch

    for lvl, bname, sb in target_bands:
        for sl_r, sl_c, tile_r, tile_c in compute_tile_slices(*sb.shape, tiles=tiles):
            tile = sb[sl_r, sl_c]
            D, Ht, Wt = dct2d_blocks_8x8(tile)
            H8, W8 = D.shape
            n_rows, n_cols = H8 // 8, W8 // 8
            n_blocks = n_rows * n_cols
            n_bits = len(bits)

            rep = repeat if repeat > 0 else max(1, n_blocks // n_bits)
            if n_bits * rep > n_blocks:
                rep = max(1, n_blocks // n_bits)

            tag = {'LL': 201, 'H': 202, 'V': 203, 'D': 204}[bname]
            seed = (int(secret_key) ^ (int(tile_r) << 9) ^ (int(tile_c) << 5) ^ (lvl << 1) ^ tag) & 0xFFFFFFFF
            rng = np.random.RandomState(seed)
            idx = np.arange(n_blocks); rng.shuffle(idx)

            pos = 0
            for bit_idx, bit in enumerate(bits):
                for _ in range(rep):
                    if pos >= n_blocks:
                        break
                    k = int(idx[pos]); pos += 1
                    r = k // n_cols; c = k % n_cols
                    i, j = r * 8, c * 8
                    embed_bit_in_block_svd_qim(
                        D[i:i+8, j:j+8], int(bit), quant_step,
                        svd_index=svd_index, patch_rows=pr, patch_cols=pc
                    )

            sb[sl_r, sl_c] = idct2d_blocks_8x8(D, Ht, Wt)
        set_band_at_level(coeffs, lvl, bname, sb)

    y_wm = idwt2(coeffs, wavelet)
    return y_wm[:H0, :W0]

def extract_bits_y_multilevel_svd_qim(
    y_luma: np.ndarray,
    n_bits: int,
    secret_key: int,
    quant_step: float,
    svd_index: int = 0,
    svd_patch=((2, 6), (2, 6)),
    wavelet='haar',
    dwt_levels=2,
    tiles=(2, 2),
    include_LL=False,
    include_D=False
):
    """
    Multi-level extraction with soft voting across all used bands/tiles/levels.
    """
    stride = 1 << dwt_levels
    mult = stride if stride >= 8 else 8
    H0, W0 = y_luma.shape
    Hm = (H0 + mult - 1) // mult * mult
    Wm = (W0 + mult - 1) // mult * mult
    y_pad = np.pad(y_luma, ((0, Hm - H0), (0, Wm - W0)), mode='edge')

    coeffs = dwt2(y_pad, wavelet, dwt_levels)
    soft_sum = np.zeros(n_bits, dtype=np.float64)
    target_bands = bands_all_levels(coeffs, use_LL=include_LL, use_HV=True, use_D=include_D)
    pr, pc = svd_patch

    for lvl, bname, sb in target_bands:
        for sl_r, sl_c, tile_r, tile_c in compute_tile_slices(*sb.shape, tiles=tiles):
            tile = sb[sl_r, sl_c]
            D, Ht, Wt = dct2d_blocks_8x8(tile)
            H8, W8 = D.shape
            n_rows, n_cols = H8 // 8, W8 // 8
            n_blocks = n_rows * n_cols

            rep = max(1, n_blocks // n_bits)
            tag = {'LL': 201, 'H': 202, 'V': 203, 'D': 204}[bname]
            seed = (int(secret_key) ^ (int(tile_r) << 9) ^ (int(tile_c) << 5) ^ (lvl << 1) ^ tag) & 0xFFFFFFFF
            rng = np.random.RandomState(seed)
            idx = np.arange(n_blocks); rng.shuffle(idx)

            pos = 0
            for bit_idx in range(n_bits):
                score = 0.0
                for _ in range(rep):
                    if pos >= n_blocks:
                        break
                    k = int(idx[pos]); pos += 1
                    r = k // n_cols; c = k % n_cols
                    i, j = r * 8, c * 8
                    score += svd_qim_llr_for_block(
                        D[i:i+8, j:j+8], quant_step,
                        svd_index=svd_index, patch_rows=pr, patch_cols=pc
                    )
                soft_sum[bit_idx] += score

    out_bits = (soft_sum >= 0).astype(np.uint8).tolist()
    return out_bits

# ===============================
# Attack helpers (for testing)
# ===============================

def save_jpeg(input_path: str, out_path: str, quality: int = 75):
    Image.open(input_path).convert("RGB").save(out_path, quality=quality, subsampling=0, optimize=False)

def save_small_random_cutout(
    input_path: str,
    out_path: str,
    area_ratio: float = 0.01,
    num_patches: int = 1,
    shape: str = "rect",          # 'rect' or 'circle'
    fill: str = "noise",          # 'noise'|'black'|'avg'|'blur'|'inpaint'
    blur_kernel: int = 11,
    seed: int | None = None
):
    """
    Remove one or more small regions from the image while keeping the same size.

    This simulates 'tiny defects' or stickers/dust (small cutout) without global resizing/cropping.
    """
    im = Image.open(input_path).convert("RGB")
    arr = np.array(im, dtype=np.uint8)
    H, W, C = arr.shape
    rng = np.random.default_rng(seed)

    def mask_rect(x0, y0, w, h):
        m = np.zeros((H, W), dtype=np.uint8)
        m[y0:y0+h, x0:x0+w] = 255
        return m

    def mask_circle(x0, y0, w, h):
        m = np.zeros((H, W), dtype=np.uint8)
        cy, cx = y0 + h // 2, x0 + w // 2
        r = int(0.5 * max(w, h))
        yy, xx = np.ogrid[:H, :W]
        circle = (yy - cy) ** 2 + (xx - cx) ** 2 <= r * r
        m[circle] = 255
        return m

    for _ in range(max(1, int(num_patches))):
        # Patch size ~ sqrt(area_ratio) * (W,H) → area ≈ area_ratio * W * H
        w = max(1, int(round(np.sqrt(area_ratio) * W)))
        h = max(1, int(round(np.sqrt(area_ratio) * H)))
        x0 = int(rng.integers(0, max(1, W - w)))
        y0 = int(rng.integers(0, max(1, H - h)))
        mask = mask_rect(x0, y0, w, h) if shape == "rect" else mask_circle(x0, y0, w, h)

        if fill == "black":
            arr[mask == 255] = 0
        elif fill == "avg":
            mean = arr.reshape(-1, C).mean(axis=0).astype(np.uint8)
            arr[mask == 255] = mean
        elif fill == "noise":
            noise = rng.integers(0, 256, size=(H, W, C), dtype=np.uint8)
            arr[mask == 255] = noise[mask == 255]
        elif fill == "blur":
            k = blur_kernel if blur_kernel % 2 == 1 else blur_kernel + 1
            blurred = cv2.GaussianBlur(arr, (k, k), 0)
            arr[mask == 255] = blurred[mask == 255]
        elif fill == "inpaint":
            bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
            bgr = cv2.inpaint(bgr, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
            arr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        else:
            raise ValueError("fill must be one of {'noise','black','avg','blur','inpaint'}")

    Image.fromarray(arr).save(out_path)

# --- Resize (down→up) resampling attack, size restored to original ---
def save_resize_attack(
    input_path: str,
    out_path: str,
    scale: float = 0.75,
    interpolation: str = "bicubic"  # 'nearest'|'bilinear'|'bicubic'|'lanczos'
):
    interp_map = {
        "nearest": cv2.INTER_NEAREST,
        "bilinear": cv2.INTER_LINEAR,
        "bicubic": cv2.INTER_CUBIC,
        "lanczos": cv2.INTER_LANCZOS4,
    }
    interp = interp_map.get(interpolation, cv2.INTER_CUBIC)

    img = cv2.imread(input_path, cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(input_path)
    H, W = img.shape[:2]

    W2 = max(1, int(round(W * scale)))
    H2 = max(1, int(round(H * scale)))

    small = cv2.resize(img, (W2, H2), interpolation=interp)
    # # back to original size (this creates typical resampling artifacts)
    back = cv2.resize(small, (W, H), interpolation=interp)
    cv2.imwrite(out_path, back)

# --- Rotation attack, keeps the original size via warpAffine + border handling ---
def save_rotation_attack(
    input_path: str,
    out_path: str,
    angle_deg: float = 5.0,
    interpolation: str = "bicubic",    # 'nearest'|'bilinear'|'bicubic'|'lanczos'
    border: str = "replicate",         # 'replicate'|'reflect'|'constant_white'|'constant_black'
):
    interp_map = {
        "nearest": cv2.INTER_NEAREST,
        "bilinear": cv2.INTER_LINEAR,
        "bicubic": cv2.INTER_CUBIC,
        "lanczos": cv2.INTER_LANCZOS4,
    }
    border_map = {
        "replicate": (cv2.BORDER_REPLICATE, None),
        "reflect": (cv2.BORDER_REFLECT_101, None),
        "constant_white": (cv2.BORDER_CONSTANT, (255, 255, 255)),
        "constant_black": (cv2.BORDER_CONSTANT, (0, 0, 0)),
    }
    interp = interp_map.get(interpolation, cv2.INTER_CUBIC)
    bmode, bvalue = border_map.get(border, (cv2.BORDER_REPLICATE, None))

    img = cv2.imread(input_path, cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(input_path)
    H, W = img.shape[:2]

    M = cv2.getRotationMatrix2D((W / 2.0, H / 2.0), angle_deg, 1.0)
    rotated = cv2.warpAffine(
        img, M, (W, H),
        flags=interp,
        borderMode=bmode,
        borderValue=(0, 0, 0) if bvalue is None else bvalue
    )
    cv2.imwrite(out_path, rotated)

# ===============================
# Helper: extraction with small angle / scale search
# ===============================
def extract_with_rigid_search(
    y_luma, n_bits, secret_key, base_step,
    target_shape=None,                 # NEW
    angles=range(-7, 8, 1),
    scales=None,                       # NEW
    **kw
):
    # Default: denser near 1.0 plus your 0.75 case
    if scales is None:
        scales = [1.00, 0.98, 0.96, 0.94, 0.92, 0.90, 0.88, 0.85, 0.80, 0.75]

    # If caller provides the expected (orig) shape, force-resize first
    if target_shape is not None:
        Ht, Wt = target_shape
        if y_luma.shape != (Ht, Wt):
            y_luma = cv2.resize(y_luma, (Wt, Ht), interpolation=cv2.INTER_CUBIC)

    best_bits  = None
    best_score = -1e9
    H, W = y_luma.shape

    for ang in angles:
        M = cv2.getRotationMatrix2D((W/2, H/2), ang, 1.0)
        rot = cv2.warpAffine(
            y_luma, M, (W, H),
            flags=cv2.INTER_CUBIC,
            borderMode=cv2.BORDER_REPLICATE
        )
        for s in scales:
            if abs(s - 1.0) < 1e-12:
                tmp = rot
            else:
                tmp = cv2.resize(
                    cv2.resize(rot, (0, 0), fx=s, fy=s, interpolation=cv2.INTER_CUBIC),
                    (W, H), interpolation=cv2.INTER_CUBIC
                )

            bits = extract_bits_y_multilevel_svd_qim(
                tmp, n_bits, secret_key, base_step, **kw
            )
            score = sum(1 if b else -1 for b in bits)
            if score > best_score:
                best_score, best_bits = score, bits

    return best_bits




# ===============================
# Demo / Main (focus: JPEG compression & small cutout attacks)
# ===============================
if _name_ == "_main_":
    # ---- Paths ----
    HOST_PATH = "./images/original2.png"   # color host, any size (no resize at embed)
    WM_PATH   = "./images/watermark.jpg"      # watermark will be binarized to 32×32

    os.makedirs("./result", exist_ok=True)
    os.makedirs("./images", exist_ok=True)

    # ---- Parameters ----
    secret_key     = 1234567890
    quant_step     = 60.0             # QIM step for σ_k (tune vs PSNR/robustness)
    quant_step_LL = 30.0
    include_LL    = True           # whether to use LL band (more robust, less imperceptible)
    wavelet_name   = 'haar'
    dwt_levels     = 4
    watermark_size = 32
    tile_grid      = (2, 2)

    # SVD settings
    svd_index      = 0                # which σ_k to quantize (0 = largest)
    svd_patch_rc   = ((2, 6), (2, 6)) # mid-band patch inside each 8×8 DCT block

    # ---- Load host (Y, Cb, Cr) ----
    y_luma, cb_chroma, cr_chroma, _ = load_rgb_keep_size(HOST_PATH)

    # ---- Prepare watermark bits ----
    wm_gray = Image.open(WM_PATH).convert('L').resize((watermark_size, watermark_size), Image.LANCZOS)
    wm_bin  = (np.array(wm_gray, dtype=np.uint8) > 127).astype(np.uint8)
    payload_bits = wm_bin.reshape(-1).tolist()

    # ---- Embed (multi-level, details + LL) ----
    y_wm = embed_bits_y_multilevel_svd_qim(
        y_luma, payload_bits,
        secret_key=secret_key,
        quant_step=quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        repeat=0,
        include_LL=include_LL,   # now True
        include_D=False
    )

    # One extra light pass only into deepest LL
    y_wm = embed_bits_y_multilevel_svd_qim(
        y_wm, payload_bits,
        secret_key=secret_key ^ 0x55AA,   # different key
        quant_step=quant_step_LL,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=(1, 1),
        repeat=0,
        include_LL=True,
        include_D=False
)


    WATERMARKED_PATH = "./images/x-watermarked-ml.png"
    save_rgb_from_y(y_wm, cb_chroma, cr_chroma, WATERMARKED_PATH)

    # ---- Extract (clean) ----
    y2, _, _, _ = load_rgb_keep_size(WATERMARKED_PATH)

    bits_clean = extract_bits_y_multilevel_svd_qim(
        y2, watermark_size * watermark_size,
        secret_key=secret_key,
        quant_step=quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=False,
        include_D=False
    )
    wm_rec = np.array(bits_clean, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_rec * 255).astype(np.uint8)).save("./result/Extracted_Clean_MultiLevel.png")



    # --- JPEG compression attack ---
    ATTACK_JPEG = "./images/attack_q75-Multilevel.jpg"
    save_jpeg(WATERMARKED_PATH, ATTACK_JPEG, quality=75)
    y_jpeg, _, _, _ = load_rgb_keep_size(ATTACK_JPEG)

    bits_jpeg = extract_bits_y_multilevel_svd_qim(
        y_jpeg, watermark_size * watermark_size,
        secret_key=secret_key,
        quant_step=quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=False,
        include_D=False
    )
    wm_jpeg = np.array(bits_jpeg, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_jpeg * 255).astype(np.uint8)).save("./result/Extracted_JPEG_q75-Multilevel.png")

    # --- Small random cutout attack (looks like “tiny noise specks”) ---
    ATTACK_SMALL = "./images/attack_small_cutout.png"
    save_small_random_cutout(
        WATERMARKED_PATH, ATTACK_SMALL,
        area_ratio=0.001,   # ~0.1% of image area per patch
        num_patches=50,     # many tiny specks
        shape="rect",
        fill="noise",
        seed=secret_key
    )
    y_sc, _, _, _ = load_rgb_keep_size(ATTACK_SMALL)

    bits_sc = extract_bits_y_multilevel_svd_qim(
        y_sc, watermark_size * watermark_size,
        secret_key=secret_key,
        quant_step=quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=False,
        include_D=False
    )
    wm_sc = np.array(bits_sc, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_sc * 255).astype(np.uint8)).save("./result/Extracted_Small_Cutout.png")

    # --- Resize (downsample→upsample) attack ---
    ATTACK_RESIZE = "./images/attack_resize_0p75.png"
    save_resize_attack(WATERMARKED_PATH, ATTACK_RESIZE, scale=0.75, interpolation="bicubic")
    y_rs, _, _, _ = load_rgb_keep_size(ATTACK_RESIZE)

    # After clean extract (y2 is the watermarked image's Y)
    orig_shape = y2.shape

    bits_rs = extract_with_rigid_search(
        y_rs, watermark_size * watermark_size,
        secret_key, quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=include_LL,
        include_D=False,
        target_shape=orig_shape        # NEW: force the attacked image back to original grid
    )

    wm_rs = np.array(bits_rs, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_rs * 255).astype(np.uint8)).save("./result/Extracted_Resize_0p75.png")

    # --- Rotation attack (keep original size) ---
    ATTACK_ROT = "./images/attack_rotate_5deg.png"
    save_rotation_attack(WATERMARKED_PATH, ATTACK_ROT, angle_deg=5.0, interpolation="bicubic", border="replicate")
    y_rot, _, _, _ = load_rgb_keep_size(ATTACK_ROT)

    bits_rot = extract_with_rigid_search(
        y_rot, watermark_size * watermark_size,
        secret_key, quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=include_LL,
        include_D=False,
        target_shape=orig_shape        # same size anyway, but keeps logic uniform
    )

    wm_rot = np.array(bits_rot, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_rot * 255).astype(np.uint8)).save("./result/Extracted_Rotate_5deg.png")


    # ---- Metrics ----
    host_bgr = cv2.imread(HOST_PATH)
    watermarked_bgr = cv2.imread(WATERMARKED_PATH)
    PSNR = psnr(host_bgr, watermarked_bgr)
    print(f"PSNR (host vs watermarked): {PSNR:.4f} dB")

    ber_clean, stats_clean = bit_error_rate(wm_bin, wm_rec, return_counts=True)
    print("\nClean extraction vs original")
    print(f"BER: {ber_clean:.4f} | Acc: {stats_clean['accuracy']:.4f} | Errors: {stats_clean['errors']}/{stats_clean['total']}")

    ber_jpeg, stats_jpeg = bit_error_rate(wm_bin, wm_jpeg, return_counts=True)
    print("\nJPEG q=75 vs original")
    print(f"BER: {ber_jpeg:.4f} | Acc: {stats_jpeg['accuracy']:.4f} | Errors: {stats_jpeg['errors']}/{stats_jpeg['total']}")

    ber_sc, stats_sc = bit_error_rate(wm_bin, wm_sc, return_counts=True)
    print("\nSmall cutout vs original")
    print(f"BER: {ber_sc:.4f} | Acc: {stats_sc['accuracy']:.4f} | Errors: {stats_sc['errors']}/{stats_sc['total']}")

    ber_rs, stats_rs = bit_error_rate(wm_bin, wm_rs, return_counts=True)
    print("\nResize (0.75x down -> back up) vs original")
    print(f"BER: {ber_rs:.4f} | Acc: {stats_rs['accuracy']:.4f} | Errors: {stats_rs['errors']}/{stats_rs['total']}")

    ber_rot, stats_rot = bit_error_rate(wm_bin, wm_rot, return_counts=True)
    print("\nRotation (+5°) vs original")
    print(f"BER: {ber_rot:.4f} | Acc: {stats_rot['accuracy']:.4f} | Errors: {stats_rot['errors']}/{stats_rot['total']}")

    print("\nMethod: SVD_QIM (Σ-only on mid-band patch of each 8×8 DCT block)")
    print("Watermarked image:", WATERMARKED_PATH)
    print("Extracted images saved in ./result")

# Test Whatsapp Code

In [None]:
"""
SVD–QIM Watermarking (Color host, blind, no host resizing at embed)

Core idea
---------
1) Work in the Y (luma) channel for perceptual robustness.
2) Apply multi-level DWT; operate only on detail bands (default) for imperceptibility.
3) For each 8×8 DCT block in selected bands/tiles, take a mid-band patch,
   do SVD, and quantize ONE singular value (σ_k) via QIM (Σ-only embedding).
4) Use randomized block assignment (seeded by secret key + band/tile tags) and repetition
   + soft voting at extraction for robustness (e.g., compression and small local cutouts).

This file contains ONLY the SVD_QIM method (the original DCT-QIM branch is removed).
"""

import os
import math
import cv2
import numpy as np
from math import log10, sqrt
from PIL import Image
import pywt
from scipy.fftpack import dct, idct

# ===============================
# Utility: PSNR & BER
# ===============================

def psnr(img_a: np.ndarray, img_b: np.ndarray) -> float:
    """Peak-SNR between two uint8 images (same size)."""
    a = img_a.astype(np.float64)
    b = img_b.astype(np.float64)
    mse = np.mean((a - b) ** 2)
    if mse <= 1e-12:
        return 100.0
    return 20.0 * math.log10(255.0 / math.sqrt(mse))

def bit_error_rate(true_bits, pred_bits, threshold=None, return_counts=False):
    """
    Compute BER between two binary (or gray) arrays/lists.
    If arrays are grayscale, they will be binarized by 0.5 (0..1) or 127 (0..255) unless threshold given.
    """
    a = np.asarray(true_bits)
    b = np.asarray(pred_bits)
    if a.shape != b.shape:
        raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}")

    def binarize(x):
        x = np.asarray(x)
        thr = (0.5 if x.max() <= 1.0 else 127) if threshold is None else threshold
        return (x > thr).astype(np.uint8)

    a_bin = binarize(a)
    b_bin = binarize(b)
    errors = np.count_nonzero(a_bin ^ b_bin)
    N = a_bin.size
    ber = errors / N

    if not return_counts:
        return ber

    tp = np.count_nonzero((a_bin == 1) & (b_bin == 1))
    tn = np.count_nonzero((a_bin == 0) & (b_bin == 0))
    fp = np.count_nonzero((a_bin == 0) & (b_bin == 1))
    fn = np.count_nonzero((a_bin == 1) & (b_bin == 0))
    return ber, {
        "errors": errors, "total": N, "accuracy": 1.0 - ber,
        "tp": tp, "tn": tn, "fp": fp, "fn": fn
    }

# ===============================
# I/O: Keep original size, YCbCr (we embed in Y only)
# ===============================

def load_rgb_keep_size(path: str):
    """
    Load image as RGB, convert to YCbCr, and return (Y_float64, Cb_u8, Cr_u8, (W,H)).
    We process Y for watermarking, then recombine with original Cb/Cr to save.
    """
    im = Image.open(path).convert("RGB")
    ycbcr = im.convert("YCbCr")
    y_luma = np.array(ycbcr.getchannel(0), dtype=np.float64)
    cb_chroma = np.array(ycbcr.getchannel(1), dtype=np.uint8)
    cr_chroma = np.array(ycbcr.getchannel(2), dtype=np.uint8)
    return y_luma, cb_chroma, cr_chroma, im.size  # (W,H)

def save_rgb_from_y(y_luma: np.ndarray, cb_chroma: np.ndarray, cr_chroma: np.ndarray, out_path: str):
    """
    Save RGB by merging a (possibly float) Y with original Cb/Cr.
    """
    y_u8 = np.rint(np.clip(y_luma, 0, 255)).astype(np.uint8)
    ycbcr = Image.merge(
        "YCbCr",
        (Image.fromarray(y_u8), Image.fromarray(cb_chroma), Image.fromarray(cr_chroma))
    )
    rgb = ycbcr.convert("RGB")
    rgb.save(out_path)

# ===============================
# DWT helpers
# ===============================

def dwt2(image_2d: np.ndarray, wavelet: str = 'haar', level: int = 1, mode: str = 'periodization'):
    return pywt.wavedec2(image_2d, wavelet=wavelet, level=level, mode=mode)

def idwt2(coeffs, wavelet: str = 'haar', mode: str = 'periodization'):
    return pywt.waverec2(coeffs, wavelet=wavelet, mode=mode)

def wavelet_subband_map(coeffs):
    """
    Convert pywt 2D coeffs (level=1) to a dict: {"LL": arr, "LH": arr, "HL": arr, "HH": arr}
    """
    d = {"LL": coeffs[0]}
    lh, hl, hh = coeffs[1]
    d["LH"], d["HL"], d["HH"] = lh, hl, hh
    return d

def update_wavelet_subband(coeffs, band: str, new_arr: np.ndarray):
    """Write back a specific subband into coeffs (level=1)."""
    lh, hl, hh = coeffs[1]
    if band == "LL":
        coeffs[0] = new_arr
    elif band == "LH":
        coeffs[1] = (new_arr, hl, hh)
    elif band == "HL":
        coeffs[1] = (lh, new_arr, hh)
    elif band == "HH":
        coeffs[1] = (lh, hl, new_arr)

# ===============================
# Block DCT with padding to multiple of 8 (no host resize)
# ===============================

def dct2d_blocks_8x8(full_arr: np.ndarray):
    """
    Return DCT of the image (padded to multiples of 8), and original H,W for inverse crop.
    We apply 2D DCT per 8×8 block (JPEG-like).
    """
    H, W = full_arr.shape
    H8 = (H + 7) // 8 * 8
    W8 = (W + 7) // 8 * 8
    padded = np.pad(full_arr, ((0, H8 - H), (0, W8 - W)), mode='edge').astype(np.float64)
    D = np.empty((H8, W8), dtype=np.float64)
    for i in range(0, H8, 8):
        for j in range(0, W8, 8):
            blk = padded[i:i+8, j:j+8]
            D[i:i+8, j:j+8] = dct(dct(blk.T, norm="ortho").T, norm="ortho")
    return D, H, W

def idct2d_blocks_8x8(D: np.ndarray, H: int, W: int):
    """Inverse of dct2d_blocks_8x8; crop back to original H×W."""
    H8, W8 = D.shape
    rec = np.empty_like(D)
    for i in range(0, H8, 8):
        for j in range(0, W8, 8):
            blk = D[i:i+8, j:j+8]
            rec[i:i+8, j:j+8] = idct(idct(blk.T, norm="ortho").T, norm="ortho")
    return rec[:H, :W]

# ===============================
# QIM primitives (for Σ quantization inside SVD)
# ===============================

def embed_bit_by_qim(coeff_value: float, bit: int, step: float) -> float:
    """
    Scalar QIM embed of 1 bit into a single coefficient value.
    Places coeff into the center of the nearest bit-labeled interval of size 'step'.
    - bit=0 → interval centered at 0.25*step
    - bit=1 → interval centered at 0.75*step
    """
    q = np.floor(coeff_value / step)
    return (q + (0.25 if bit == 0 else 0.75)) * step

def qim_log_likelihood_ratio(coeff_value: float, step: float) -> float:
    """
    Soft decision: >0 means 'closer to bit=1', <0 means 'closer to bit=0'.
    Computed via distances to the 0.25*step and 0.75*step targets within the quantization cell.
    """
    x = coeff_value / step
    frac = x - np.floor(x)
    return abs(frac - 0.25) - abs(frac - 0.75)

# ===============================
# SVD-on-DCT-block helpers (Σ-only QIM)
# ===============================

def embed_bit_in_block_svd_qim(
    dct_block_8x8: np.ndarray,
    bit: int,
    step: float,
    svd_index: int = 0,
    patch_rows=(2, 6),
    patch_cols=(2, 6)
) -> None:
    """
    In-place modification of a mid-band patch inside an 8×8 DCT block via Σ-only QIM.

    Steps:
    1) Extract a mid-band patch from the 8×8 DCT block (rows 2..5, cols 2..5 by default).
    2) Compute SVD of the patch: P = U * diag(S) * V^T.
    3) Quantize ONE singular value S[svd_index] using QIM to embed 'bit'.
    4) Reconstruct the patch with the modified singular value and write it back.
    """
    r0, r1 = patch_rows
    c0, c1 = patch_cols
    patch = dct_block_8x8[r0:r1, c0:c1]
    U, S, Vt = np.linalg.svd(patch, full_matrices=False)
    k = min(svd_index, len(S) - 1)
    S[k] = embed_bit_by_qim(S[k], int(bit), step)
    patch_mod = (U * S) @ Vt  # equivalent to U @ diag(S) @ Vt
    dct_block_8x8[r0:r1, c0:c1] = patch_mod

def svd_qim_llr_for_block(
    dct_block_8x8: np.ndarray,
    step: float,
    svd_index: int = 0,
    patch_rows=(2, 6),
    patch_cols=(2, 6)
) -> float:
    """
    Return soft evidence (LLR) for the embedded bit from a single 8×8 block.
    """
    r0, r1 = patch_rows
    c0, c1 = patch_cols
    patch = dct_block_8x8[r0:r1, c0:c1]
    U, S, Vt = np.linalg.svd(patch, full_matrices=False)
    k = min(svd_index, len(S) - 1)
    return qim_log_likelihood_ratio(S[k], step)

# ===============================
# Tiling (for crop/small cutout robustness)
# ===============================

def compute_tile_slices(H: int, W: int, tiles=(2, 2)):
    """
    Split an array into a grid of tiles. Returns list of (rows_slice, cols_slice, tile_r, tile_c).
    We embed/extract bits across tiles with independent RNG shuffles, improving spatial spread.
    """
    tr, tc = tiles
    h_edges = [0] + [(H * r) // tr for r in range(1, tr)] + [H]
    w_edges = [0] + [(W * c) // tc for c in range(1, tc)] + [W]
    out = []
    for r in range(tr):
        for c in range(tc):
            r0, r1 = h_edges[r], h_edges[r + 1]
            c0, c1 = w_edges[c], w_edges[c + 1]
            out.append((slice(r0, r1), slice(c0, c1), r, c))
    return out

# ===============================
# Single-level SVD-QIM (kept for completeness)
# ===============================

BAND_TAG = {"LL": 101, "LH": 102, "HL": 103, "HH": 104}

def embed_bits_tiled_y_svd_qim(
    y_luma: np.ndarray,
    bits: list,
    secret_key: int,
    quant_step: float,
    svd_index: int = 0,
    svd_patch=((2, 6), (2, 6)),
    bands=("LL", "LH", "HL"),
    wavelet='haar',
    dwt_level=1,
    tiles=(2, 2),
    repeat=0
):
    """
    Embed bitstream into selected DWT subbands of Y using SVD-QIM on 8×8 DCT patches.
    """
    coeffs = dwt2(y_luma, wavelet, dwt_level)
    bands_map = wavelet_subband_map(coeffs)
    pr, pc = svd_patch

    for band in bands:
        sb = bands_map[band]
        for sl_r, sl_c, tile_r, tile_c in compute_tile_slices(*sb.shape, tiles=tiles):
            tile = sb[sl_r, sl_c]
            D, H0, W0 = dct2d_blocks_8x8(tile)
            H8, W8 = D.shape
            n_rows, n_cols = H8 // 8, W8 // 8
            n_blocks = n_rows * n_cols
            n_bits = len(bits)

            # Decide how many repetitions per bit (if repeat==0, fill all blocks)
            rep = repeat if repeat > 0 else max(1, n_blocks // n_bits)
            if n_bits * rep > n_blocks:
                rep = max(1, n_blocks // n_bits)

            # RNG shuffle for block positions (seeded by key+tile+band)
            seed = (int(secret_key) ^ (int(tile_r) << 8) ^ (int(tile_c) << 4) ^ BAND_TAG.get(band, 99)) & 0xFFFFFFFF
            rng = np.random.RandomState(seed)
            idx = np.arange(n_blocks)
            rng.shuffle(idx)

            pos = 0
            for bit_idx, bit in enumerate(bits):
                for _ in range(rep):
                    if pos >= n_blocks:
                        break
                    k = int(idx[pos]); pos += 1
                    r = k // n_cols; c = k % n_cols
                    i, j = r * 8, c * 8
                    embed_bit_in_block_svd_qim(
                        D[i:i+8, j:j+8], int(bit), quant_step,
                        svd_index=svd_index, patch_rows=pr, patch_cols=pc
                    )

            sb[sl_r, sl_c] = idct2d_blocks_8x8(D, H0, W0)

        bands_map[band] = sb

    for band in bands:
        update_wavelet_subband(coeffs, band, bands_map[band])
    y_watermarked = idwt2(coeffs, wavelet)
    return y_watermarked

def extract_bits_tiled_y_svd_qim(
    y_luma: np.ndarray,
    n_bits: int,
    secret_key: int,
    quant_step: float,
    svd_index: int = 0,
    svd_patch=((2, 6), (2, 6)),
    bands=("LL", "LH", "HL"),
    wavelet='haar',
    dwt_level=1,
    tiles=(2, 2),
):
    """
    Extract bitstream using soft voting over randomized, repeated block positions.
    Returns a list of {0,1}.
    """
    coeffs = dwt2(y_luma, wavelet, dwt_level)
    bands_map = wavelet_subband_map(coeffs)
    soft_sum = np.zeros(n_bits, dtype=np.float64)
    pr, pc = svd_patch

    for band in bands:
        sb = bands_map[band]
        for sl_r, sl_c, tile_r, tile_c in compute_tile_slices(*sb.shape, tiles=tiles):
            tile = sb[sl_r, sl_c]
            D, H0, W0 = dct2d_blocks_8x8(tile)
            H8, W8 = D.shape
            n_rows, n_cols = H8 // 8, W8 // 8
            n_blocks = n_rows * n_cols

            # How many blocks per bit (mirror of embed when repeat=0)
            rep = max(1, n_blocks // n_bits)

            seed = (int(secret_key) ^ (int(tile_r) << 8) ^ (int(tile_c) << 4) ^ BAND_TAG.get(band, 99)) & 0xFFFFFFFF
            rng = np.random.RandomState(seed)
            idx = np.arange(n_blocks)
            rng.shuffle(idx)

            pos = 0
            for bit_idx in range(n_bits):
                score = 0.0
                for _ in range(rep):
                    if pos >= n_blocks:
                        break
                    k = int(idx[pos]); pos += 1
                    r = k // n_cols; c = k % n_cols
                    i, j = r * 8, c * 8
                    score += svd_qim_llr_for_block(
                        D[i:i+8, j:j+8], quant_step,
                        svd_index=svd_index, patch_rows=pr, patch_cols=pc
                    )
                soft_sum[bit_idx] += score

    # Soft decision across all bands/tiles:
    out_bits = (soft_sum >= 0).astype(np.uint8).tolist()
    return out_bits

# ===============================
# Multi-level variants (LEVEL>=1)
# ===============================

def bands_all_levels(coeffs, use_LL=True, use_HV=True, use_D=False):
    """
    Create a list of (level_index, band_name, array_view) across levels.
    Level indexing: L is deepest. For each level, return H (horizontal), V (vertical), D (diagonal) as requested.
    """
    out = []
    L = len(coeffs) - 1
    if use_LL:
        out.append((L, 'LL', coeffs[0]))
    for l in range(L, 0, -1):
        (cH, cV, cD) = coeffs[L - l + 1]
        if use_HV:
            out.append((l, 'H', cH))
            out.append((l, 'V', cV))
        if use_D:
            out.append((l, 'D', cD))
    return out

def set_band_at_level(coeffs, level_index, band_name, new_arr):
    """Write back band at specific level."""
    L = len(coeffs) - 1
    if band_name == 'LL':
        coeffs[0] = new_arr; return
    idx = L - level_index + 1
    cH, cV, cD = coeffs[idx]
    if band_name == 'H':
        coeffs[idx] = (new_arr, cV, cD)
    elif band_name == 'V':
        coeffs[idx] = (cH, new_arr, cD)
    elif band_name == 'D':
        coeffs[idx] = (cH, cV, new_arr)

def embed_bits_y_multilevel_svd_qim(
    y_luma: np.ndarray,
    bits: list,
    secret_key: int,
    quant_step: float,
    svd_index: int = 0,
    svd_patch=((2, 6), (2, 6)),
    wavelet='haar',
    dwt_levels=2,
    tiles=(2, 2),
    repeat=0,
    include_LL=False,
    include_D=False
):
    """
    Multi-level version: pad Y to match DWT stride, embed across H/V (and optionally LL/D) at all levels.
    """
    # Pad so that after multi-level transforms, we still align to multiples of 8 for DCT blocks
    stride = 1 << dwt_levels
    mult = stride if stride >= 8 else 8
    H0, W0 = y_luma.shape
    Hm = (H0 + mult - 1) // mult * mult
    Wm = (W0 + mult - 1) // mult * mult
    y_pad = np.pad(y_luma, ((0, Hm - H0), (0, Wm - W0)), mode='edge')

    coeffs = dwt2(y_pad, wavelet, dwt_levels)
    target_bands = bands_all_levels(coeffs, use_LL=include_LL, use_HV=True, use_D=include_D)
    pr, pc = svd_patch

    for lvl, bname, sb in target_bands:
        for sl_r, sl_c, tile_r, tile_c in compute_tile_slices(*sb.shape, tiles=tiles):
            tile = sb[sl_r, sl_c]
            D, Ht, Wt = dct2d_blocks_8x8(tile)
            H8, W8 = D.shape
            n_rows, n_cols = H8 // 8, W8 // 8
            n_blocks = n_rows * n_cols
            n_bits = len(bits)

            rep = repeat if repeat > 0 else max(1, n_blocks // n_bits)
            if n_bits * rep > n_blocks:
                rep = max(1, n_blocks // n_bits)

            tag = {'LL': 201, 'H': 202, 'V': 203, 'D': 204}[bname]
            seed = (int(secret_key) ^ (int(tile_r) << 9) ^ (int(tile_c) << 5) ^ (lvl << 1) ^ tag) & 0xFFFFFFFF
            rng = np.random.RandomState(seed)
            idx = np.arange(n_blocks); rng.shuffle(idx)

            pos = 0
            for bit_idx, bit in enumerate(bits):
                for _ in range(rep):
                    if pos >= n_blocks:
                        break
                    k = int(idx[pos]); pos += 1
                    r = k // n_cols; c = k % n_cols
                    i, j = r * 8, c * 8
                    embed_bit_in_block_svd_qim(
                        D[i:i+8, j:j+8], int(bit), quant_step,
                        svd_index=svd_index, patch_rows=pr, patch_cols=pc
                    )

            sb[sl_r, sl_c] = idct2d_blocks_8x8(D, Ht, Wt)
        set_band_at_level(coeffs, lvl, bname, sb)

    y_wm = idwt2(coeffs, wavelet)
    return y_wm[:H0, :W0]

def extract_bits_y_multilevel_svd_qim(
    y_luma: np.ndarray,
    n_bits: int,
    secret_key: int,
    quant_step: float,
    svd_index: int = 0,
    svd_patch=((2, 6), (2, 6)),
    wavelet='haar',
    dwt_levels=2,
    tiles=(2, 2),
    include_LL=False,
    include_D=False
):
    """
    Multi-level extraction with soft voting across all used bands/tiles/levels.
    """
    stride = 1 << dwt_levels
    mult = stride if stride >= 8 else 8
    H0, W0 = y_luma.shape
    Hm = (H0 + mult - 1) // mult * mult
    Wm = (W0 + mult - 1) // mult * mult
    y_pad = np.pad(y_luma, ((0, Hm - H0), (0, Wm - W0)), mode='edge')

    coeffs = dwt2(y_pad, wavelet, dwt_levels)
    soft_sum = np.zeros(n_bits, dtype=np.float64)
    target_bands = bands_all_levels(coeffs, use_LL=include_LL, use_HV=True, use_D=include_D)
    pr, pc = svd_patch

    for lvl, bname, sb in target_bands:
        for sl_r, sl_c, tile_r, tile_c in compute_tile_slices(*sb.shape, tiles=tiles):
            tile = sb[sl_r, sl_c]
            D, Ht, Wt = dct2d_blocks_8x8(tile)
            H8, W8 = D.shape
            n_rows, n_cols = H8 // 8, W8 // 8
            n_blocks = n_rows * n_cols

            rep = max(1, n_blocks // n_bits)
            tag = {'LL': 201, 'H': 202, 'V': 203, 'D': 204}[bname]
            seed = (int(secret_key) ^ (int(tile_r) << 9) ^ (int(tile_c) << 5) ^ (lvl << 1) ^ tag) & 0xFFFFFFFF
            rng = np.random.RandomState(seed)
            idx = np.arange(n_blocks); rng.shuffle(idx)

            pos = 0
            for bit_idx in range(n_bits):
                score = 0.0
                for _ in range(rep):
                    if pos >= n_blocks:
                        break
                    k = int(idx[pos]); pos += 1
                    r = k // n_cols; c = k % n_cols
                    i, j = r * 8, c * 8
                    score += svd_qim_llr_for_block(
                        D[i:i+8, j:j+8], quant_step,
                        svd_index=svd_index, patch_rows=pr, patch_cols=pc
                    )
                soft_sum[bit_idx] += score

    out_bits = (soft_sum >= 0).astype(np.uint8).tolist()
    return out_bits

# ===============================
# Attack helpers (for testing)
# ===============================

def save_jpeg(input_path: str, out_path: str, quality: int = 75):
    Image.open(input_path).convert("RGB").save(out_path, quality=quality, subsampling=0, optimize=False)

def save_small_random_cutout(
    input_path: str,
    out_path: str,
    area_ratio: float = 0.01,
    num_patches: int = 1,
    shape: str = "rect",          # 'rect' or 'circle'
    fill: str = "noise",          # 'noise'|'black'|'avg'|'blur'|'inpaint'
    blur_kernel: int = 11,
    seed: int | None = None
):
    """
    Remove one or more small regions from the image while keeping the same size.

    This simulates 'tiny defects' or stickers/dust (small cutout) without global resizing/cropping.
    """
    im = Image.open(input_path).convert("RGB")
    arr = np.array(im, dtype=np.uint8)
    H, W, C = arr.shape
    rng = np.random.default_rng(seed)

    def mask_rect(x0, y0, w, h):
        m = np.zeros((H, W), dtype=np.uint8)
        m[y0:y0+h, x0:x0+w] = 255
        return m

    def mask_circle(x0, y0, w, h):
        m = np.zeros((H, W), dtype=np.uint8)
        cy, cx = y0 + h // 2, x0 + w // 2
        r = int(0.5 * max(w, h))
        yy, xx = np.ogrid[:H, :W]
        circle = (yy - cy) ** 2 + (xx - cx) ** 2 <= r * r
        m[circle] = 255
        return m

    for _ in range(max(1, int(num_patches))):
        # Patch size ~ sqrt(area_ratio) * (W,H) → area ≈ area_ratio * W * H
        w = max(1, int(round(np.sqrt(area_ratio) * W)))
        h = max(1, int(round(np.sqrt(area_ratio) * H)))
        x0 = int(rng.integers(0, max(1, W - w)))
        y0 = int(rng.integers(0, max(1, H - h)))
        mask = mask_rect(x0, y0, w, h) if shape == "rect" else mask_circle(x0, y0, w, h)

        if fill == "black":
            arr[mask == 255] = 0
        elif fill == "avg":
            mean = arr.reshape(-1, C).mean(axis=0).astype(np.uint8)
            arr[mask == 255] = mean
        elif fill == "noise":
            noise = rng.integers(0, 256, size=(H, W, C), dtype=np.uint8)
            arr[mask == 255] = noise[mask == 255]
        elif fill == "blur":
            k = blur_kernel if blur_kernel % 2 == 1 else blur_kernel + 1
            blurred = cv2.GaussianBlur(arr, (k, k), 0)
            arr[mask == 255] = blurred[mask == 255]
        elif fill == "inpaint":
            bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
            bgr = cv2.inpaint(bgr, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
            arr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        else:
            raise ValueError("fill must be one of {'noise','black','avg','blur','inpaint'}")

    Image.fromarray(arr).save(out_path)

def _unsharp(y: np.ndarray, k: float = 0.7, radius: int = 3) -> np.ndarray:
    # y is float64 0..255
    r = radius if radius % 2 == 1 else radius + 1
    g = cv2.GaussianBlur(y, (r, r), 0)
    s = cv2.addWeighted(y, 1.0 + k, g, -k, 0.0)
    return np.clip(s, 0, 255)

def _try_extract(y_img, n_bits, secret_key, base_step, extractor_kwargs):
    bits = extract_bits_y_multilevel_svd_qim(
        y_img, n_bits, secret_key, base_step, **extractor_kwargs
    )
    # same "sum ±1" score you already use
    score = sum(1 if b else -1 for b in bits)
    return bits, score


# --- Resize (down→up) resampling attack, size restored to original ---
def save_resize_attack(
    input_path: str,
    out_path: str,
    scale: float = 0.75,
    interpolation: str = "bicubic"  # 'nearest'|'bilinear'|'bicubic'|'lanczos'
):
    interp_map = {
        "nearest": cv2.INTER_NEAREST,
        "bilinear": cv2.INTER_LINEAR,
        "bicubic": cv2.INTER_CUBIC,
        "lanczos": cv2.INTER_LANCZOS4,
    }
    interp = interp_map.get(interpolation, cv2.INTER_CUBIC)

    img = cv2.imread(input_path, cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(input_path)
    H, W = img.shape[:2]

    W2 = max(1, int(round(W * scale)))
    H2 = max(1, int(round(H * scale)))

    small = cv2.resize(img, (W2, H2), interpolation=interp)
    # # back to original size (this creates typical resampling artifacts)
    back = cv2.resize(small, (W, H), interpolation=interp)
    cv2.imwrite(out_path, back)

# --- Rotation attack, keeps the original size via warpAffine + border handling ---
def save_rotation_attack(
    input_path: str,
    out_path: str,
    angle_deg: float = 5.0,
    interpolation: str = "bicubic",    # 'nearest'|'bilinear'|'bicubic'|'lanczos'
    border: str = "replicate",         # 'replicate'|'reflect'|'constant_white'|'constant_black'
):
    interp_map = {
        "nearest": cv2.INTER_NEAREST,
        "bilinear": cv2.INTER_LINEAR,
        "bicubic": cv2.INTER_CUBIC,
        "lanczos": cv2.INTER_LANCZOS4,
    }
    border_map = {
        "replicate": (cv2.BORDER_REPLICATE, None),
        "reflect": (cv2.BORDER_REFLECT_101, None),
        "constant_white": (cv2.BORDER_CONSTANT, (255, 255, 255)),
        "constant_black": (cv2.BORDER_CONSTANT, (0, 0, 0)),
    }
    interp = interp_map.get(interpolation, cv2.INTER_CUBIC)
    bmode, bvalue = border_map.get(border, (cv2.BORDER_REPLICATE, None))

    img = cv2.imread(input_path, cv2.IMREAD_COLOR)
    if img is None:
        raise FileNotFoundError(input_path)
    H, W = img.shape[:2]

    M = cv2.getRotationMatrix2D((W / 2.0, H / 2.0), angle_deg, 1.0)
    rotated = cv2.warpAffine(
        img, M, (W, H),
        flags=interp,
        borderMode=bmode,
        borderValue=(0, 0, 0) if bvalue is None else bvalue
    )
    cv2.imwrite(out_path, rotated)

# ===============================
# Helper: extraction with small angle / scale search
# ===============================
def extract_with_rigid_search(
    y_luma, n_bits, secret_key, base_step, *,
    target_shape=None,
    angles=range(-7, 8, 1),
    coarse_scales=(1.00, 0.98, 0.96, 0.94, 0.92, 0.90, 0.88, 0.86, 0.84, 0.82, 0.80, 0.78, 0.76, 0.74),
    refine_halfwidth=0.03,    # ±3% around best
    refine_step=0.002,        # 0.2% steps
    try_unsharp=True,
    **extractor_kwargs
):
    # Normalize to expected grid
    if target_shape is not None and y_luma.shape != tuple(target_shape):
        y_luma = cv2.resize(y_luma, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_CUBIC)

    best = (-1e9, None)  # (score, bits)
    best_params = (0, 1.0)

    H, W = y_luma.shape
    for ang in angles:
        M = cv2.getRotationMatrix2D((W/2, H/2), ang, 1.0)
        rot = cv2.warpAffine(y_luma, M, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)

        # Coarse
        for s in coarse_scales:
            img = rot if abs(s-1.0) < 1e-12 else cv2.resize(cv2.resize(rot, (0,0), fx=s, fy=s, interpolation=cv2.INTER_CUBIC), (W, H), interpolation=cv2.INTER_CUBIC)

            for candidate in ([img, _unsharp(img)] if try_unsharp else [img]):
                bits, score = _try_extract(candidate, n_bits, secret_key, base_step, extractor_kwargs)
                if score > best[0]:
                    best = (score, bits); best_params = (ang, s)

    # Refine around the best scale
    ang0, s0 = best_params
    M = cv2.getRotationMatrix2D((W/2, H/2), ang0, 1.0)
    rot = cv2.warpAffine(y_luma, M, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)

    s_min = max(0.70, s0 - refine_halfwidth)
    s_max = min(1.02, s0 + refine_halfwidth)
    s = s_min
    while s <= s_max + 1e-12:
        img = rot if abs(s-1.0) < 1e-12 else cv2.resize(cv2.resize(rot, (0,0), fx=s, fy=s, interpolation=cv2.INTER_CUBIC), (W, H), interpolation=cv2.INTER_CUBIC)
        for candidate in ([img, _unsharp(img)] if try_unsharp else [img]):
            bits, score = _try_extract(candidate, n_bits, secret_key, base_step, extractor_kwargs)
            if score > best[0]:
                best = (score, bits)
        s += refine_step

    return best[1]


# ===============================
# Demo / Main (focus: JPEG compression & small cutout attacks)
# ===============================
if __name__ == "__main__":
    # ---- Paths ----
    HOST_PATH = "./images/original.jpeg"   # color host, any size (no resize at embed)
    WM_PATH   = "./images/watermark.jpg"      # watermark will be binarized to 32×32

    os.makedirs("./result", exist_ok=True)
    os.makedirs("./images", exist_ok=True)

    # ---- Parameters ----
    secret_key     = 1234567890
    quant_step     = 90.0             # QIM step for σ_k (tune vs PSNR/robustness)
    quant_step_LL = 45.0
    include_LL    = True           # whether to use LL band (more robust, less imperceptible)
    wavelet_name   = 'haar'
    dwt_levels     = 4
    watermark_size = 32
    tile_grid      = (2, 2)

    # SVD settings
    svd_index      = 0                # which σ_k to quantize (0 = largest)
    svd_patch_rc   = ((2, 6), (2, 6)) # mid-band patch inside each 8×8 DCT block

    # ---- Load host (Y, Cb, Cr) ----
    y_luma, cb_chroma, cr_chroma, _ = load_rgb_keep_size(HOST_PATH)

    # ---- Prepare watermark bits ----
    wm_gray = Image.open(WM_PATH).convert('L').resize((watermark_size, watermark_size), Image.LANCZOS)
    wm_bin  = (np.array(wm_gray, dtype=np.uint8) > 127).astype(np.uint8)
    payload_bits = wm_bin.reshape(-1).tolist()

    # ---- Embed (multi-level, details + LL) ----
    y_wm = embed_bits_y_multilevel_svd_qim(
        y_luma, payload_bits,
        secret_key=secret_key,
        quant_step=quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        repeat=0,
        include_LL=include_LL,   # now True
        include_D=False
    )

    # One extra light pass only into deepest LL
    y_wm = embed_bits_y_multilevel_svd_qim(
        y_wm, payload_bits,
        secret_key=secret_key ^ 0x55AA,   # different key
        quant_step=quant_step_LL,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=(1, 1),
        repeat=0,
        include_LL=True,
        include_D=False
)


    WATERMARKED_PATH = "./images/atkpic.jpg"
    # save_rgb_from_y(y_wm, cb_chroma, cr_chroma, WATERMARKED_PATH)

    # ---- Extract (clean) ----
    y2, _, _, _ = load_rgb_keep_size(WATERMARKED_PATH)

    bits_clean = extract_bits_y_multilevel_svd_qim(
        y2, watermark_size * watermark_size,
        secret_key=secret_key,
        quant_step=quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=False,
        include_D=False
    )
    wm_rec = np.array(bits_clean, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_rec * 255).astype(np.uint8)).save("./result/Extracted_Clean_MultiLevel.png")



    # --- JPEG compression attack ---
    ATTACK_JPEG = "./images/attack_q75-Multilevel.jpg"
    save_jpeg(WATERMARKED_PATH, ATTACK_JPEG, quality=75)
    y_jpeg, _, _, _ = load_rgb_keep_size(ATTACK_JPEG)

    bits_jpeg = extract_bits_y_multilevel_svd_qim(
        y_jpeg, watermark_size * watermark_size,
        secret_key=secret_key,
        quant_step=quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=False,
        include_D=False
    )
    wm_jpeg = np.array(bits_jpeg, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_jpeg * 255).astype(np.uint8)).save("./result/Extracted_JPEG_q75-Multilevel.png")

    # --- Small random cutout attack (looks like “tiny noise specks”) ---
    ATTACK_SMALL = "./images/attack_small_cutout.png"
    save_small_random_cutout(
        WATERMARKED_PATH, ATTACK_SMALL,
        area_ratio=0.001,   # ~0.1% of image area per patch
        num_patches=50,     # many tiny specks
        shape="rect",
        fill="noise",
        seed=secret_key
    )
    y_sc, _, _, _ = load_rgb_keep_size(ATTACK_SMALL)

    bits_sc = extract_bits_y_multilevel_svd_qim(
        y_sc, watermark_size * watermark_size,
        secret_key=secret_key,
        quant_step=quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=False,
        include_D=False
    )
    wm_sc = np.array(bits_sc, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_sc * 255).astype(np.uint8)).save("./result/Extracted_Small_Cutout.png")

    # --- Resize (downsample→upsample) attack ---
    ATTACK_RESIZE = "./images/attack_resize_0p75.png"
    save_resize_attack(WATERMARKED_PATH, ATTACK_RESIZE, scale=0.75, interpolation="bicubic")
    y_rs, _, _, _ = load_rgb_keep_size(ATTACK_RESIZE)

    # After clean extract (y2 is the watermarked image's Y)
    orig_shape = y2.shape

    bits_rs = extract_with_rigid_search(
        y_rs, watermark_size * watermark_size,
        secret_key, quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=include_LL,
        include_D=False,
        target_shape=orig_shape        # NEW: force the attacked image back to original grid
    )

    wm_rs = np.array(bits_rs, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_rs * 255).astype(np.uint8)).save("./result/Extracted_Resize_0p75.png")

    # --- Rotation attack (keep original size) ---
    ATTACK_ROT = "./images/attack_rotate_5deg.png"
    save_rotation_attack(WATERMARKED_PATH, ATTACK_ROT, angle_deg=5.0, interpolation="bicubic", border="replicate")
    y_rot, _, _, _ = load_rgb_keep_size(ATTACK_ROT)

    bits_rot = extract_with_rigid_search(
        y_rot, watermark_size * watermark_size,
        secret_key, quant_step,
        svd_index=svd_index,
        svd_patch=svd_patch_rc,
        wavelet=wavelet_name,
        dwt_levels=dwt_levels,
        tiles=tile_grid,
        include_LL=include_LL,
        include_D=False,
        target_shape=orig_shape        # same size anyway, but keeps logic uniform
    )

    wm_rot = np.array(bits_rot, dtype=np.uint8).reshape(watermark_size, watermark_size)
    Image.fromarray((wm_rot * 255).astype(np.uint8)).save("./result/Extracted_Rotate_5deg.png")


    # ---- Metrics ----
    host_bgr = cv2.imread(HOST_PATH)
    watermarked_bgr = cv2.imread(WATERMARKED_PATH)
    # PSNR = psnr(host_bgr, watermarked_bgr)
    # print(f"PSNR (host vs watermarked): {PSNR:.4f} dB")

    ber_clean, stats_clean = bit_error_rate(wm_bin, wm_rec, return_counts=True)
    print("\nClean extraction vs original")
    print(f"BER: {ber_clean:.4f} | Acc: {stats_clean['accuracy']:.4f} | Errors: {stats_clean['errors']}/{stats_clean['total']}")

    ber_jpeg, stats_jpeg = bit_error_rate(wm_bin, wm_jpeg, return_counts=True)
    print("\nJPEG q=75 vs original")
    print(f"BER: {ber_jpeg:.4f} | Acc: {stats_jpeg['accuracy']:.4f} | Errors: {stats_jpeg['errors']}/{stats_jpeg['total']}")

    ber_sc, stats_sc = bit_error_rate(wm_bin, wm_sc, return_counts=True)
    print("\nSmall cutout vs original")
    print(f"BER: {ber_sc:.4f} | Acc: {stats_sc['accuracy']:.4f} | Errors: {stats_sc['errors']}/{stats_sc['total']}")

    ber_rs, stats_rs = bit_error_rate(wm_bin, wm_rs, return_counts=True)
    print("\nResize (0.75x down -> back up) vs original")
    print(f"BER: {ber_rs:.4f} | Acc: {stats_rs['accuracy']:.4f} | Errors: {stats_rs['errors']}/{stats_rs['total']}")

    ber_rot, stats_rot = bit_error_rate(wm_bin, wm_rot, return_counts=True)
    print("\nRotation (+5°) vs original")
    print(f"BER: {ber_rot:.4f} | Acc: {stats_rot['accuracy']:.4f} | Errors: {stats_rot['errors']}/{stats_rot['total']}")

    print("\nMethod: SVD_QIM (Σ-only on mid-band patch of each 8×8 DCT block)")
    print("Watermarked image:", WATERMARKED_PATH)
    print("Extracted images saved in ./result")


Clean extraction vs original
BER: 0.5752 | Acc: 0.4248 | Errors: 589/1024

JPEG q=75 vs original
BER: 0.5742 | Acc: 0.4258 | Errors: 588/1024

Small cutout vs original
BER: 0.5723 | Acc: 0.4277 | Errors: 586/1024

Resize (0.75x down -> back up) vs original
BER: 0.5713 | Acc: 0.4287 | Errors: 585/1024

Rotation (+5°) vs original
BER: 0.5693 | Acc: 0.4307 | Errors: 583/1024

Method: SVD_QIM (Σ-only on mid-band patch of each 8×8 DCT block)
Watermarked image: ./images/atkpic.jpg
Extracted images saved in ./result
