In [None]:
import os, cv2, json, math
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoImageProcessor

# ==================== CONFIG ====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TEST_DIR = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/test_images"
SAMPLE_SUB = "/kaggle/input/recodai-luc-scientific-image-forgery-detection/sample_submission.csv"
WEIGHTS_PATH = "/kaggle/input/recod-1219/best_model_clean.pt"  # UPDATE THIS

DINO_PATH = "/kaggle/input/dinov2/pytorch/base/1"
IMG_SIZE = 512
CHANNELS = 4
UNFREEZE_BLOCKS = 4
DECODER_DROPOUT = 0.15

# ==================== MODEL ====================
class DinoDecoder(nn.Module):
    def __init__(self, in_channels=768, out_channels=4, dropout=0.1):
        super().__init__()
        self.up1 = self._block(in_channels, 384, dropout)
        self.up2 = self._block(384, 192, dropout)
        self.up3 = self._block(192, 96, dropout)
        self.up4 = self._block(96, 48, dropout)
        self.final = nn.Conv2d(48, out_channels, kernel_size=1)
    
    def _block(self, in_ch, out_ch, dropout):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Dropout2d(dropout), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
        )
    
    def forward(self, features, target_size):
        x = F.interpolate(features, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up1(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up2(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up3(x)
        x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
        x = self.up4(x)
        return self.final(x)

class DinoSegmenter(nn.Module):
    def __init__(self, backbone="facebook/dinov2-base", out_channels=4, unfreeze_blocks=3, decoder_dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(backbone)
        hidden_size = self.encoder.config.hidden_size
        self.processor = AutoImageProcessor.from_pretrained(backbone, use_fast=True)
        for param in self.encoder.parameters():
            param.requires_grad = False
        num_blocks = len(self.encoder.encoder.layer)
        for i in range(num_blocks - unfreeze_blocks, num_blocks):
            for param in self.encoder.encoder.layer[i].parameters():
                param.requires_grad = True
        for param in self.encoder.layernorm.parameters():
            param.requires_grad = True
        self.decoder = DinoDecoder(hidden_size, out_channels, decoder_dropout)
    
    def forward_features(self, x):
        # Convert from [0, 1] to [0, 255]
        imgs = (x * 255).clamp(0, 255).byte().permute(0, 2, 3, 1)
        inputs = self.processor(images=imgs, return_tensors="pt")
        # Move to same device/dtype as model
        inputs = {k: v.to(x.device, x.dtype) if v.is_floating_point() else v.to(x.device) for k, v in inputs.items()}
        feats = self.encoder(**inputs).last_hidden_state
        B, N, C = feats.shape
        fmap = feats[:, 1:, :].permute(0, 2, 1).reshape(B, C, int(math.sqrt(N-1)), int(math.sqrt(N-1)))
        return fmap
    
    def forward(self, x):
        target_size = (x.shape[2], x.shape[3])
        fmap = self.forward_features(x)
        return self.decoder(fmap, target_size)

# ==================== LOAD MODEL ====================
print("Loading model...")
model = DinoSegmenter(DINO_PATH, CHANNELS, UNFREEZE_BLOCKS, DECODER_DROPOUT).to(device)
checkpoint = torch.load(WEIGHTS_PATH, map_location=device, weights_only=False)
state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
model.load_state_dict(state_dict)
model.eval()
print("✅ Model loaded")

# ==================== INFERENCE WITH TTA ====================
@torch.no_grad()
def segment_prob_map_single(pil):
    img = pil.resize((IMG_SIZE, IMG_SIZE))
    x = torch.from_numpy(np.array(img, np.float32) / 255.).permute(2, 0, 1)[None].to(device)
    return torch.sigmoid(model(x))[0].cpu().numpy()

@torch.no_grad()
def segment_prob_map_tta(pil):
    """TTA with normal, horizontal flip, vertical flip, and both flips"""
    # Original
    probs_orig = segment_prob_map_single(pil)
    
    # Horizontal flip
    pil_hflip = pil.transpose(Image.FLIP_LEFT_RIGHT)
    probs_hflip = segment_prob_map_single(pil_hflip)
    probs_hflip = probs_hflip[:, :, ::-1]  # flip back
    
    # Vertical flip
    pil_vflip = pil.transpose(Image.FLIP_TOP_BOTTOM)
    probs_vflip = segment_prob_map_single(pil_vflip)
    probs_vflip = probs_vflip[:, ::-1, :]  # flip back
    
    # Both flips
    pil_both = pil.transpose(Image.FLIP_LEFT_RIGHT).transpose(Image.FLIP_TOP_BOTTOM)
    probs_both = segment_prob_map_single(pil_both)
    probs_both = probs_both[:, ::-1, ::-1]  # flip back both
    
    # Average
    return (probs_orig + probs_hflip + probs_vflip + probs_both) / 4.0

def threshold_mask(prob, thr=0.99):
    mask = (prob > thr).astype(np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))
    return mask

def pipeline_final(pil):
    probs = segment_prob_map_tta(pil)
    all_masks = []
    for ch in range(probs.shape[0]):
        mask = threshold_mask(probs[ch], thr=0.99)
        mask = cv2.resize(mask, pil.size, interpolation=cv2.INTER_NEAREST)
        area = int(mask.sum())
        if area >= 200:
            all_masks.append(mask)
    return all_masks

# ==================== RLE ENCODING ====================
def rle_encode_single(mask):
    pixels = mask.T.flatten()
    dots = np.where(pixels == 1)[0]
    if len(dots) == 0:
        return None
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return json.dumps([int(x) for x in run_lengths])

def rle_encode_multi(masks):
    encoded = [rle_encode_single((m > 0).astype(np.uint8)) for m in masks]
    encoded = [e for e in encoded if e is not None]
    return ';'.join(encoded) if encoded else "authentic"

# ==================== GENERATE SUBMISSION ====================
rows = []
test_files = sorted(os.listdir(TEST_DIR))
print(f"Processing {len(test_files)} test images...")

for f in test_files:
    pil = Image.open(Path(TEST_DIR) / f).convert("RGB")
    masks = pipeline_final(pil)
    annot = rle_encode_multi(masks) if masks else "authentic"
    rows.append({"case_id": Path(f).stem, "annotation": annot})

sub = pd.DataFrame(rows)
ss = pd.read_csv(SAMPLE_SUB)
ss["case_id"] = ss["case_id"].astype(str)
sub["case_id"] = sub["case_id"].astype(str)
final = ss[["case_id"]].merge(sub, on="case_id", how="left")
final["annotation"] = final["annotation"].fillna("authentic")
final[["case_id", "annotation"]].to_csv("submission.csv", index=False)

import os
print(f"\n✅ Saved submission.csv")
print(f"File exists: {os.path.exists('submission.csv')}")
print(f"Total: {len(final)} | Forged: {(final['annotation'] != 'authentic').sum()} | Authentic: {(final['annotation'] == 'authentic').sum()}")
