# Define config and imports

In [None]:
from transformers import AutoImageProcessor, AutoModel
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from pathlib import Path
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import torch
import json
import math
import cv2
import os
import random
from sklearn.model_selection import train_test_split
from scipy import stats
from skimage import exposure, restoration
from skimage.feature import local_binary_pattern
import albumentations as A
from albumentations.pytorch import ToTensorV2

warnings.filterwarnings("ignore")
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# %% [markdown]
# ## 2. Configuration Class (Enhanced)

class CFG:
    # Paths
    train_images_path = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_images"
    test_images_path = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images"
    train_masks_path = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks"
    sample_sub_path = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv"
    
    # Model paths
    dino_path = "/kaggle/input/dinov2/pytorch/base/1"
    dino_weights_path = "/kaggle/input/m/ravaghi/dinov2/pytorch/base/1/model.pt"
    
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Image processing
    img_size = 512
    batch_size = 4
    
    # Inference settings
    use_tta = True
    tta_flip_horizontal = True
    tta_flip_vertical = True
    
    # Post-processing hyperparameters (to be tuned)
    alpha_grad = 0.35
    threshold_multiplier = 0.3
    morph_close_kernel = 5
    morph_open_kernel = 3
    
    # Hyperparameter tuning ranges
    alpha_grad_range = [0.2, 0.3, 0.35, 0.4, 0.5]
    threshold_range = [0.2, 0.3, 0.4, 0.5]
    kernel_range = [3, 5, 7]
    
    # EDA settings
    eda_sample_size = 100
    plot_dpi = 100
    
    # Random seed for reproducibility
    seed = 42
    
    @staticmethod
    def set_seed():
        random.seed(CFG.seed)
        np.random.seed(CFG.seed)
        torch.manual_seed(CFG.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(CFG.seed)

CFG.set_seed()

# Model Architecture (Pretrained dino encoder + small decoder

In [None]:
class DinoTinyDecoder(nn.Module):
    def __init__(self, in_ch=768, out_ch=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 256, 3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, out_ch, 1)
        )

    def forward(self, f, size):
        return self.net(F.interpolate(f, size=size, mode="bilinear", align_corners=False))


class DinoSegmenter(nn.Module):
    def __init__(self, encoder, processor):
        super().__init__()
        self.encoder, self.processor = encoder, processor
        
        for p in self.encoder.parameters():
            p.requires_grad = False
        
        self.seg_head = DinoTinyDecoder(768, 1)

    def forward_features(self, x):
        imgs = (x*255).clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy()
        inputs = self.processor(images=list(imgs), return_tensors="pt").to(x.device)
        
        with torch.no_grad():
            feats = self.encoder(**inputs).last_hidden_state
        
        B, N, C = feats.shape
        fmap = feats[:, 1:, :].permute(0, 2, 1)
        s = int(math.sqrt(N-1))
        fmap = fmap.reshape(B, C, s, s)
        
        return fmap

    def forward_seg(self, x):
        fmap = self.forward_features(x)
        return self.seg_head(fmap, (CFG.img_size, CFG.img_size))
        

# Define eval and training

Dice loss + BCE loss to evaluate

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        intersection = (pred * target).sum(dim=(2,3))
        union = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

@torch.no_grad()
def eval_decoder(model, val_loader):
    model.eval()
    ious = []

    for images, masks in val_loader:
        images = images.to(CFG.device)
        masks = masks.to(CFG.device)

        logits = model.forward_seg(images)
        preds = (torch.sigmoid(logits) > 0.5).float()

        inter = (preds * masks).sum(dim=(2,3))
        union = ((preds + masks) > 0).float().sum(dim=(2,3))
        iou = ((inter + 1e-6) / (union + 1e-6)).mean().item()

        ious.append(iou)

    model.train()
    return sum(ious) / len(ious)

def train_decoder(model, train_loader, val_loader=None, 
                  epochs=10, lr=1e-4, save_path="decoder_trained.pth",
                  use_amp=True, dice_weight=0.3):

    device = CFG.device

    # Train ONLY the decoder
    for p in model.encoder.parameters():
        p.requires_grad = False
    for p in model.seg_head.parameters():
        p.requires_grad = True

    criterion_bce = nn.BCEWithLogitsLoss()
    criterion_dice = DiceLoss()

    optimizer = torch.optim.Adam(model.seg_head.parameters(), lr=lr)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    model.train()

    for epoch in range(epochs):
        epoch_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device).float()

            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model.forward_seg(images)  # (B,1,H,W)

                loss_bce = criterion_bce(logits, masks)
                loss_dice = criterion_dice(logits, masks)

                loss = (1 - dice_weight) * loss_bce + dice_weight * loss_dice

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()
            pbar.set_postfix({"loss": loss.item()})

        print(f"Epoch {epoch+1}, Loss: {epoch_loss / len(train_loader):.4f}")

        # Optional validation
        # if val_loader is not None:
        #     val_iou = eval_decoder(model, val_loader)
        #     print(f"Val IoU: {val_iou:.4f}")

# Define loaders and data

90 / 10 Train Validation Split

In [None]:
class CopyMoveDataset(Dataset):
    """
    - authentic → zero mask
    - forged → load 0/1 mask from .npy file
    """

    def __init__(self, image_paths, mask_paths, img_size):
        self.image_paths = image_paths
        self.mask_paths = mask_paths   # contains None for authentic
        self.img_size = img_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path  = self.mask_paths[idx]

        # --- Load and preprocess image ---
        image = Image.open(image_path).convert("RGB")
        image = image.resize((self.img_size, self.img_size))
        image = np.array(image, np.float32) / 255.0
        image = torch.from_numpy(image).permute(2, 0, 1)

        # --- Load mask ---
        if mask_path is None:
            # authentic → no mask → zero mask
            mask = np.zeros((self.img_size, self.img_size), dtype=np.uint8)

        else:
            # forged → load .npy binary mask
            mask = np.load(mask_path)    # shape is (H, W) or (C, H, W)
            
            # If mask has channels, collapse
            if mask.ndim == 3:
                mask = np.max(mask, axis=0)

            # Ensure binary
            mask = (mask > 0).astype(np.uint8)

            # Resize to training size
            mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)

        mask = torch.from_numpy(mask)[None].float()
        return image, mask

def get_train_val_loaders(val_ratio=0.1):
    image_dir = Path(CFG.train_images_path)
    mask_dir  = Path(CFG.train_masks_path)

    # --- Load authentic images ---
    authentic_images = sorted(list((image_dir / "authentic").glob("*.*")))
    authentic_dict = {img.stem: img for img in authentic_images}

    # --- Load forged images ---
    forged_images = sorted(list((image_dir / "forged").glob("*.*")))
    forged_dict = {img.stem: img for img in forged_images}

    # Check masks exist
    for img in forged_images:
        mpath = mask_dir / (img.stem + ".npy")
        if not mpath.exists():
            raise FileNotFoundError(f"Missing .npy mask for forged image: {img.name}")

    paired_groups = []   # each item: list of (image_path, mask_path)
    singletons = []      # same format

    # --- Build pairs and singletons ---
    all_keys = set(authentic_dict.keys()) | set(forged_dict.keys())

    for stem in all_keys:
        has_auth = stem in authentic_dict
        has_forg = stem in forged_dict

        if has_auth and has_forg:
            # A matched pair
            authentic_img = authentic_dict[stem]
            forged_img = forged_dict[stem]
            forged_mask = mask_dir / (stem + ".npy")

            paired_groups.append([
                (authentic_img, None),
                (forged_img, forged_mask)
            ])
        elif has_auth:
            # Standalone authentic
            singletons.append([(authentic_dict[stem], None)])
        else:
            # Standalone forged
            forged_img = forged_dict[stem]
            forged_mask = mask_dir / (stem + ".npy")
            singletons.append([(forged_img, forged_mask)])

    print(f"Found {len(paired_groups)} paired authentic/forged groups.")
    print(f"Found {len(singletons)} singleton items.")

    # --- Split groups into train/val ---
    random.seed(CFG.seed)
    random.shuffle(paired_groups)
    random.shuffle(singletons)

    total_units = len(paired_groups) + len(singletons)
    val_units = int(total_units * val_ratio)

    # First fill val with some pairs (keep them intact)
    val_groups = []
    train_groups = []

    # Add pairs to val until close to ratio
    for group in paired_groups:
        if len(val_groups) < val_units:
            val_groups.append(group)
        else:
            train_groups.append(group)

    # Add singletons
    for group in singletons:
        if len(val_groups) < val_units:
            val_groups.append(group)
        else:
            train_groups.append(group)

    # --- Flatten lists ---
    train_images = []
    train_masks  = []
    val_images   = []
    val_masks    = []

    def append_group(groups, img_list, mask_list):
        for group in groups:
            for img_path, mask_path in group:
                img_list.append(img_path)
                mask_list.append(mask_path)

    append_group(train_groups, train_images, train_masks)
    append_group(val_groups,   val_images,  val_masks)

    print(f"Final split → train: {len(train_images)}, val: {len(val_images)}")

    # --- Build datasets ---
    train_ds = CopyMoveDataset(train_images, train_masks, CFG.img_size)
    val_ds   = CopyMoveDataset(val_images, val_masks, CFG.img_size)

    # --- Loaders ---
    train_loader = DataLoader(
        train_ds,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    return train_loader, val_loader

# Inference From predicted maps

Will create infered maps for each

In [None]:
def enhanced_postprocess(preds, original_size, alpha_grad=None, threshold_multiplier=None):
    """Enhanced post-processing with tuned parameters""" 
    if alpha_grad is None:
        alpha_grad = CFG.alpha_grad 
    if threshold_multiplier is None:
        threshold_multiplier = CFG.threshold_multiplier
        
    gx = cv2.Sobel(preds, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(preds, cv2.CV_32F, 0, 1, ksize=3)
    grad_mag = np.sqrt(gx**2 + gy**2)
    grad_norm = grad_mag / (grad_mag.max() + 1e-6)
    enhanced = (1 - alpha_grad) * preds + alpha_grad * grad_norm
    enhanced = cv2.GaussianBlur(enhanced, (3, 3), 0)
    thr = np.mean(enhanced) + threshold_multiplier * np.std(enhanced)
    mask = (enhanced > thr).astype(np.uint8)
    # Apply morphological operations with tuned parameters 
    close_kernel = np.ones((CFG.morph_close_kernel, CFG.morph_close_kernel), np.uint8)
    open_kernel = np.ones((CFG.morph_open_kernel, CFG.morph_open_kernel), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, open_kernel)
    mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)
    
    return mask

def rle_encode(mask):
    pixels = mask.T.flatten()
    dots = np.where(pixels == 1)[0]
    
    if len(dots) == 0:
        return "authentic"
    
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    
    return json.dumps([int(x) for x in run_lengths])

@torch.no_grad()
def predict_with_tta(model, image):
    predictions = []

    # Original prediction
    pred = torch.sigmoid(model.forward_seg(image))
    predictions.append(pred)

    # Horizontal flip TTA
    if CFG.tta_flip_horizontal:
        pred = torch.sigmoid(model.forward_seg(torch.flip(image, dims=[3])))
        predictions.append(torch.flip(pred, dims=[3]))

    # Vertical flip TTA
    if CFG.tta_flip_vertical:
        pred = torch.sigmoid(model.forward_seg(torch.flip(image, dims=[2])))
        predictions.append(torch.flip(pred, dims=[2]))

    return torch.stack(predictions).mean(0)[0, 0].cpu().numpy()

@torch.no_grad()
def predict(model, image):
    return torch.sigmoid(model.forward_seg(image))[0,0].cpu().numpy()

def infer_image(image, area_cutoff=400, mean_cutoff=0.3):
    image_array = np.array(image.resize((CFG.img_size, CFG.img_size)), np.float32) / 255
    image_array = torch.from_numpy(image_array).permute(2, 0, 1)[None].to(CFG.device)
    
    if CFG.use_tta:
        preds = predict_with_tta(model, image_array)
    else:
        preds = predict(model, image_array)
    
    mask = enhanced_postprocess(preds, image.size)
    
    area = int(mask.sum())
    if area > 0:
        resized_mask = cv2.resize(mask, (CFG.img_size, CFG.img_size), interpolation=cv2.INTER_NEAREST)
        mean_inside = float(preds[resized_mask == 1].mean())
    else:
        mean_inside = 0.0

    # Enhanced decision logic with tuned thresholds
    if area < area_cutoff or mean_inside < mean_cutoff:
        return "authentic", None    
    
    return "forged", mask


In [None]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
)

def precompute_val_predictions(model, val_loader):
    """
    Run the model once on all validation images and store the raw predictions
    """
    model.eval()
    val_data = []

    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Running model on validation set"):
            B = images.size(0)
            for b in range(B):
                # Image
                image_tensor = images[b]
                pil_img = Image.fromarray((image_tensor.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8))

                # Ground truth
                gt_mask = masks[b,0].cpu().numpy()
                gt_label = "forged" if gt_mask.sum() > 0 else "authentic"

                # --- Forward pass ---
                image_array = np.array(pil_img.resize((CFG.img_size, CFG.img_size)), np.float32) / 255
                image_array = torch.from_numpy(image_array).permute(2,0,1)[None].to(CFG.device)

                if CFG.use_tta:
                    preds = predict_with_tta(model, image_array)
                else:
                    preds = predict(model, image_array)

                val_data.append({
                    "pil_image": pil_img,
                    "gt_mask": gt_mask,
                    "gt_label": gt_label,
                    "raw_pred": preds  # float32 numpy array, resized to original size
                })

    print(f"Precomputed predictions for {len(val_data)} validation images.")
    return val_data


def fast_tune_infer_cutoffs(val_data, area_range, mean_range):
    """
    Grid search using precomputed raw predictions
    """
    best_f1 = -1
    best_params = {}

    for area_cut in area_range:
        for mean_cut in mean_range:
            y_true = []
            y_pred = []

            for item in val_data:
                preds = item["raw_pred"]
                gt_label = item["gt_label"]
                gt_mask  = item["gt_mask"]

                # --- Apply post-processing and decision logic only ---
                mask = enhanced_postprocess(preds, item["pil_image"].size)
                area = int(mask.sum())
                if area > 0:
                    resized_mask = cv2.resize(mask, (CFG.img_size, CFG.img_size), interpolation=cv2.INTER_NEAREST)
                    mean_inside = float(preds[resized_mask == 1].mean())
                else:
                    mean_inside = 0.0

                if area < area_cut or mean_inside < mean_cut:
                    pred_label = "authentic"
                else:
                    pred_label = "forged"

                y_true.append(gt_label)
                y_pred.append(pred_label)

            f1 = f1_score(y_true, y_pred, pos_label="forged")

            if f1 > best_f1:
                best_f1 = f1
                best_params = {
                    "area_cutoff": area_cut,
                    "mean_cutoff": mean_cut
                }

    print(f"\nBest thresholds: {best_params}, F1 Score: {best_f1:.4f}")
    return best_params

# Validation Metrics

Displays first few actual maps, summary stats

In [None]:
def compute_iou(pred_mask, gt_mask):
    pred = pred_mask > 0
    gt = gt_mask > 0
    inter = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    if union == 0:
        return 1.0 if inter == 0 else 0.0
    return inter / union

def compute_dice(pred_mask, gt_mask):
    pred = pred_mask > 0
    gt = gt_mask > 0
    inter = np.logical_and(pred, gt).sum()
    denom = pred.sum() + gt.sum()
    if denom == 0:
        return 1.0 if inter == 0 else 0.0
    return 2 * inter / denom


# -------------------------------------------------------------
#      VALIDATION PREDICTION PIPELINE USING infer_image()
# -------------------------------------------------------------

def evaluate_on_validation(model, val_loader, max_visualize=10, area_cutoff=400, mean_cutoff=0.3):

    all_labels = []      # true labels (authentic/forged)
    all_preds  = []      # predicted labels
    iou_scores = []      # mask IoU for forged examples
    dice_scores = []     # mask dice for forged examples

    visualizations = []

    count_visual = 0

    for images, masks in val_loader:

        B = images.size(0)

        for b in range(B):

            image_tensor = images[b]
            mask_tensor = masks[b, 0]

            # Convert back to PIL for infer_image()
            image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
            image_np = (image_np * 255).clip(0, 255).astype(np.uint8)
            pil_img = Image.fromarray(image_np)

            gt_mask = mask_tensor.cpu().numpy()

            # --- infer_image (classification + mask prediction) ---
            pred_label, pred_mask = infer_image(pil_img, area_cutoff=area_cutoff, mean_cutoff=mean_cutoff)

            # GT label
            true_label = "forged" if gt_mask.sum() > 0 else "authentic"

            all_labels.append(true_label)
            all_preds.append(pred_label)

            # Compute mask metrics only for forged images
            if true_label == "forged":
                if pred_mask is None:
                    pred_mask = np.zeros_like(gt_mask)

                pred_mask = pred_mask.astype(np.uint8)
                iou_scores.append(compute_iou(pred_mask, gt_mask))
                dice_scores.append(compute_dice(pred_mask, gt_mask))

            # Save visualizations
            if count_visual < max_visualize:
                visualizations.append((pil_img, gt_mask, pred_mask, true_label, pred_label))
                count_visual += 1


    # ---------------- Metrics ----------------
    accuracy  = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, pos_label="forged")
    recall    = recall_score(all_labels, all_preds, pos_label="forged")
    f1        = f1_score(all_labels, all_preds, pos_label="forged")
    cm        = confusion_matrix(all_labels, all_preds, labels=["authentic", "forged"])

    print("\n================ Validation Results ================\n")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")
    print("\nConfusion Matrix (rows=true, cols=pred):")
    print("            authentic   forged")
    print(f"authentic |   {cm[0,0]:5d}      {cm[0,1]:5d}")
    print(f"forged    |   {cm[1,0]:5d}      {cm[1,1]:5d}")

    if iou_scores:
        print(f"\nMean IoU (forged only):  {np.mean(iou_scores):.4f}")
        print(f"Mean Dice (forged only): {np.mean(dice_scores):.4f}")
    else:
        print("\nNo forged samples found in validation set!")


    # ---------------- Visualization ----------------
    print("\nShowing first validation predictions...\n")

    rows = len(visualizations)
    plt.figure(figsize=(12, 4 * rows))

    for i, (img, gt_mask, pred_mask, true_lbl, pred_lbl) in enumerate(visualizations):
    
        # Ensure pred_mask is a proper numeric array
        if pred_mask is None:
            pred_mask = np.zeros_like(gt_mask, dtype=np.uint8)
    
        # Original
        plt.subplot(rows, 3, i*3 + 1)
        plt.imshow(img)
        plt.title(f"Image\nLabel: {true_lbl}")
        plt.axis("off")
    
        # Ground truth
        plt.subplot(rows, 3, i*3 + 2)
        plt.imshow(gt_mask, cmap="gray")
        plt.title("Ground Truth Mask")
        plt.axis("off")
    
        # Predicted
        plt.subplot(rows, 3, i*3 + 3)
        plt.imshow(pred_mask, cmap="gray")
        plt.title(f"Predicted Mask\nPred Label: {pred_lbl}")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

# Actual Model Execution

In [None]:
# Intial Define and train
train_loader, val_loader = get_train_val_loaders()

processor = AutoImageProcessor.from_pretrained(CFG.dino_path, local_files_only=True)
encoder = AutoModel.from_pretrained(CFG.dino_path, local_files_only=True).eval().to(CFG.device)

model = DinoSegmenter(encoder, processor).to(CFG.device)

train_decoder(
    model,
    train_loader,
    val_loader=val_loader,
    epochs=5,
    lr=1e-4,
    dice_weight=0.3
)

In [None]:
# Grid search over cutoffs
val_data = precompute_val_predictions(model, val_loader)

# Step 2: Run fast grid search
area_candidates = [100, 200, 300, 400, 500, 600]
mean_candidates = [0.1, 0.2, 0.25, 0.3, 0.35, 0.4]

best_thresholds = fast_tune_infer_cutoffs(val_data, area_candidates, mean_candidates)

print(f'Best Area Threshold: {best_thresholds["area_cutoff"] }\nBest Mean Threshold: {best_thresholds["mean_cutoff"]}')

In [None]:
evaluate_on_validation(model, val_loader, area_cutoff=best_thresholds["area_cutoff"], mean_cutoff=best_thresholds["mean_cutoff"])

In [None]:
torch.save(model.state_dict(), "/kaggle/working/model_trained_5epochs_pair.pth")

# Intial Define and train
train_loader, val_loader = get_train_val_loaders()

processor = AutoImageProcessor.from_pretrained(CFG.dino_path, local_files_only=True)
encoder = AutoModel.from_pretrained(CFG.dino_path, local_files_only=True).eval().to(CFG.device)

model = DinoSegmenter(encoder, processor).to(CFG.device)

state_dict = torch.load("/kaggle/working/model_trained_5epochs.pth", map_location=CFG.device)
model.load_state_dict(state_dict)

model.eval()

print("Model loaded successfully!")

In [None]:
train_decoder(
    model,
    train_loader,
    val_loader=val_loader,
    epochs=5,
    lr=1e-4,
    dice_weight=0.3
)

In [None]:
# Grid search over cutoffs
val_data = precompute_val_predictions(model, val_loader)

# Step 2: Run fast grid search
area_candidates = [100, 200, 300, 400, 500, 600]
mean_candidates = [0.1, 0.2, 0.25, 0.3, 0.35, 0.4]

best_thresholds = fast_tune_infer_cutoffs(val_data, area_candidates, mean_candidates)

print(f"Best Area Threshold: {best_thresholds['area_cutoff']}\nBest Mean Threshold: {best_thresholds['mean_cutoff']}")

In [None]:
evaluate_on_validation(model, val_loader, area_cutoff=best_thresholds["area_cutoff"], mean_cutoff=best_thresholds["mean_cutoff"])

In [None]:
torch.save(model.state_dict(), "/kaggle/working/model_trained_10epochs_pair.pth")

In [None]:
train_decoder(
    model,
    train_loader,
    val_loader=val_loader,
    epochs=5,
    lr=1e-4,
    dice_weight=0.3
)

In [None]:
# Grid search over cutoffs
val_data = precompute_val_predictions(model, val_loader)

# Step 2: Run fast grid search
area_candidates = [100, 200, 300, 400, 500, 600]
mean_candidates = [0.1, 0.2, 0.25, 0.3, 0.35, 0.4]

best_thresholds = fast_tune_infer_cutoffs(val_data, area_candidates, mean_candidates)

print(f"Best Area Threshold: {best_thresholds['area_cutoff']}\nBest Mean Threshold: {best_thresholds['mean_cutoff']}")

In [None]:
evaluate_on_validation(model, val_loader, area_cutoff=best_thresholds["area_cutoff"], mean_cutoff=best_thresholds["mean_cutoff"])

In [None]:
torch.save(model.state_dict(), "/kaggle/working/model_trained_15epochs_pair.pth")

# Train CNN on predicted masks to class classification

In [None]:
class ClassifierInputDataset(Dataset):
    """
    Dataset for training the mask classifier.
    
    It does NOT run the segmenter.
    It only returns:
        - raw image tensor (3×H×W)
        - label: forged=1, authentic=0

    The classifier model itself handles running the frozen segmenter.
    """

    def __init__(self, base_dataset):
        self.base_dataset = base_dataset

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        image_tensor, mask_tensor = self.base_dataset[idx]  # mask_tensor = GT mask

        # label: forged = 1 if GT mask has any area
        label = 1 if mask_tensor.sum() > 0 else 0

        return image_tensor, torch.tensor(label, dtype=torch.long)

def get_classifier_loaders(train_dataset, val_dataset, batch_size=16):
    train_ds = ClassifierInputDataset(train_dataset)
    val_ds   = ClassifierInputDataset(val_dataset)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    return train_loader, val_loader

# CNN Classifier architecture

In [None]:
class FrozenSegmenterMaskClassifier(nn.Module):
    """
    Full pipeline:
        raw RGB image → frozen segmenter → predicted mask → tiny CNN classifier (trainable)
    """

    def __init__(self, segmenter):
        super().__init__()

        # Store the frozen segmenter
        self.segmenter = segmenter
        self.segmenter.eval()
        for p in self.segmenter.parameters():
            p.requires_grad = False

        # Small CNN for classification on predicted mask
        self.features = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=5, stride=4, padding=2),  # 512 → 128
            nn.ReLU(),

            nn.Conv2d(8, 16, kernel_size=3, stride=4, padding=1), # 128 → 32
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 32 → 16
            nn.ReLU(),
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # (B,32,1,1)
            nn.Flatten(),             # (B,32)
            nn.Linear(32, 1)          # (B,1)
        )


    def forward(self, x):
        """
        x:
            Either raw image tensor (B,3,512,512)
            OR already-segmented mask tensor (B,1,512,512)

        Returns:
            logits (B,1)
        """

        # ------------------------------------------------------
        # 1. If input is 3-ch RGB → run the frozen segmenter
        # ------------------------------------------------------
        if x.shape[1] == 3:     # RGB input
            with torch.no_grad():
                pred_mask = torch.sigmoid(self.segmenter.forward_seg(x))
        else:
            # Already a mask
            pred_mask = x

        # Ensure mask size is (B,1,512,512)
        if pred_mask.shape[1] != 1:
            pred_mask = pred_mask[:, :1]

        # ------------------------------------------------------
        # 2. Pass mask through the trainable classifier
        # ------------------------------------------------------
        feats = self.features(pred_mask)
        out = self.classifier(feats)

        return out

# Training Loop

In [None]:
def train_mask_classifier(
    model,
    train_loader,
    val_loader,
    num_epochs=10,
    lr=1e-4,
    device="cuda",
    save_path="mask_classifier_best.pth"
):

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    best_val_f1 = 0.0

    for epoch in range(1, num_epochs + 1):
        print(f"\n======== Epoch {epoch}/{num_epochs} ========")

        # -----------------------
        # TRAIN
        # -----------------------
        model.train()
        train_losses = []
        all_preds, all_labels = [], []

        for masks, labels in tqdm(train_loader, desc="Training", leave=False):
            masks = masks.to(device)
            labels = labels.float().unsqueeze(1).to(device)

            # Forward
            logits = model(masks)
            loss = criterion(logits, labels)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

            # Predictions
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            preds = (probs > 0.5).astype(int).flatten()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy().flatten())

        train_loss = sum(train_losses) / len(train_losses)
        train_acc = accuracy_score(all_labels, all_preds)

        print(f"Train Loss: {train_loss:.4f}  |  Train Acc: {train_acc:.4f}")

        # -----------------------
        # VALIDATION
        # -----------------------
        model.eval()
        val_losses = []
        val_preds, val_labels_list = [], []

        with torch.no_grad():
            for masks, labels in tqdm(val_loader, desc="Validating", leave=False):
                masks = masks.to(device)
                labels = labels.float().unsqueeze(1).to(device)

                logits = model(masks)
                loss = criterion(logits, labels)
                val_losses.append(loss.item())

                probs = torch.sigmoid(logits).cpu().numpy()
                preds = (probs > 0.5).astype(int).flatten()

                val_preds.extend(preds)
                val_labels_list.extend(labels.cpu().numpy().flatten())

        val_loss = sum(val_losses) / len(val_losses)
        val_acc = accuracy_score(val_labels_list, val_preds)
        val_prec = precision_score(val_labels_list, val_preds, zero_division=0)
        val_rec = recall_score(val_labels_list, val_preds, zero_division=0)
        val_f1 = f1_score(val_labels_list, val_preds, zero_division=0)

        print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | "
              f"Prec: {val_prec:.4f} | Rec: {val_rec:.4f} | F1: {val_f1:.4f}")

        # Save best model
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), save_path)
            print(f"✔ Saved new best model (F1={best_val_f1:.4f})")

    print("\nTraining complete!")
    print(f"Best F1 Score: {best_val_f1:.4f}")
    return model

# Training of Post-Unet Classifier

In [None]:
# 2. Load your segmentation model
state_dict = torch.load("/kaggle/working/model_trained_10epochs.pth", map_location=CFG.device)
model.load_state_dict(state_dict)
model.eval()

# 3. Create prediction mask loaders
train_pred_loader, val_pred_loader = get_classifier_loaders(
    train_loader.dataset,
    val_loader.dataset,
    batch_size=CFG.batch_size
)

post_model = FrozenSegmenterMaskClassifier(model)

trained_model = train_mask_classifier(
    post_model,
    train_pred_loader,
    val_pred_loader,
    num_epochs=5,
    lr=1e-4,
    device=CFG.device,
    save_path="/kaggle/working/tiny_mask_classifier_best.pth"
)