In [None]:
# ============================================================================
# ðŸ”§ CONFIGURATION - Edit these variables to test different settings
# ============================================================================

# --- Model & Weights ---
USE_KAGGLE_WEIGHTS = False
KAGGLE_WEIGHTS_PATH = "/kaggle/input/recod-1219/best_model.pt"
LOCAL_WEIGHTS_PATH = "../checkpoints/best_model.pt"
DINO_PATH = "facebook/dinov2-base"  # or "/kaggle/input/dinov2/pytorch/base/1" on Kaggle
IMG_SIZE = 512
CHANNELS = 4
UNFREEZE_BLOCKS = 3
DECODER_DROPOUT = 0.05

# --- Data ---
DATASET_ID = "eliplutchok/recod-finetune"
SAMPLE_SIZE = 100  # None = all

# --- Post-Processing (re-run evaluation cell after changing these) ---
USE_ENHANCED_ADAPTIVE = True  # False = simple threshold
ALPHA_GRAD = 0.35
GAUSSIAN_BLUR_SIZE = 3
THRESHOLD_STD_MULT = 0.3
SIMPLE_THRESHOLD = 0.5
USE_MORPHOLOGY = True
MORPH_CLOSE_KERNEL = 5
MORPH_OPEN_KERNEL = 3
MIN_AREA = 400
MIN_MEAN_PROB = 0.35

WEIGHTS_PATH = KAGGLE_WEIGHTS_PATH if USE_KAGGLE_WEIGHTS else LOCAL_WEIGHTS_PATH


In [None]:
# Setup, Model Definition & Loading
import os, cv2, math, random
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
from datasets import load_dataset
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DinoDecoder(nn.Module):
    def __init__(self, in_channels=768, out_channels=4, dropout=0.1):
        super().__init__()
        self.up1 = self._block(in_channels, 384, dropout)
        self.up2 = self._block(384, 192, dropout)
        self.up3 = self._block(192, 96, dropout)
        self.up4 = self._block(96, 48, dropout)
        self.final = nn.Conv2d(48, out_channels, kernel_size=1)
    
    def _block(self, in_ch, out_ch, dropout):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Dropout2d(dropout), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    
    def forward(self, features, target_size):
        x = F.interpolate(features, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up1(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up2(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = self.up3(x)
        x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
        x = self.up4(x)
        return self.final(x)

class DinoSegmenter(nn.Module):
    def __init__(self, backbone, out_channels=4, unfreeze_blocks=3, decoder_dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(backbone)
        hidden_size = self.encoder.config.hidden_size
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        for p in self.encoder.parameters(): p.requires_grad = False
        for i in range(len(self.encoder.encoder.layer) - unfreeze_blocks, len(self.encoder.encoder.layer)):
            for p in self.encoder.encoder.layer[i].parameters(): p.requires_grad = True
        for p in self.encoder.layernorm.parameters(): p.requires_grad = True
        self.decoder = DinoDecoder(hidden_size, out_channels, decoder_dropout)
    
    def forward(self, x):
        x_norm = (x - self.mean) / self.std
        feats = self.encoder(pixel_values=x_norm).last_hidden_state
        B, N, C = feats.shape
        fmap = feats[:, 1:, :].permute(0, 2, 1).reshape(B, C, int(math.sqrt(N-1)), int(math.sqrt(N-1)))
        target_size = (x.shape[2], x.shape[3])
        return self.decoder(fmap, target_size)

print(f"Device: {device}")
print(f"Loading weights from: {WEIGHTS_PATH}")
model = DinoSegmenter(DINO_PATH, CHANNELS, UNFREEZE_BLOCKS, DECODER_DROPOUT).to(device)
ckpt = torch.load(WEIGHTS_PATH, map_location=device, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt)
model.eval()
print("âœ… Model loaded")


In [None]:
# Load dataset and cache all model outputs (run once, then test different post-processing)
print(f"Loading dataset: {DATASET_ID}")
dataset = load_dataset(DATASET_ID, split="train")
if SAMPLE_SIZE and SAMPLE_SIZE < len(dataset):
    random.seed(42)
    dataset = dataset.select(random.sample(range(len(dataset)), SAMPLE_SIZE))
print(f"Testing on {len(dataset)} samples")

# Cache: stores (probs, original_size, gt_masks) for each sample
cache = []

@torch.no_grad()
def get_probs(pil_img):
    img = pil_img.resize((IMG_SIZE, IMG_SIZE))
    x = torch.from_numpy(np.array(img, np.float32) / 255.).permute(2, 0, 1)[None].to(device)
    return torch.sigmoid(model(x))[0].cpu().numpy()

print("Running model on all samples (this only needs to run once)...")
for idx in tqdm(range(len(dataset))):
    example = dataset[idx]
    img = example["image"].convert("RGB")
    original_size = img.size
    
    gt_masks = example.get("mask")
    if gt_masks and isinstance(gt_masks, list):
        gt_masks = [np.array(m).astype(np.uint8) for m in gt_masks]
    else:
        gt_masks = []
    
    probs = get_probs(img)
    cache.append({"probs": probs, "size": original_size, "gt": gt_masks})

print(f"âœ… Cached {len(cache)} samples. Now you can re-run the evaluation cell with different settings.")


In [None]:
# âš¡ EVALUATE - Re-run this cell after changing post-processing settings in cell 0
def process_probs(probs, size):
    """Apply current post-processing settings to cached probs."""
    masks = []
    for ch in range(probs.shape[0]):
        p = probs[ch]
        if USE_ENHANCED_ADAPTIVE:
            gx, gy = cv2.Sobel(p, cv2.CV_32F, 1, 0, ksize=3), cv2.Sobel(p, cv2.CV_32F, 0, 1, ksize=3)
            grad = np.sqrt(gx**2 + gy**2) / (np.sqrt(gx**2 + gy**2).max() + 1e-6)
            enhanced = cv2.GaussianBlur((1 - ALPHA_GRAD) * p + ALPHA_GRAD * grad, (GAUSSIAN_BLUR_SIZE, GAUSSIAN_BLUR_SIZE), 0)
            thr = np.mean(enhanced) + THRESHOLD_STD_MULT * np.std(enhanced)
            mask = (enhanced > thr).astype(np.uint8)
        else:
            mask = (p > SIMPLE_THRESHOLD).astype(np.uint8)
        
        if USE_MORPHOLOGY:
            mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((MORPH_CLOSE_KERNEL, MORPH_CLOSE_KERNEL), np.uint8))
            mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((MORPH_OPEN_KERNEL, MORPH_OPEN_KERNEL), np.uint8))
        
        mask = cv2.resize(mask, size, interpolation=cv2.INTER_NEAREST)
        area = mask.sum()
        if area >= MIN_AREA:
            prob_resized = cv2.resize(p, size, interpolation=cv2.INTER_LINEAR)
            if area > 0 and prob_resized[mask == 1].mean() >= MIN_MEAN_PROB:
                masks.append(mask)
    return masks

def compute_metrics(pred, gt):
    intersection = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    iou = intersection / union if union > 0 else 1.0
    dice = 2 * intersection / (pred.sum() + gt.sum()) if (pred.sum() + gt.sum()) > 0 else 1.0
    tp, fp, fn = intersection, (pred & ~gt).sum(), (~pred & gt).sum()
    prec, rec = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8)
    f1 = 2 * prec * rec / (prec + rec + 1e-8)
    return iou, dice, prec, rec, f1

results = []
for idx, item in enumerate(cache):
    probs, size, gt_masks = item["probs"], item["size"], item["gt"]
    pred_masks = process_probs(probs, size)
    
    h, w = size[1], size[0]
    pred = np.zeros((h, w), dtype=bool)
    for m in pred_masks: pred |= m.astype(bool)
    gt = np.zeros((h, w), dtype=bool)
    for m in gt_masks:
        if m.shape == (h, w): gt |= m.astype(bool)
    
    iou, dice, prec, rec, f1 = compute_metrics(pred, gt)
    results.append({"idx": idx, "iou": iou, "dice": dice, "precision": prec, "recall": rec, "f1": f1, 
                    "n_pred": len(pred_masks), "n_gt": len(gt_masks)})

df = pd.DataFrame(results)

# Print results
print("=" * 60)
print(f"Settings: {'Enhanced' if USE_ENHANCED_ADAPTIVE else 'Simple'} | Morph={USE_MORPHOLOGY} | Areaâ‰¥{MIN_AREA} | Probâ‰¥{MIN_MEAN_PROB}")
if USE_ENHANCED_ADAPTIVE:
    print(f"  Alpha={ALPHA_GRAD}, Blur={GAUSSIAN_BLUR_SIZE}, STDÃ—{THRESHOLD_STD_MULT}")
else:
    print(f"  Threshold={SIMPLE_THRESHOLD}")
print("=" * 60)
print(f"IoU:       {df['iou'].mean():.4f} Â± {df['iou'].std():.4f}")
print(f"Dice:      {df['dice'].mean():.4f} Â± {df['dice'].std():.4f}")
print(f"Precision: {df['precision'].mean():.4f} Â± {df['precision'].std():.4f}")
print(f"Recall:    {df['recall'].mean():.4f} Â± {df['recall'].std():.4f}")
print(f"F1:        {df['f1'].mean():.4f} Â± {df['f1'].std():.4f}")
print(f"Samples with preds: {(df['n_pred'] > 0).sum()}/{len(df)}")


In [None]:
# Visualize samples (worst, random, or specific index)
def viz(idx):
    item = cache[idx]
    probs, size, gt_masks = item["probs"], item["size"], item["gt"]
    pred_masks = process_probs(probs, size)
    img = dataset[idx]["image"].convert("RGB")
    
    h, w = size[1], size[0]
    pred = np.zeros((h, w), dtype=np.uint8)
    for m in pred_masks: pred |= m
    gt = np.zeros((h, w), dtype=np.uint8)
    for m in gt_masks:
        if m.shape == (h, w): gt |= m
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    axes[0].imshow(img); axes[0].set_title("Image"); axes[0].axis("off")
    axes[1].imshow(gt, cmap="Reds"); axes[1].set_title(f"GT ({len(gt_masks)})"); axes[1].axis("off")
    axes[2].imshow(pred, cmap="Blues"); axes[2].set_title(f"Pred ({len(pred_masks)})"); axes[2].axis("off")
    prob_avg = probs.mean(axis=0)
    axes[3].imshow(cv2.resize(prob_avg, size), cmap="viridis", vmin=0, vmax=1)
    axes[3].set_title("Avg Prob"); axes[3].axis("off")
    
    iou, dice, _, _, f1 = compute_metrics(pred.astype(bool), gt.astype(bool))
    fig.suptitle(f"Sample {idx} | IoU={iou:.3f} | Dice={dice:.3f} | F1={f1:.3f}")
    plt.tight_layout(); plt.show()

# Show worst, best, and a random sample
worst_idx = df.loc[df['f1'].idxmin(), 'idx']
best_idx = df.loc[df['f1'].idxmax(), 'idx']
print("Worst sample:"); viz(worst_idx)
print("Best sample:"); viz(best_idx)
print("Random sample:"); viz(random.randint(0, len(cache)-1))


In [None]:
# Quick parameter sweep (optional - compares different threshold values)
def eval_with_params(use_enhanced, threshold=0.5, alpha=0.35, std_mult=0.3, min_area=400, min_prob=0.35):
    results = []
    for item in cache:
        probs, size, gt_masks = item["probs"], item["size"], item["gt"]
        masks = []
        for ch in range(probs.shape[0]):
            p = probs[ch]
            if use_enhanced:
                gx, gy = cv2.Sobel(p, cv2.CV_32F, 1, 0, ksize=3), cv2.Sobel(p, cv2.CV_32F, 0, 1, ksize=3)
                grad = np.sqrt(gx**2 + gy**2) / (np.sqrt(gx**2 + gy**2).max() + 1e-6)
                enhanced = cv2.GaussianBlur((1 - alpha) * p + alpha * grad, (3, 3), 0)
                mask = (enhanced > np.mean(enhanced) + std_mult * np.std(enhanced)).astype(np.uint8)
            else:
                mask = (p > threshold).astype(np.uint8)
            mask = cv2.resize(mask, size, interpolation=cv2.INTER_NEAREST)
            if mask.sum() >= min_area:
                prob_r = cv2.resize(p, size, interpolation=cv2.INTER_LINEAR)
                if mask.sum() > 0 and prob_r[mask == 1].mean() >= min_prob:
                    masks.append(mask)
        
        h, w = size[1], size[0]
        pred = np.zeros((h, w), dtype=bool)
        for m in masks: pred |= m.astype(bool)
        gt = np.zeros((h, w), dtype=bool)
        for m in gt_masks:
            if m.shape == (h, w): gt |= m.astype(bool)
        
        _, _, _, _, f1 = compute_metrics(pred, gt)
        results.append(f1)
    return np.mean(results)

# Compare simple thresholds
print("Simple threshold sweep:")
for t in [0.3, 0.4, 0.5, 0.6, 0.7]:
    f1 = eval_with_params(use_enhanced=False, threshold=t)
    print(f"  threshold={t:.1f} â†’ F1={f1:.4f}")

print("\nEnhanced adaptive sweep (STD mult):")
for s in [0.1, 0.2, 0.3, 0.4, 0.5]:
    f1 = eval_with_params(use_enhanced=True, std_mult=s)
    print(f"  std_mult={s:.1f} â†’ F1={f1:.4f}")
