# Submission Notebook - DINOv2 Forgery Detection

This notebook loads pre-trained weights and generates predictions for the test set.

In [None]:
import os, cv2, json, math
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModel

In [None]:
# ==================== CONFIG ====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Kaggle paths
TEST_DIR = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images"
SAMPLE_SUB = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv"

# Path to your uploaded model weights dataset
# Change this to match your Kaggle dataset name!
WEIGHTS_PATH = "/kaggle/input/your-model-weights-dataset/model_seg_final.pt"

# Model config (must match training)
DINO_PATH = "facebook/dinov2-base"
IMG_SIZE = 512
CHANNELS = 4

OUT_PATH = "submission.csv"

print(f"Device: {device}")

In [None]:
# ==================== MODEL DEFINITION ====================

class DinoDecoder(nn.Module):
    """Progressive upsampling decoder with regularization"""
    def __init__(self, in_ch=768, out_ch=CHANNELS, dropout=0.1):
        super().__init__()
        
        self.up1 = self._block(in_ch, 384, dropout)
        self.up2 = self._block(384, 192, dropout)
        self.up3 = self._block(192, 96, dropout)
        self.up4 = self._block(96, 48, dropout)
        
        self.final = nn.Conv2d(48, out_ch, kernel_size=1)
    
    def _block(self, in_ch, out_ch, dropout):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, f, size):
        x = F.interpolate(f, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up1(x)
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up2(x)
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up3(x)
        
        x = F.interpolate(x, size=size, mode='bilinear', align_corners=False)
        x = self.up4(x)
        
        return self.final(x)


class DinoSegmenter(nn.Module):
    def __init__(self, encoder, processor, unfreeze_blocks=3):
        super().__init__()
        self.encoder, self.processor = encoder, processor
        
        # Freeze all parameters
        for p in self.encoder.parameters():
            p.requires_grad = False
        
        # Unfreeze last N blocks
        num_blocks = len(self.encoder.encoder.layer)
        for i in range(num_blocks - unfreeze_blocks, num_blocks):
            for p in self.encoder.encoder.layer[i].parameters():
                p.requires_grad = True
        
        for p in self.encoder.layernorm.parameters():
            p.requires_grad = True
        
        self.seg_head = DinoDecoder(768, CHANNELS)

    def forward_features(self, x):
        imgs = (x*255).clamp(0,255).byte().permute(0,2,3,1).cpu().numpy()
        inputs = self.processor(images=list(imgs), return_tensors="pt").to(x.device)
        with torch.no_grad():
            feats = self.encoder(**inputs).last_hidden_state
        B, N, C = feats.shape
        fmap = feats[:, 1:, :].permute(0, 2, 1)
        s = int(math.sqrt(N-1))
        fmap = fmap.reshape(B, C, s, s)
        return fmap

    def forward_seg(self, x):
        fmap = self.forward_features(x)
        return self.seg_head(fmap, (IMG_SIZE, IMG_SIZE))

In [None]:
# ==================== LOAD MODEL ====================

print("Loading DINOv2 encoder...")
processor = AutoImageProcessor.from_pretrained(DINO_PATH)
encoder = AutoModel.from_pretrained(DINO_PATH).eval().to(device)

print("Building model...")
model_seg = DinoSegmenter(encoder, processor).to(device)

print(f"Loading weights from {WEIGHTS_PATH}...")
model_seg.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
model_seg.eval()

print("Model loaded successfully!")

In [None]:
# ==================== INFERENCE FUNCTIONS ====================

@torch.no_grad()
def segment_prob_map_all_channels(pil):
    """Returns probability maps for ALL channels."""
    x = torch.from_numpy(np.array(pil.resize((IMG_SIZE, IMG_SIZE)), np.float32)/255.).permute(2,0,1)[None].to(device)
    return torch.sigmoid(model_seg.forward_seg(x))[0].cpu().numpy()


def enhanced_adaptive_mask(prob, alpha_grad=0.35):
    gx = cv2.Sobel(prob, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(prob, cv2.CV_32F, 0, 1, ksize=3)
    grad_mag = np.sqrt(gx**2 + gy**2)
    grad_norm = grad_mag / (grad_mag.max() + 1e-6)
    enhanced = (1 - alpha_grad) * prob + alpha_grad * grad_norm
    enhanced = cv2.GaussianBlur(enhanced, (3,3), 0)
    thr = np.mean(enhanced) + 0.3 * np.std(enhanced)
    mask = (enhanced > thr).astype(np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5,5), np.uint8))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8))
    return mask, thr


def finalize_mask(prob, orig_size):
    mask, thr = enhanced_adaptive_mask(prob)
    mask = cv2.resize(mask, orig_size, interpolation=cv2.INTER_NEAREST)
    return mask, thr


def pipeline_final(pil):
    """Returns a LIST of masks (one per detected forged region)."""
    probs = segment_prob_map_all_channels(pil)
    
    all_masks = []
    all_areas = []
    all_means = []
    all_thrs = []
    
    for ch in range(probs.shape[0]):
        prob = probs[ch]
        mask, thr = finalize_mask(prob, pil.size)
        area = int(mask.sum())
        
        if area > 0:
            prob_resized = cv2.resize(prob, pil.size, interpolation=cv2.INTER_LINEAR)
            mean_inside = float(prob_resized[mask == 1].mean())
        else:
            mean_inside = 0.0
        
        # Filter out small/weak detections
        if area >= 400 and mean_inside >= 0.35:
            all_masks.append(mask)
            all_areas.append(area)
            all_means.append(mean_inside)
            all_thrs.append(thr)
    
    if len(all_masks) == 0:
        return "authentic", [], {"area": 0, "mean_inside": 0.0, "thr": 0.0}
    
    total_area = sum(all_areas)
    avg_mean = sum(all_means) / len(all_means)
    avg_thr = sum(all_thrs) / len(all_thrs)
    
    return "forged", all_masks, {"area": total_area, "mean_inside": avg_mean, "thr": avg_thr, "num_masks": len(all_masks)}

In [None]:
# ==================== RLE ENCODING ====================

def rle_encode_single(mask: np.ndarray, fg_val: int = 1) -> str:
    """Encode a single 2D mask to RLE JSON string."""
    pixels = mask.T.flatten()
    dots = np.where(pixels == fg_val)[0]
    if len(dots) == 0:
        return None
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return json.dumps([int(x) for x in run_lengths])


def rle_encode_multi(masks: list, fg_val: int = 1) -> str:
    """Encode multiple masks, joining with semicolons."""
    encoded = []
    for m in masks:
        enc = rle_encode_single((m > 0).astype(np.uint8), fg_val)
        if enc is not None:
            encoded.append(enc)
    return ';'.join(encoded) if encoded else "authentic"

In [None]:
# ==================== GENERATE SUBMISSION ====================

rows = []
test_files = sorted(os.listdir(TEST_DIR))
print(f"Processing {len(test_files)} test images...")

for f in tqdm(test_files, desc="Inference"):
    pil = Image.open(Path(TEST_DIR)/f).convert("RGB")
    label, masks, dbg = pipeline_final(pil)

    if label == "authentic" or len(masks) == 0:
        annot = "authentic"
    else:
        annot = rle_encode_multi(masks)

    rows.append({
        "case_id": Path(f).stem,
        "annotation": annot,
    })

# Create submission DataFrame
sub = pd.DataFrame(rows)

# Merge with sample submission to ensure correct order
ss = pd.read_csv(SAMPLE_SUB)
ss["case_id"] = ss["case_id"].astype(str)
sub["case_id"] = sub["case_id"].astype(str)
final = ss[["case_id"]].merge(sub, on="case_id", how="left")
final["annotation"] = final["annotation"].fillna("authentic")

# Save
final[["case_id", "annotation"]].to_csv(OUT_PATH, index=False)

print(f"\nâœ… Saved submission to: {OUT_PATH}")
print(f"Total rows: {len(final)}")
print(f"Forged: {(final['annotation'] != 'authentic').sum()}")
print(f"Authentic: {(final['annotation'] == 'authentic').sum()}")
print("\nFirst 10 rows:")
print(final.head(10))