In [3]:
"""
Slide-256 Full-Image Inference (PRNU-CNN / ELA-CNN / CLIP-LogReg + LR Fusion)
-----------------------------------------------------------------------------
目標：用 256×256 覆蓋整張圖（可重疊），對每一塊做推論，輸出：
  1) 每塊的 logit 與機率、座標、預測
  2) 疊加可視化（格子上色 + 半透明熱度覆蓋）
  3) 整張圖的彙總結果（多種匯總策略）

特點：
- 非 .jpg 來源會先做臨時 .jpg（EXIF 方向修正、RGB、4:4:4, Q=95），以利 ELA/CLIP；
  PRNU 直接用 numpy tile 計算，不受 jpg 轉檔影響。
- CLIP 針對每個 tile 直接用 256×256 JPG（4:4:4, Q=95），與 Center-256 規格一致。
- 融合順序固定 [PRNU, ELA, CLIP]，優先用 sklearn LR（fusion_lr.pkl），否則 fallback 權重。

可直接放到 Jupyter；下方 __main__ 給出用法範例。
"""
from __future__ import annotations

import os
import math
import json
import uuid
from io import BytesIO
from pathlib import Path
from typing import Dict, Any, List, Tuple

# CUDA/encoding 建議
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:128")
os.environ.setdefault("PYTHONUTF8","1")
os.environ.setdefault("PYTHONIOENCODING","utf-8")

import numpy as np
import torch
import torch.nn as nn
from PIL import Image, ImageChops, ImageOps, ImageDraw, ImageFont
from skimage import io as skio
from skimage.util import img_as_float32
from skimage.restoration import denoise_wavelet

# 路徑（照你的專案）
SCRIPT_ROOT     = Path("/home/yaya/ai-detect-proj/Script")
SAVED_MODELS    = SCRIPT_ROOT / "saved_models"

PRNU_MODEL_PATH = SAVED_MODELS / "prnu_fastcnn_best.pt"
ELA_MODEL_PATH  = SAVED_MODELS / "ela_fastcnn_best.pt"

CLIP_LOGREG_PKL = SAVED_MODELS / "clip_logreg_gpu.pkl"    # cuML LogisticRegression（建議）
CLIP_SVM_CUML   = SAVED_MODELS / "clip_svm_gpu.pkl"       # cuML SVM（回退）
CLIP_SVM_TORCH  = SAVED_MODELS / "clip_svm_gpu_torch.pt"  # Torch 線性 SVM（回退）
CLIP_PLATT_PKL  = SAVED_MODELS / "clip_platt.pkl"         #（可選）Platt 標定

FUSER_PKL       = SAVED_MODELS / "fusion_lr.pkl"
FUSER_META      = SAVED_MODELS / "fusion_lr_meta.json"

# 參數
SEED            = 42
TILE            = 256      # tile 尺寸
STRIDE          = 128      # 步長（128 = 50% 重疊）
ELA_QUALITY     = 90
ELA_SCALE       = 15
ELA_FEASZ       = 128
PRNU_MODE       = "soft"
PRNU_WAVELET    = "db8"
PRNU_Q_MODE     = "per_file"
PRNU_Q_PERC     = 0.999
PRNU_Q_SAMPLES  = 4096

FORCE_JPG_NONJPG = True
FORCE_JPG_FOR    = {"ela": True, "clip": True, "prnu": False}  # prnu 不依賴 jpg

JPEG_FORCE_QUALITY     = 95
JPEG_FORCE_SUBSAMPLING = 0     # 4:4:4

CLIP_BACKBONE   = "ViT-L-14"
CLIP_PRETRAINED = {"ViT-L-14":"laion2b_s32b_b82k","ViT-B-32":"laion400m_e32","ViT-L-14-336":"laion2b_s32b_b82k"}.get(CLIP_BACKBONE, "laion2b_s32b_b82k")

# 裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========= 小工具 =========
def _safe_logit(p: float, eps: float = 1e-6, clamp: float = 20.0) -> float:
    p = float(np.clip(p, eps, 1.0 - eps))
    z = math.log(p) - math.log(1.0 - p)
    return float(np.clip(z, -clamp, clamp))

def as_jpg_if_needed(p: Path,
                     quality: int = JPEG_FORCE_QUALITY,
                     subsampling: int | str = JPEG_FORCE_SUBSAMPLING) -> Tuple[Path, str | None]:
    p = Path(p)
    if p.suffix.lower() == ".jpg":
        return p, None
    img = Image.open(p)
    try:
        img = ImageOps.exif_transpose(img)
    except Exception:
        pass
    if img.mode != "RGB":
        img = img.convert("RGB")
    tmp = p.with_suffix(f".tmp_infer_{uuid.uuid4().hex[:8]}.jpg")
    buf = BytesIO(); img.save(buf, format="JPEG", quality=int(quality), subsampling=subsampling)
    with open(tmp, "wb") as f:
        f.write(buf.getvalue())
    return tmp, str(tmp)

# ========= PRNU/ELA CNN =========
class DSBlock(nn.Module):
    def __init__(self, c_in, c_out, stride=1):
        super().__init__()
        self.dw = nn.Conv2d(c_in, c_in, 3, stride=stride, padding=1, groups=c_in, bias=False)
        self.bn1= nn.BatchNorm2d(c_in)
        self.pw = nn.Conv2d(c_in, c_out, 1, bias=False)
        self.bn2= nn.BatchNorm2d(c_out)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.act(self.bn1(self.dw(x)))
        x = self.act(self.bn2(self.pw(x)))
        return x

class FastCNN_1ch(nn.Module):
    def __init__(self, base=32, num_classes=2):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv2d(1, base, 3, padding=1, bias=False),
                                  nn.BatchNorm2d(base), nn.ReLU(inplace=True))
        self.stage= nn.Sequential(DSBlock(base,base*2,1), DSBlock(base*2,base*4,2),
                                  DSBlock(base*4,base*4,1), DSBlock(base*4,base*8,2),
                                  DSBlock(base*8,base*8,1))
        self.head = nn.Sequential(nn.Conv2d(base*8, base*8, 1, bias=False),
                                  nn.BatchNorm2d(base*8), nn.ReLU(inplace=True))
        self.pool = nn.AdaptiveAvgPool2d(1); self.fc = nn.Linear(base*8,2)
    def forward(self, x):
        x=self.stem(x); x=self.stage(x); x=self.head(x); x=self.pool(x).flatten(1); return self.fc(x)

from collections import OrderedDict

def _safe_load(path, map_location="cpu"):
    try:
        return torch.load(path, map_location=map_location, weights_only=True)
    except TypeError:
        return torch.load(path, map_location=map_location)

def _extract_state_dict(blob):
    if isinstance(blob, dict):
        for k in ("state_dict","model","weights"):
            if k in blob and isinstance(blob[k], dict):
                sd = blob[k]; break
        else:
            sd = {k:v for k,v in blob.items() if torch.is_tensor(v)}
            if not sd: raise RuntimeError("未知 checkpoint 格式")
    elif isinstance(blob, torch.nn.Module):
        sd = blob.state_dict()
    else:
        sd = blob
    if len(sd)>0 and next(iter(sd)).startswith("module."):
        sd = OrderedDict((k[len("module."):], v) for k,v in sd.items())
    return sd

prnu_model = FastCNN_1ch().to(device).eval()
ela_model  = FastCNN_1ch().to(device).eval()
prnu_model.load_state_dict(_extract_state_dict(_safe_load(PRNU_MODEL_PATH, device)))
ela_model .load_state_dict(_extract_state_dict(_safe_load(ELA_MODEL_PATH,  device)))

@torch.no_grad()
def prnu_logit_from_i8(arr_i8: np.ndarray) -> float:
    a = arr_i8
    if a.dtype == np.int8:
        x = a.astype(np.float32) / 127.0
    elif a.dtype == np.uint8:
        x = a.astype(np.float32) / 255.0
    else:
        x = a.astype(np.float32); x = x/255.0 if x.max()>1.5 else x
    x = x - x.mean()
    t = torch.from_numpy(x[None,None,...]).to(device)
    p1 = torch.softmax(prnu_model(t), dim=1)[0,1].item()
    return _safe_logit(p1)

@torch.no_grad()
def ela_logit_from_i8(arr_i8: np.ndarray) -> float:
    a = arr_i8
    if a.dtype == np.int8:
        x = (a.astype(np.float32) + 128.0) / 255.0  # offset128
    elif a.dtype == np.uint8:
        x = a.astype(np.float32) / 255.0
    else:
        x = a.astype(np.float32); x = x/255.0 if x.max()>1.5 else x
    x = x - x.mean()
    t = torch.from_numpy(x[None,None,...]).to(device)
    p1 = torch.softmax(ela_model(t), dim=1)[0,1].item()
    return _safe_logit(p1)

# ========= open_clip + CLIP 分支 =========
_openclip = {"model":None,"pre":None,"dev":"cpu"}

def _cuml_ok():
    try:
        import cuml, cupy  # noqa: F401
        return True
    except Exception:
        return False

@torch.no_grad()
def load_openclip():
    if _openclip["model"] is None:
        import open_clip
        dev = "cuda" if torch.cuda.is_available() else "cpu"
        model, _, pre = open_clip.create_model_and_transforms(CLIP_BACKBONE, pretrained=CLIP_PRETRAINED)
        model = model.to(dev).eval()
        _openclip.update(model=model, pre=pre, dev=dev)
    return _openclip["model"], _openclip["pre"], _openclip["dev"]

class CLIPSingle:
    def __init__(self):
        self.mode=None; self.backend=None; self.model=None; self.D=None; self.platt=None
        self._load()
    def _load(self):
        # cuML LR（優先）
        if CLIP_LOGREG_PKL.exists():
            import joblib
            obj = joblib.load(CLIP_LOGREG_PKL)
            self.model = obj["model"] if isinstance(obj, dict) and "model" in obj else obj
            self.mode, self.backend, self.D = "logreg", "cuml", 1024
            print("✅ CLIP: cuML LogisticRegression（logit）"); return
        # cuML SVM（回退）
        if CLIP_SVM_CUML.exists() and _cuml_ok():
            import joblib
            clf = joblib.load(CLIP_SVM_CUML)
            self.model = clf["model"] if isinstance(clf, dict) and "model" in clf else clf
            self.mode, self.backend, self.D = "svm", "cuml", 1024
            print("⚠️ CLIP: cuML SVM（margin）")
        # Torch 線性 SVM（回退）
        elif CLIP_SVM_TORCH.exists():
            sd = _safe_load(CLIP_SVM_TORCH, "cpu")
            state = sd["state_dict"] if isinstance(sd, dict) and "state_dict" in sd else sd
            self.D = int(sd["D"]) if isinstance(sd, dict) and "D" in sd else state["weight"].shape[1]
            lin = nn.Linear(self.D, 1, bias=True).eval(); lin.load_state_dict(state)
            self.model = lin; self.mode, self.backend = "svm", "torch"
            print(f"⚠️ CLIP: Torch SVM（margin, D={self.D}）")
        else:
            print("ℹ️ 未找到 CLIP 分類器（跳過 CLIP）")
            return
        # 可選：Platt
        if self.mode == "svm" and CLIP_PLATT_PKL.exists():
            import joblib
            try:
                self.platt = joblib.load(CLIP_PLATT_PKL)
                print("✅ 載入 Platt 標定器（SVM → prob）")
            except Exception as e:
                print("⚠️ Platt 載入失敗：", e)

    @torch.no_grad()
    def predict_logit(self, pil_tile_256: Image.Image) -> float | None:
        if self.model is None: return None
        # 直接把 256×256 tile 另存為臨時 JPG，保持與 Center-256 一致
        tmp = Path(f"/tmp/clip_tile_{uuid.uuid4().hex[:8]}.jpg")
        buf = BytesIO(); pil_tile_256.save(buf, format="JPEG", quality=95, subsampling=0)
        with open(tmp, "wb") as f: f.write(buf.getvalue())
        try:
            model, pre, dev = load_openclip()
            img = Image.open(tmp).convert("RGB")
            im  = pre(img).unsqueeze(0).to(dev)
            visual = model.visual
            tokens=None
            def _apply_norm(x):
                if x.ndim==2: x=x.unsqueeze(0)
                if hasattr(visual, "ln_post") and visual.ln_post is not None:
                    x = visual.ln_post(x)
                elif hasattr(visual, "trunk") and hasattr(visual.trunk, "norm") and visual.trunk.norm is not None:
                    x = visual.trunk.norm(x)
                return x.squeeze(0)
            # 嘗試抓 token map
            ok=False
            for try_path in ("trunk.forward_features","visual.forward_features","hook"):
                try:
                    if try_path=="trunk.forward_features" and hasattr(visual,"trunk") and hasattr(visual.trunk,"forward_features"):
                        out = visual.trunk.forward_features(im)
                        if isinstance(out,(tuple,list)): out=out[0]
                        if isinstance(out,dict): out=out.get("x", out.get("tokens", out))
                        if torch.is_tensor(out) and out.ndim==3:
                            tokens=_apply_norm(out.detach()); ok=True; break
                    if try_path=="visual.forward_features" and hasattr(visual,"forward_features"):
                        out = visual.forward_features(im)
                        if isinstance(out,(tuple,list)): out=out[0]
                        if isinstance(out,dict): out=out.get("x", out.get("tokens", out))
                        if torch.is_tensor(out) and out.ndim==3:
                            tokens=_apply_norm(out.detach()); ok=True; break
                    if try_path=="hook":
                        feats={}; h=None
                        try:
                            if hasattr(visual,"trunk") and hasattr(visual.trunk,"blocks") and len(visual.trunk.blocks)>0:
                                target=visual.trunk.blocks[-1]
                                def _hook(_m,_i,o): feats["x"]=o
                                h=target.register_forward_hook(_hook)
                                _ = getattr(visual.trunk, "forward_features", visual.trunk.forward)(im)
                            elif hasattr(visual,"transformer") and hasattr(visual.transformer,"resblocks") and len(visual.transformer.resblocks)>0:
                                target=visual.transformer.resblocks[-1]
                                def _hook(_m,_i,o): feats["x"]=o
                                h=target.register_forward_hook(_hook); _=model.encode_image(im)
                        finally:
                            if h is not None:
                                try: h.remove()
                                except Exception: pass
                        x=feats.get("x",None)
                        if torch.is_tensor(x) and x.ndim==3:
                            tokens=_apply_norm(x.detach()); ok=True; break
                except Exception:
                    pass
            if not ok:
                # 保底：取 ln_post CLS（這裡沒有分類器時不返回）
                return None
            tokens = tokens.float().cpu().numpy().astype(np.float32)
            if tokens.shape[0] <= 1: return None
            pooled = tokens[1:].mean(axis=0)
            pooled /= (np.linalg.norm(pooled)+1e-12)
            # 決策
            if self.backend == "cuml":
                import cupy as cp
                s = self.model.decision_function(cp.asarray(pooled[None,:])).get().astype(np.float32)[0]
                return float(np.clip(s, -20.0, 20.0))
            else:
                with torch.no_grad():
                    s = self.model(torch.from_numpy(pooled[None,:]).float()).squeeze(1).numpy().astype(np.float32)[0]
                if self.mode == "svm" and self.platt is not None:
                    from sklearn.linear_model import LogisticRegression  # noqa
                    import numpy as _np
                    p = float(self.platt.predict_proba(_np.array(s, _np.float32).reshape(1,1))[:,1][0])
                    return _safe_logit(p)
                return float(np.clip(s, -20.0, 20.0))
        finally:
            try: os.unlink(tmp)
            except Exception: pass

clip_single = CLIPSingle()

# ========= 融合器 =========
def read_json_utf8(p: Path):
    with p.open("r", encoding="utf-8") as f:
        return json.load(f)

def load_fuser_and_meta():
    meta=None; kind="dummy"; fuser=None
    try:
        if FUSER_PKL.exists():
            import joblib
            fuser = joblib.load(FUSER_PKL); kind="lr"
    except Exception as e:
        print("⚠️ fusion_lr.pkl 載入失敗：", e)
    if FUSER_META.exists():
        try: meta = read_json_utf8(FUSER_META)
        except Exception as e: print("⚠️ fusion_lr_meta.json 讀取失敗：", e)
    return kind, fuser, meta

fuser_kind, fuser, fmeta = load_fuser_and_meta()

def fuse_logits(logits_dict: Dict[str, float | None]) -> float:
    keys = ["prnu","ela","clip"]
    z = [float(np.nan_to_num(logits_dict.get(k,0.0), nan=0.0, posinf=20.0, neginf=-20.0)) for k in keys]
    X = np.array([z], np.float32)
    if fuser_kind == "lr":
        try:
            proba = fuser.predict_proba(X)[:,1][0]
            return float(np.clip(proba, 1e-6, 1-1e-6))
        except Exception as e:
            print("⚠️ 融合器失敗，fallback：", e)
    if fmeta and "weights_if_no_sklearn" in fmeta:
        w = np.array(fmeta["weights_if_no_sklearn"], np.float32); w = w/(w.sum()+1e-12)
    else:
        w = np.array([1/3,1/3,1/3], np.float32)
    zz = float(np.clip(np.dot(X[0], w), -20.0, 20.0))
    return float(1.0 / (1.0 + np.exp(-zz)))

# ========= 特徵抽取（tile 級） =========
def _to_int8_offset128_from_01(x01: np.ndarray) -> np.ndarray:
    u8 = np.rint(np.clip(x01, 0.0, 1.0) * 255.0).astype(np.uint8)
    return (u8.astype(np.int16) - 128).astype(np.int8)

def ela_i8_from_tile(pil_tile_256: Image.Image) -> np.ndarray:
    buf = BytesIO(); pil_tile_256.save(buf, format="JPEG", quality=int(ELA_QUALITY), subsampling=0, optimize=False)
    buf.seek(0)
    diff = ImageChops.difference(pil_tile_256, Image.open(buf)).point(lambda x: x * ELA_SCALE)
    diff = diff.convert("L").resize((ELA_FEASZ, ELA_FEASZ))
    arr01 = np.asarray(diff, dtype=np.float32) / 255.0
    return _to_int8_offset128_from_01(arr01)

def prnu_i8_from_tile(np_tile_rgb: np.ndarray) -> np.ndarray:
    if np_tile_rgb.ndim == 2:
        np_tile_rgb = np.repeat(np_tile_rgb[...,None], 3, axis=-1)
    gray = np_tile_rgb.mean(axis=2, dtype=np.float32)
    try:
        den = denoise_wavelet(gray, channel_axis=None, mode=PRNU_MODE, wavelet=PRNU_WAVELET, convert2ycbcr=False)
    except TypeError:
        den = denoise_wavelet(gray, multichannel=False, mode=PRNU_MODE, wavelet=PRNU_WAVELET, convert2ycbcr=False)
    residual = gray - den
    residual -= residual.mean()
    if PRNU_Q_MODE == "per_file":
        v = residual.reshape(-1).astype(np.float32, copy=False)
        if v.size > PRNU_Q_SAMPLES:
            rng = np.random.default_rng(SEED)
            idx = rng.integers(0, v.size, size=PRNU_Q_SAMPLES, endpoint=False)
            v = np.abs(v[idx])
        else:
            v = np.abs(v)
        k = int(PRNU_Q_PERC * max(1, v.size-1))
        S = float(max(1e-8, np.partition(v, k)[k]))
    else:
        S = max(1e-6, float(np.std(residual)) * 6.0)
    x = np.clip(residual, -S, S) / S * 127.0
    q = np.rint(x).astype(np.int16)
    q = np.clip(q, -127, 127).astype(np.int8)
    return q

# ========= 覆蓋座標 =========
def make_grid(w: int, h: int, tile: int = TILE, stride: int = STRIDE) -> List[Tuple[int,int,int,int]]:
    xs = list(range(0, max(1, w - tile + 1), stride))
    ys = list(range(0, max(1, h - tile + 1), stride))
    if xs[-1] != w - tile: xs.append(max(0, w - tile))
    if ys[-1] != h - tile: ys.append(max(0, h - tile))
    coords = [(x, y, tile, tile) for y in ys for x in xs]
    return coords

# ========= 可視化 =========
# --- 量文字大小：兼容 Pillow 舊新版本 ---
def _text_size(draw, text, font):
    # Pillow ≥10
    if hasattr(draw, "textbbox"):
        l, t, r, b = draw.textbbox((0, 0), text, font=font)
        return (r - l), (b - t)
    # 後備：有些版本在 font 物件上
    if hasattr(font, "getbbox"):
        l, t, r, b = font.getbbox(text)
        return (r - l), (b - t)
    # 舊版
    if hasattr(font, "getsize"):
        return font.getsize(text)
    # 最後的保底估計
    return (8 * len(text), 12)

def overlay_tiles(base_img: Image.Image,
                  tiles: list[dict],
                  alpha: float = 0.35,
                  draw_frame: bool = True,
                  show_score: bool = True) -> Image.Image:
    W, H = base_img.size

    # 以每像素平均機率做熱度圖
    heat = np.zeros((H, W), dtype=np.float32)
    cnt  = np.zeros((H, W), dtype=np.float32)
    for t in tiles:
        x, y, w, h = t['x'], t['y'], t['w'], t['h']
        p = float(t['prob_fake'])
        heat[y:y+h, x:x+w] += p
        cnt [y:y+h, x:x+w] += 1.0
    cnt[cnt == 0] = 1.0
    heat = heat / cnt
    heat_u8 = np.rint(np.clip(heat, 0.0, 1.0) * 255.0).astype(np.uint8)

    # 轉為紅色熱圖（R=heat, G=0, B=0）
    heat_rgb = np.zeros((H, W, 3), dtype=np.uint8)
    heat_rgb[..., 0] = heat_u8
    base_rgb = base_img.convert('RGB')
    heat_img = Image.fromarray(heat_rgb, mode='RGB')

    # 混合後轉 RGBA，方便畫半透明標籤
    overlay = Image.blend(base_rgb, heat_img, alpha).convert('RGBA')
    draw = ImageDraw.Draw(overlay, 'RGBA')

    # 字型
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
    except Exception:
        font = ImageFont.load_default()

    if draw_frame or show_score:
        for t in tiles:
            x, y, w, h = t['x'], t['y'], t['w'], t['h']
            p  = float(t['prob_fake'])
            lab = int(t['pred_label'])
            color = (255, 0, 0, 255) if lab == 1 else (0, 255, 0, 255)

            if draw_frame:
                draw.rectangle([x, y, x + w, y + h], outline=color, width=2)

            if show_score:
                txt = f"{p:.2f}"
                tw, th = _text_size(draw, txt, font)
                # 黑底半透明文字框
                draw.rectangle([x, y, x + tw + 6, y + th + 4], fill=(0, 0, 0, 127))
                draw.text((x + 3, y + 2), txt, fill=(255, 255, 255, 255), font=font)

    return overlay


# ========= 主流程：整張圖（tile）推論 =========
def infer_image_by_tiles(img_path: str | Path,
                         tile: int = TILE,
                         stride: int = STRIDE,
                         aggregate: str = 'mean_prob',  # 'mean_prob'|'max_prob'|'topk_mean:0.2'|'mean_logit'
                         use_clip: bool = True,
                         save_overlay: bool = True,
                         overlay_alpha: float = 0.35) -> Dict[str, Any]:
    p = Path(img_path)
    tmp_files = []
    # 原圖（僅供 ELA/CLIP 的 jpg 流程需要；PRNU 走 numpy tile）
    if FORCE_JPG_NONJPG:
        p_jpg, t1 = as_jpg_if_needed(p); 
        if t1: tmp_files.append(t1)
    else:
        p_jpg = p

    # 讀圖（給 tile 裁切）
    img_pil = Image.open(p_jpg).convert('RGB')
    W,H = img_pil.size
    coords = make_grid(W, H, tile=tile, stride=stride)

    tiles_out: List[Dict[str, Any]] = []

    # 也準備一份 numpy（PRNU 用），以原始檔路徑讀，避免重壓縮影響
    np_img = np.asarray(Image.open(p).convert('RGB'), dtype=np.uint8)

    for (x,y,w,h) in coords:
        pil_tile = img_pil.crop((x,y,x+w,y+h))
        np_tile  = np.asarray(np_img[y:y+h, x:x+w, :], dtype=np.uint8)

        # 特徵
        prnu_i8 = prnu_i8_from_tile(np_tile)
        ela_i8  = ela_i8_from_tile(pil_tile)

        # 三路 logit
        z_prnu = prnu_logit_from_i8(prnu_i8)
        z_ela  = ela_logit_from_i8(ela_i8)
        z_clip = clip_single.predict_logit(pil_tile) if (use_clip and clip_single.model is not None) else None

        prob_fake = fuse_logits({"prnu": z_prnu, "ela": z_ela, "clip": z_clip})
        pred = int(prob_fake >= 0.5)

        tiles_out.append({
            "x":x, "y":y, "w":w, "h":h,
            "prnu_logit": float(z_prnu),
            "ela_logit":  float(z_ela),
            "clip_logit": (None if z_clip is None else float(z_clip)),
            "prob_fake": float(prob_fake),
            "pred_label": pred,
        })

    # 匯總策略
    probs = np.array([t['prob_fake'] for t in tiles_out], np.float32)
    if aggregate == 'mean_prob':
        whole_prob = float(np.clip(probs.mean() if len(probs)>0 else 0.0, 1e-6, 1-1e-6))
    elif aggregate == 'max_prob':
        whole_prob = float(np.clip(probs.max() if len(probs)>0 else 0.0, 1e-6, 1-1e-6))
    elif aggregate.startswith('topk_mean'):
        try:
            frac = float(aggregate.split(':',1)[1]) if ':' in aggregate else 0.2
        except Exception:
            frac = 0.2
        k = max(1, int(len(probs) * frac))
        idx = np.argsort(-probs)[:k]
        whole_prob = float(np.clip(probs[idx].mean(), 1e-6, 1-1e-6))
    elif aggregate == 'mean_logit':
        lgs = np.array([_safe_logit(float(p)) for p in probs], np.float32)
        lg  = float(np.clip(lgs.mean() if len(lgs)>0 else 0.0, -20.0, 20.0))
        whole_prob = float(1.0/(1.0+np.exp(-lg)))
    else:
        whole_prob = float(np.clip(probs.mean() if len(probs)>0 else 0.0, 1e-6, 1-1e-6))

    whole_pred = int(whole_prob >= 0.5)

    overlay_path = None
    if save_overlay:
        overlay_img = overlay_tiles(Image.open(p).convert('RGB'), tiles_out, alpha=overlay_alpha,
                                    draw_frame=True, show_score=True)
        overlay_path = str(p.with_suffix(".tiles_overlay.png"))
        overlay_img.save(overlay_path)

    out = {
        "path": str(p),
        "image_size": [H, W],
        "tile": tile,
        "stride": stride,
        "aggregate": aggregate,
        "overall": {"prob_fake": float(whole_prob), "pred_label": int(whole_pred)},
        "tiles": tiles_out,
        "overlay_path": overlay_path,
    }

    # 清理臨時檔
    for t in tmp_files:
        try: os.unlink(t)
        except Exception: pass

    return out


✅ CLIP: cuML LogisticRegression（logit）


In [4]:
# ==== HEIC/HEIF/AVIF 支援（註冊外掛 + 強化開圖與轉JPG）====

# 1) 註冊解碼器（可多次呼叫，無副作用）
def _register_heif_avif():
    ok = False
    try:
        import pillow_heif
        pillow_heif.register_heif_opener()
        ok = True
        print("✅ pillow-heif registered")
    except Exception as e:
        print("ℹ️ pillow-heif not available:", e)
    try:
        # 匯入即完成 AVIF 註冊
        import pillow_avif  # noqa: F401
        from pillow_avif import AvifImagePlugin  # noqa: F401
        ok = True or ok
        print("✅ pillow-avif registered")
    except Exception as e:
        print("ℹ️ pillow-avif not available:", e)
    return ok

_register_heif_avif()

# 2) 更穩的開圖（必要時直接用 pillow_heif 讀）
from PIL import Image

def _open_image_any(p: Path) -> Image.Image:
    p = Path(p)
    try:
        img = Image.open(p)
        img.load()  # 立刻讀入，避免延後才報錯
        return img
    except Exception as e:
        # 針對 heif/avif 再試一次（不靠 PIL opener）
        suf = p.suffix.lower()
        if suf in {".heic", ".heif", ".heifs", ".hif", ".avif"}:
            try:
                import pillow_heif
                h = pillow_heif.read_heif(str(p))
                img = Image.frombytes(h.mode, h.size, h.data, "raw")
                return img
            except Exception as e2:
                raise RuntimeError(f"HEIF/AVIF decode failed: {e2}") from e2
        raise

def _to_rgb_no_alpha(img: Image.Image) -> Image.Image:
    # 先套用 EXIF 方向
    try:
        from PIL import ImageOps
        img = ImageOps.exif_transpose(img)
    except Exception:
        pass
    # 去除 alpha（白底貼上）
    if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
        bg = Image.new("RGB", img.size, (255, 255, 255))
        bg.paste(img.convert("RGBA"), mask=img.convert("RGBA").split()[-1])
        return bg
    # 其他色彩空間統一到 RGB
    if img.mode != "RGB":
        return img.convert("RGB")
    return img

# 3) 取代你原本的 as_jpg_if_needed（擴充 .jpeg 與 heic/avif）
def as_jpg_if_needed(p: Path,
                     quality: int = 95,
                     subsampling: int | str = 0) -> Tuple[Path, str | None]:
    """
    - 若原檔已是 .jpg/.jpeg：直接回傳原路徑
    - 否則開圖（含 HEIC/AVIF），做 EXIF 方向、去 alpha、轉 RGB，存臨時 JPG（4:4:4, Q=quality）
    """
    p = Path(p)
    if p.suffix.lower() in {".jpg", ".jpeg"}:
        return p, None

    # 用強化版開圖 + 規範成 RGB 無 alpha
    img = _open_image_any(p)
    img = _to_rgb_no_alpha(img)

    # 輸出臨時 JPG（與你原邏輯一致：4:4:4, Q=95）
    tmp = p.with_suffix(f".tmp_infer_{uuid.uuid4().hex[:8]}.jpg")
    buf = BytesIO()
    img.save(buf, format="JPEG", quality=int(quality), subsampling=subsampling, optimize=False)
    with open(tmp, "wb") as f:
        f.write(buf.getvalue())
    return tmp, str(tmp)

# === 也建議把 PRNU 端的讀取改用 _open_image_any，確保 HEIC 可讀 ===
# 在 infer_image_by_tiles 裡這行：
#   np_img = np.asarray(Image.open(p).convert('RGB'), dtype=np.uint8)
# 換成：
#   np_img = np.asarray(_to_rgb_no_alpha(_open_image_any(p)), dtype=np.uint8)

print("HEIC/AVIF support is ready. 🎉")


✅ pillow-heif registered
✅ pillow-avif registered
HEIC/AVIF support is ready. 🎉


In [5]:
# === Gradio UI with progress (drop-in) ===
import gradio as gr
import numpy as np
from pathlib import Path
from PIL import Image
import tempfile, json, math

def _aggregate_probs(probs: np.ndarray, method: str) -> float:
    probs = np.asarray(probs, np.float32)
    if probs.size == 0: 
        return 1e-6
    if method == "mean_prob":
        p = probs.mean()
    elif method == "max_prob":
        p = probs.max()
    elif method.startswith("topk_mean"):
        try:
            frac = float(method.split(":",1)[1]) if ":" in method else 0.2
        except Exception:
            frac = 0.2
        k = max(1, int(len(probs) * frac))
        idx = np.argsort(-probs)[:k]
        p = probs[idx].mean()
    elif method == "mean_logit":
        # same as你的程式：平均 logit 再 sigmoid
        def _safe_logit(p, eps=1e-6, clamp=20.0):
            p = float(np.clip(p, eps, 1.0 - eps))
            z = math.log(p) - math.log(1.0 - p)
            return float(np.clip(z, -clamp, clamp))
        lgs = np.array([_safe_logit(float(x)) for x in probs], np.float32)
        lg = float(np.clip(lgs.mean(), -20.0, 20.0))
        p = float(1.0/(1.0+np.exp(-lg)))
    else:
        p = probs.mean()
    return float(np.clip(p, 1e-6, 1-1e-6))

def infer_stream(pil_img,
                 tile=256, stride=128, aggregate="topk_mean:0.25",
                 use_clip=True, draw_boxes=True, show_score=True, alpha=0.35,
                 progress=gr.Progress(track_tqdm=True)):
    """
    逐步推論，每 N 個 tile 回傳一次中間結果，帶進度條。
    輸出順序：caption(str), overlay(Image), raw(JSON)
    """
    # 先把上傳圖暫存到檔案，沿用你現有的影像/轉 JPG 流程
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
        pil_img.save(f.name)
        p = Path(f.name)

    try:
        base_rgb = Image.open(p).convert("RGB")
        # ELA/CLIP 用 JPG（跟你的 code 一樣）
        if FORCE_JPG_NONJPG:
            p_jpg, _tmp = as_jpg_if_needed(p)
        else:
            p_jpg = p
        img_pil = Image.open(p_jpg).convert("RGB")
        W, H = img_pil.size

        coords = make_grid(W, H, tile=int(tile), stride=int(stride))
        n = len(coords)
        tiles_out = []

        # PRNU 用原始圖
        np_img = np.asarray(base_rgb, dtype=np.uint8)

        # 每多少步更新一次 UI（大約 20 段）
        every = max(1, n // 20)

        progress(0.0, desc=f"準備中… {W}×{H}, tiles={n}")
        for i, (x,y,w,h) in enumerate(coords):
            # 取 tile
            pil_tile = img_pil.crop((x,y,x+w,y+h))
            np_tile  = np.asarray(np_img[y:y+h, x:x+w, :], dtype=np.uint8)

            # 特徵 → 子模型 logit
            prnu_i8 = prnu_i8_from_tile(np_tile)
            ela_i8  = ela_i8_from_tile(pil_tile)

            z_prnu = prnu_logit_from_i8(prnu_i8)
            z_ela  = ela_logit_from_i8(ela_i8)
            z_clip = clip_single.predict_logit(pil_tile) if (use_clip and clip_single.model is not None) else None

            prob_fake = fuse_logits({"prnu": z_prnu, "ela": z_ela, "clip": z_clip})
            pred = int(prob_fake >= 0.5)

            tiles_out.append({
                "x":x, "y":y, "w":w, "h":h,
                "prnu_logit": float(z_prnu),
                "ela_logit":  float(z_ela),
                "clip_logit": (None if z_clip is None else float(z_clip)),
                "prob_fake": float(prob_fake),
                "pred_label": pred,
            })

            # 進度與中途回傳
            if (i+1) % every == 0 or (i+1) == n:
                progress((i+1)/n, desc=f"推論中… ({i+1}/{n})")
                probs = np.array([t['prob_fake'] for t in tiles_out], np.float32)
                whole_prob = _aggregate_probs(probs, aggregate)
                whole_pred = int(whole_prob >= 0.5)

                # 即時熱圖（用目前累積的 tiles）
                overlay_img = overlay_tiles(base_rgb, tiles_out, alpha=float(alpha),
                                            draw_frame=bool(draw_boxes), show_score=bool(show_score))

                out = {
                    "path": str(p),
                    "image_size": [H, W],
                    "tile": int(tile),
                    "stride": int(stride),
                    "aggregate": aggregate,
                    "overall": {"prob_fake": float(whole_prob), "pred_label": int(whole_pred)},
                    "tiles_done": len(tiles_out),
                    "tiles_total": n,
                    "tiles": tiles_out,   # 注意：這是累積中的 tiles（最後一次即是完整結果）
                }
                caption = f"pred={whole_pred}  prob_fake={whole_prob:.3f}  ({i+1}/{n})"
                yield caption, overlay_img, out

    finally:
        try: p.unlink()
        except Exception: pass


# ===== Gradio 介面 =====
with gr.Blocks(title="AI Image Detector (CUDA, with progress)") as demo:
    gr.Markdown("### Slide-256 Full-Image Inference — PRNU / ELA / CLIP + LR Fusion（含進度條）")
    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type="pil", label="上傳影像")
            with gr.Row():
                tile   = gr.Number(value=256, precision=0, label="Tile size")
                stride = gr.Number(value=128, precision=0, label="Stride")
            aggregate = gr.Dropdown(choices=["mean_prob","max_prob","topk_mean:0.2","topk_mean:0.25","mean_logit"],
                                    value="topk_mean:0.25", label="Aggregator")
            use_clip   = gr.Checkbox(value=True,  label="Use CLIP head")
            draw_boxes = gr.Checkbox(value=True,  label="Draw tile boxes")
            show_score = gr.Checkbox(value=True,  label="Show tile score")
            alpha      = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Overlay alpha")
            run = gr.Button("開始推論", variant="primary")
        with gr.Column(scale=1):
            caption = gr.Label(label="整體預測")
            overlay = gr.Image(label="即時疊加視覺化（會逐步更新）")
            raw     = gr.JSON(label="Raw 輸出（最後一次為完整結果）")

    run.click(
        infer_stream,
        inputs=[inp, tile, stride, aggregate, use_clip, draw_boxes, show_score, alpha],
        outputs=[caption, overlay, raw]
    )

# 建議在 notebook 內用 queue() 取得平滑進度條
demo.queue(max_size=8).launch(share=False, inline=True)

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.




In [6]:
# === 事後多標準彙總：不重跑模型 ===
import gradio as gr
import numpy as np
from PIL import Image

# 同你之前小工具
def _aggregate_probs(probs: np.ndarray, method: str) -> float:
    probs = np.asarray(probs, np.float32)
    if probs.size == 0: 
        return 1e-6
    if method == "mean_prob":
        p = probs.mean()
    elif method == "max_prob":
        p = probs.max()
    elif method.startswith("topk_mean"):
        try:
            frac = float(method.split(":",1)[1]) if ":" in method else 0.5
        except Exception:
            frac = 0.5
        k = max(1, int(len(probs) * frac))
        idx = np.argsort(-probs)[:k]
        p = probs[idx].mean()
    elif method == "mean_logit":
        def _safe_logit(p, eps=1e-6, clamp=20.0):
            p = float(np.clip(p, eps, 1.0 - eps))
            z = np.log(p) - np.log(1.0 - p)
            return float(np.clip(z, -clamp, clamp))
        lgs = np.array([_safe_logit(float(x)) for x in probs], np.float32)
        lg = float(np.clip(lgs.mean(), -20.0, 20.0))
        p = float(1.0/(1.0+np.exp(-lg)))
    else:
        p = probs.mean()
    return float(np.clip(p, 1e-6, 1-1e-6))

# 1) 跑一次完整推論（不需要 aggregate 入參）
def run_once(img, tile, stride, use_clip, alpha):
    res = infer_image_by_tiles(
        img, tile=int(tile), stride=int(stride),
        aggregate='mean_prob',           # 這裡只是占位，不影響後面多標準計算
        use_clip=bool(use_clip),
        save_overlay=True, overlay_alpha=float(alpha)
    )
    overlay = None
    if res.get("overlay_path"):
        overlay = Image.open(res["overlay_path"]).convert("RGB")
    # 提示：推論已完成，請在右側勾選彙總方式
    caption = "推論完成，請在右側勾選一個或多個彙總方式（不會重跑）"
    return caption, overlay, res

# 2) 從 raw 結果（tiles）事後計算多種彙總
def reaggregate(raw_result: dict, methods: list[str]):
    if not raw_result or "tiles" not in raw_result:
        return "尚未有推論結果", []
    probs = np.array([t['prob_fake'] for t in raw_result['tiles']], np.float32)
    table = []
    best = ("", -1.0, 0)
    for m in (methods or []):
        p = _aggregate_probs(probs, m)
        y = int(p >= 0.5)
        table.append([m, round(float(p), 6), y])
        if p > best[1]:
            best = (m, p, y)
    if not methods:
        return "請至少選擇一個彙總方式", []
    msg = f"最佳：{best[0]}  prob={best[1]:.4f}  pred={best[2]}"
    return msg, table

# === Gradio 介面（精簡、可直接用手機）===
with gr.Blocks(title="AI Image Detector — 事後多標準彙總", css="""
  .gradio-container {max-width: 980px !important;}
""") as demo2:
    gr.Markdown("### Slide-256 — 先推論，再多標準彙總（不重跑模型）")
    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type="filepath", label="上傳或拍照",
                           sources=["upload","webcam","clipboard"], height=320)
            with gr.Row():
                tile   = gr.Number(value=256, precision=0, label="Tile")
                stride = gr.Number(value=128, precision=0, label="Stride")
            use_clip = gr.Checkbox(value=True, label="Use CLIP head")
            alpha    = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Overlay alpha")
            run_btn  = gr.Button("開始推論", variant="primary")
        with gr.Column(scale=1):
            caption = gr.Label(label="整體資訊 / 彙總訊息")
            overlay = gr.Image(label="Overlay（與彙總無關，僅依 tile 機率）")
            # 顯示原始結果（含 tiles），方便除錯 / 也提供 reaggregate 的輸入
            raw     = gr.JSON(label="Raw 結果（tiles）")
            # 多選彙總方式（事後算）
            methods = gr.CheckboxGroup(
                choices=["mean_prob","max_prob","topk_mean:0.5","topk_mean:0.25","mean_logit"],
                label="選擇一個或多個彙總方式（不重跑）"
            )
            recompute = gr.Button("重新計算彙總")
            table = gr.Dataframe(
                headers=["method","prob_fake","pred_label"],
                datatype=["str","number","number"],
                interactive=False,
                label="各彙總方式結果"
            )

    # 跑一次推論，存 raw 結果；不需要 aggregate 入參
    run_btn.click(
        run_once,
        inputs=[inp, tile, stride, use_clip, alpha],
        outputs=[caption, overlay, raw]
    )

    # 事後多標準彙總：直接讀取 raw（tiles），不重跑模型
    recompute.click(
        reaggregate,
        inputs=[raw, methods],
        outputs=[caption, table]
    )

# 用 queue 可保持順暢；Gradio 4.x 不要 concurrency_count
demo2.queue(max_size=8).launch(share=False, inline=True)


* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.




In [None]:

# # ========= 範例 =========
# if __name__ == "__main__":
#     IMG_PATH = "/home/yaya/ai-detect-proj/test_img/real/Wtc-photo.jpg"
#     result = infer_image_by_tiles(IMG_PATH, tile=256, stride=128, aggregate='mean_logit', use_clip=True)
#     print(json.dumps(result, ensure_ascii=False, indent=2))
#     if result.get('overlay_path'):
#         print("Overlay saved to:", result['overlay_path'])


  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
  heat_img = Image.fromarray(heat_rgb, mode='RGB')
