In [None]:
import os, cv2, json, math
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

# ==================== CONFIG ====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TEST_DIR = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images"
SAMPLE_SUB = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv"
WEIGHTS_PATH = "/kaggle/input/recod-1219/best_model.pt"  # UPDATE THIS

DINO_PATH = "/kaggle/input/dinov2/pytorch/base/1"
IMG_SIZE = 512
CHANNELS = 4
UNFREEZE_BLOCKS = 4
DECODER_DROPOUT = 0.15

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# ==================== MODEL ====================
class DinoDecoder(nn.Module):
    def __init__(self, in_channels=768, out_channels=4, dropout=0.1):
        super().__init__()
        self.up1 = self._block(in_channels, 384, dropout)
        self.up2 = self._block(384, 192, dropout)
        self.up3 = self._block(192, 96, dropout)
        self.up4 = self._block(96, 48, dropout)
        self.final = nn.Conv2d(48, out_channels, kernel_size=1)
    
    def _block(self, in_ch, out_ch, dropout):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Dropout2d(dropout), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
        )
    
    def forward(self, features, target_size):
        x = F.interpolate(features, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up1(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up2(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up3(x)
        x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
        x = self.up4(x)
        return self.final(x)

class DinoSegmenter(nn.Module):
    def __init__(self, backbone="facebook/dinov2-base", out_channels=4, unfreeze_blocks=3, decoder_dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(backbone)
        hidden_size = self.encoder.config.hidden_size
        self.register_buffer('pixel_mean', torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1), persistent=False)
        self.register_buffer('pixel_std', torch.tensor(IMAGENET_STD).view(1, 3, 1, 1), persistent=False)
        for param in self.encoder.parameters():
            param.requires_grad = False
        num_blocks = len(self.encoder.encoder.layer)
        for i in range(num_blocks - unfreeze_blocks, num_blocks):
            for param in self.encoder.encoder.layer[i].parameters():
                param.requires_grad = True
        for param in self.encoder.layernorm.parameters():
            param.requires_grad = True
        self.decoder = DinoDecoder(hidden_size, out_channels, decoder_dropout)
    
    def forward(self, x):
        x = (x - self.pixel_mean) / self.pixel_std
        feats = self.encoder(pixel_values=x).last_hidden_state
        B, N, C = feats.shape
        fmap = feats[:, 1:, :].permute(0, 2, 1).reshape(B, C, int(math.sqrt(N-1)), int(math.sqrt(N-1)))
        return self.decoder(fmap, (x.shape[2], x.shape[3]))

# ==================== LOAD MODEL ====================
print("Loading model...")
model = DinoSegmenter(DINO_PATH, CHANNELS, UNFREEZE_BLOCKS, DECODER_DROPOUT).to(device)
checkpoint = torch.load(WEIGHTS_PATH, map_location=device, weights_only=False)
state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
model.load_state_dict(state_dict)
model.eval()
print("✅ Model loaded")

# ==================== INFERENCE ====================
@torch.no_grad()
def segment_prob_map(pil):
    img = pil.resize((IMG_SIZE, IMG_SIZE))
    x = torch.from_numpy(np.array(img, np.float32) / 255.).permute(2, 0, 1)[None].to(device)
    return torch.sigmoid(model(x))[0].cpu().numpy()

def enhanced_adaptive_mask(prob, alpha_grad=0.35):
    gx, gy = cv2.Sobel(prob, cv2.CV_32F, 1, 0, ksize=3), cv2.Sobel(prob, cv2.CV_32F, 0, 1, ksize=3)
    grad_norm = np.sqrt(gx**2 + gy**2) / (np.sqrt(gx**2 + gy**2).max() + 1e-6)
    enhanced = cv2.GaussianBlur((1 - alpha_grad) * prob + alpha_grad * grad_norm, (3, 3), 0)
    thr = np.mean(enhanced) + 0.3 * np.std(enhanced)
    mask = (enhanced > thr).astype(np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))
    return mask, thr

def pipeline_final(pil):
    probs = segment_prob_map(pil)
    all_masks = []
    for ch in range(probs.shape[0]):
        mask, _ = enhanced_adaptive_mask(probs[ch])
        mask = cv2.resize(mask, pil.size, interpolation=cv2.INTER_NEAREST)
        area = int(mask.sum())
        if area > 0:
            prob_resized = cv2.resize(probs[ch], pil.size, interpolation=cv2.INTER_LINEAR)
            mean_inside = float(prob_resized[mask == 1].mean())
        else:
            mean_inside = 0.0
        if area >= 400 and mean_inside >= 0.35:
            all_masks.append(mask)
    return all_masks

# ==================== RLE ENCODING ====================
def rle_encode_single(mask):
    pixels = mask.T.flatten()
    dots = np.where(pixels == 1)[0]
    if len(dots) == 0:
        return None
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return json.dumps([int(x) for x in run_lengths])

def rle_encode_multi(masks):
    encoded = [rle_encode_single((m > 0).astype(np.uint8)) for m in masks]
    encoded = [e for e in encoded if e is not None]
    return ';'.join(encoded) if encoded else "authentic"

# ==================== GENERATE SUBMISSION ====================
rows = []
test_files = sorted(os.listdir(TEST_DIR))
print(f"Processing {len(test_files)} test images...")

for f in test_files:
    pil = Image.open(Path(TEST_DIR) / f).convert("RGB")
    masks = pipeline_final(pil)
    annot = rle_encode_multi(masks) if masks else "authentic"
    rows.append({"case_id": Path(f).stem, "annotation": annot})

sub = pd.DataFrame(rows)
ss = pd.read_csv(SAMPLE_SUB)
ss["case_id"] = ss["case_id"].astype(str)
sub["case_id"] = sub["case_id"].astype(str)
final = ss[["case_id"]].merge(sub, on="case_id", how="left")
final["annotation"] = final["annotation"].fillna("authentic")
final[["case_id", "annotation"]].to_csv("submission.csv", index=False)

import os
print(f"\n✅ Saved submission.csv")
print(f"File exists: {os.path.exists('submission.csv')}")
print(f"Total: {len(final)} | Forged: {(final['annotation'] != 'authentic').sum()} | Authentic: {(final['annotation'] == 'authentic').sum()}")
