## CSIRO Image2Biomass – DINOv3 Direct Model Inference

This notebook runs inference using DINOv3 Direct models trained with `dinov3_train.py`.

**Model:**
- `DINOv3Direct`: Predicts Total, Green, GDM directly; derives Dead, Clover
- Components always sum to Total (mathematically guaranteed)

**Features:**
- Multi-fold ensemble (5-fold CV)
- Test-Time Augmentation (TTA)
- MPS/CUDA/CPU support

In [None]:
import os
import gc
import json
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple, Optional, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm


class CFG:
    """Configuration for DINOv3 Direct model inference."""
    
    # ==================== LOCAL TESTING MODE ====================
    LOCAL_TEST = True
    
    if LOCAL_TEST:
        BASE_PATH = "../data"
        TEST_CSV = os.path.join(BASE_PATH, "train.csv")  # Use train.csv for local OOF testing
        TEST_IMAGE_DIR = os.path.join(BASE_PATH, "train")
        # Trained DINOv3 model directory (from dinov3_train.py)
        MODEL_DIR = "/Users/kienvu/Desktop/kaggle/biomass/outputs/dinov3_20251218_174413"  # Update to your model dir
    else:
        BASE_PATH = "/kaggle/input/csiro-biomass"
        TEST_CSV = os.path.join(BASE_PATH, "test.csv")
        TEST_IMAGE_DIR = os.path.join(BASE_PATH, "test")
        MODEL_DIR = "/kaggle/input/your-dinov3-model"  # Update for Kaggle
    
    # DINOv3 backbone (fixed)
    BACKBONE = "vit_base_patch16_dinov3"
    
    # ==================== INFERENCE SETTINGS ====================
    SUBMISSION_FILE = "submission.csv"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Updated below
    BATCH_SIZE = 1
    NUM_WORKERS = 0
    
    # Model architecture params (auto-loaded from results.json)
    DROPOUT = 0.3
    HIDDEN_RATIO = 0.25
    GRID = (2, 2)
    USE_FILM = True
    USE_ATTENTION_POOL = True
    TRAIN_DEAD = False  # Whether model has head_dead
    TRAIN_CLOVER = False  # Whether model has head_clover
    USE_VEGETATION_INDICES = False  # Whether model uses VI features
    USE_DISPARITY = False  # Whether model uses stereo disparity features
    USE_LEARNABLE_AUG = False  # Whether model uses learnable augmentation
    LEARNABLE_AUG_COLOR = True  # Learnable color augmentation
    LEARNABLE_AUG_SPATIAL = False  # Learnable spatial augmentation
    IMG_SIZE = 576  # DINOv3 default from dinov3_train.py
    
    # TTA settings
    USE_TTA = True
    
    # Fold selection: None = use all folds, int = use top K folds by aR²
    TOP_K_FOLDS: Optional[int] = 3  # e.g., 3 to use best 3 folds
    
    ALL_TARGET_COLS = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]


# DINOv3 uses 256 native resolution but accepts any size divisible by 16
DINOV3_NATIVE_RES = 256


def get_device() -> torch.device:
    """Get best available device."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


# Update device
CFG.DEVICE = get_device()

print(f"Device: {CFG.DEVICE}")


def load_config_from_results(model_dir: str) -> Dict:
    """Load training config from results.json if available."""
    results_path = os.path.join(model_dir, "results.json")
    if os.path.exists(results_path):
        with open(results_path) as f:
            results = json.load(f)
        return results.get("config", {})
    return {}


# ==================== DATASET ====================

class TestBiomassDataset(Dataset):
    """Dataset for test/inference."""
    
    def __init__(self, df: pd.DataFrame, image_dir: str, transform: A.Compose):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
        row = self.df.iloc[idx]
        sample_id = row["sample_id_prefix"]
        
        img_path = os.path.join(self.image_dir, f"{sample_id}.jpg")
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        H, W = img.shape[:2]
        mid = W // 2
        img_left = img[:, :mid, :]
        img_right = img[:, mid:, :]
        
        if self.transform:
            aug_left = self.transform(image=img_left)
            aug_right = self.transform(image=img_right)
            img_left = aug_left["image"]
            img_right = aug_right["image"]
        
        return img_left, img_right, sample_id


def get_val_transform(img_size: Optional[int] = None) -> A.Compose:
    """Validation transform matching dinov3_train.py exactly."""
    if img_size is None:
        img_size = CFG.IMG_SIZE
    return A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


def get_tta_transforms(img_size: Optional[int] = None) -> List[A.Compose]:
    """TTA transforms: original + hflip + brightness. Uses INTER_AREA to match training."""
    if img_size is None:
        img_size = CFG.IMG_SIZE
    base = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    hflip = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.HorizontalFlip(p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    bright = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    return [base, hflip, bright]


# ==================== MODEL COMPONENTS ====================

def build_dinov3_backbone(pretrained: bool = False) -> Tuple[nn.Module, int, int]:
    """Build DINOv3 backbone (vit_base_patch16_dinov3)."""
    name = "vit_base_patch16_dinov3"
    model = timm.create_model(name, pretrained=pretrained, num_classes=0)
    feat_dim = model.num_features  # 768 for ViT-B
    input_res = 256  # DINOv3 default
    return model, feat_dim, input_res


def _make_edges(length: int, n_parts: int) -> List[Tuple[int, int]]:
    step = length // n_parts
    edges = [(i * step, (i + 1) * step) for i in range(n_parts)]
    edges[-1] = (edges[-1][0], length)
    return edges


class FiLM(nn.Module):
    def __init__(self, in_dim: int) -> None:
        super().__init__()
        hidden = max(64, in_dim // 2)
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, in_dim * 2),
        )
    
    def forward(self, context: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        gb = self.mlp(context)
        gamma, beta = torch.chunk(gb, 2, dim=1)
        return gamma, beta


class AttentionPooling(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.scale = dim ** -0.5
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.query(x.mean(dim=1, keepdim=True))
        k = self.key(x)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        return (attn @ x).squeeze(1)


# ==================== INNOVATIVE FEATURES ====================

class VegetationIndices(nn.Module):
    """Compute vegetation indices (ExG, ExR, GRVI) from RGB image."""
    
    def __init__(self, out_dim: int = 24) -> None:
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(24, out_dim),
            nn.GELU(),
        )
    
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        mean = torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1)
        img_denorm = (img * std + mean).clamp(0, 1)
        
        r, g, b = img_denorm.unbind(dim=1)
        exg = 2 * g - r - b
        grvi = (g - r) / (g + r + 1e-6)
        vari = (g - r) / (g + r - b + 1e-6)
        exr = 1.4 * r - g
        exgr = exg - exr
        norm_g = g / (r + g + b + 1e-6)
        
        indices = torch.stack([exg, exr, exgr, grvi, norm_g, vari], dim=1)
        feats = []
        for i in range(indices.size(1)):
            idx = indices[:, i]
            feats.extend([
                idx.mean(dim=(-2, -1)),
                idx.std(dim=(-2, -1)),
                idx.flatten(1).quantile(0.1, dim=1),
                idx.flatten(1).quantile(0.9, dim=1),
            ])
        stats = torch.stack(feats, dim=1)
        return self.proj(stats)


class DisparityFeatures(nn.Module):
    """Extract stereo disparity features from tile features."""
    
    def __init__(self, feat_dim: int, max_disparity: int = 8, out_dim: int = None) -> None:
        super().__init__()
        self.max_disparity = max_disparity
        if out_dim is None:
            out_dim = feat_dim // 4
        self.proj = nn.Sequential(
            nn.Linear(max_disparity, out_dim),
            nn.GELU(),
        )
        self.diff_proj = nn.Sequential(
            nn.Linear(6, out_dim // 2),
            nn.GELU(),
        )
        self.out_dim = out_dim + out_dim // 2
    
    def forward(self, feat_left: torch.Tensor, feat_right: torch.Tensor) -> torch.Tensor:
        B, N, D = feat_left.shape
        feat_l = F.normalize(feat_left, dim=-1)
        feat_r = F.normalize(feat_right, dim=-1)
        
        correlations = []
        for d in range(self.max_disparity):
            shifted_r = torch.roll(feat_r, shifts=d, dims=1)
            corr = (feat_l * shifted_r).sum(dim=-1).mean(dim=1)
            correlations.append(corr)
        corr_volume = torch.stack(correlations, dim=-1)
        corr_feat = self.proj(corr_volume)
        
        fl_pooled = feat_left.mean(dim=1)
        fr_pooled = feat_right.mean(dim=1)
        fl_norm = F.normalize(fl_pooled, dim=-1)
        fr_norm = F.normalize(fr_pooled, dim=-1)
        
        correlation = (fl_norm * fr_norm).sum(dim=-1, keepdim=True)
        diff = fl_pooled - fr_pooled
        diff_norm = diff.norm(dim=-1, keepdim=True)
        diff_mean = diff.mean(dim=-1, keepdim=True)
        diff_std = diff.std(dim=-1, keepdim=True)
        ratio = fl_pooled / (fr_pooled + 1e-6)
        ratio_mean = ratio.mean(dim=-1, keepdim=True)
        ratio_std = ratio.std(dim=-1, keepdim=True)
        
        stats = torch.cat([correlation, diff_norm, diff_mean, diff_std, ratio_mean, ratio_std], dim=-1)
        diff_feat = self.diff_proj(stats)
        return torch.cat([corr_feat, diff_feat], dim=-1)


class LearnableAugmentation(nn.Module):
    """Learnable augmentation module (identity at inference time).
    
    Includes all strong augmentations: color, spatial, blur, CLAHE.
    """
    
    def __init__(
        self,
        enable_color: bool = True,
        enable_spatial: bool = False,
        enable_blur: bool = True,
        enable_local_contrast: bool = True,
        color_strength: float = 0.25,
        spatial_strength: float = 0.15,
        noise_std: float = 0.1,
    ) -> None:
        super().__init__()
        self.enable_color = enable_color
        self.enable_spatial = enable_spatial
        self.enable_blur = enable_blur
        self.enable_local_contrast = enable_local_contrast
        
        if enable_color:
            self.color_params = nn.Parameter(torch.zeros(6))
            self.color_predictor = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(3, 32),
                nn.GELU(),
                nn.Linear(32, 6),
                nn.Tanh(),
            )
        
        if enable_spatial:
            self.spatial_params = nn.Parameter(torch.zeros(5))
            self.spatial_predictor = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(3, 32),
                nn.GELU(),
                nn.Linear(32, 5),
                nn.Tanh(),
            )
        
        if enable_blur:
            self.blur_params = nn.Parameter(torch.zeros(2))
            self.blur_predictor = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(3, 16),
                nn.GELU(),
                nn.Linear(16, 2),
                nn.Sigmoid(),
            )
            self.blur_kernel = nn.Parameter(torch.tensor([
                [1., 2., 1.],
                [2., 4., 2.],
                [1., 2., 1.],
            ]) / 16.0)
        
        if enable_local_contrast:
            self.contrast_params = nn.Parameter(torch.zeros(2))
            self.contrast_predictor = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(3, 16),
                nn.GELU(),
                nn.Linear(16, 2),
                nn.Sigmoid(),
            )
    
    def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Identity during inference (eval mode)
        return img, torch.tensor(0.0, device=img.device)


# ==================== DINOv3 DIRECT MODEL ====================

class DINOv3Direct(nn.Module):
    """
    DINOv3 Direct Model for Biomass Prediction.
    Predicts Total, Green, GDM directly; optionally predicts Dead, Clover.
    Components always sum to Total (mathematically guaranteed).
    """
    
    def __init__(
        self,
        grid: Tuple[int, int] = (2, 2),
        pretrained: bool = False,
        dropout: float = 0.3,
        hidden_ratio: float = 0.25,
        use_film: bool = True,
        use_attention_pool: bool = True,
        train_dead: bool = False,
        train_clover: bool = False,
        use_vegetation_indices: bool = False,
        use_disparity: bool = False,
        use_learnable_aug: bool = False,
        learnable_aug_color: bool = True,
        learnable_aug_spatial: bool = False,
    ) -> None:
        super().__init__()
        
        # Build DINOv3 backbone
        self.backbone, self.feat_dim, self.input_res = build_dinov3_backbone(pretrained)
        self.grid = tuple(grid)
        self.use_film = use_film
        self.use_attention_pool = use_attention_pool
        self.train_dead = train_dead
        self.train_clover = train_clover
        self.use_vegetation_indices = use_vegetation_indices
        self.use_disparity = use_disparity
        self.use_learnable_aug = use_learnable_aug
        
        # Learnable augmentation (identity at inference)
        # Includes all strong augmentations: color, spatial, blur, CLAHE
        if use_learnable_aug:
            self.learnable_aug_left = LearnableAugmentation(
                enable_color=learnable_aug_color,
                enable_spatial=learnable_aug_spatial,
                enable_blur=learnable_aug_color,
                enable_local_contrast=learnable_aug_color,
            )
            self.learnable_aug_right = LearnableAugmentation(
                enable_color=learnable_aug_color,
                enable_spatial=learnable_aug_spatial,
                enable_blur=learnable_aug_color,
                enable_local_contrast=learnable_aug_color,
            )
        
        # FiLM for cross-view conditioning
        if use_film:
            self.film_left = FiLM(self.feat_dim)
            self.film_right = FiLM(self.feat_dim)
        
        # Attention pooling for tiles
        if use_attention_pool:
            self.attn_pool_left = AttentionPooling(self.feat_dim)
            self.attn_pool_right = AttentionPooling(self.feat_dim)
        
        # === Optional feature modules ===
        extra_dim = 0
        if use_vegetation_indices:
            vi_out = 24
            self.vi_left = VegetationIndices(out_dim=vi_out)
            self.vi_right = VegetationIndices(out_dim=vi_out)
            extra_dim += vi_out * 2
        
        if use_disparity:
            self.disparity_module = DisparityFeatures(self.feat_dim, max_disparity=8, out_dim=self.feat_dim // 4)
            extra_dim += self.disparity_module.out_dim
        
        # Head dimensions
        combined_dim = self.feat_dim * 2 + extra_dim
        hidden_dim = max(64, int((self.feat_dim * 2) * hidden_ratio))
        
        # Shared projection
        self.shared_proj = nn.Sequential(
            nn.LayerNorm(combined_dim),
            nn.Linear(combined_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        
        # Prediction heads
        def _make_head() -> nn.Sequential:
            return nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout * 0.5),
                nn.Linear(hidden_dim, 1),
            )
        
        self.head_total = _make_head()
        self.head_green = _make_head()
        self.head_gdm = _make_head()
        
        # Optional heads for Dead and Clover
        self.head_dead = _make_head() if train_dead else None
        self.head_clover = _make_head() if train_clover else None
        
        self.softplus = nn.Softplus(beta=1.0)
    
    def _collect_tiles(self, x: torch.Tensor) -> List[torch.Tensor]:
        """Split image into grid of tiles."""
        _, _, H, W = x.shape
        r, c = self.grid
        rows = _make_edges(H, r)
        cols = _make_edges(W, c)
        
        tiles = []
        for rs, re in rows:
            for cs, ce in cols:
                tile = x[:, :, rs:re, cs:ce]
                if tile.shape[-2:] != (self.input_res, self.input_res):
                    tile = F.interpolate(
                        tile,
                        size=(self.input_res, self.input_res),
                        mode="bilinear",
                        align_corners=False,
                    )
                tiles.append(tile)
        return tiles
    
    def _extract_features(
        self, x_left: torch.Tensor, x_right: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract tile features from both views in one backbone call."""
        B = x_left.size(0)
        
        tiles_left = self._collect_tiles(x_left)
        tiles_right = self._collect_tiles(x_right)
        num_tiles = len(tiles_left)
        
        # Process all tiles in one forward pass
        all_tiles = torch.cat(tiles_left + tiles_right, dim=0)
        all_feats = self.backbone(all_tiles)
        
        # Reshape
        all_feats = all_feats.view(2 * num_tiles, B, -1).permute(1, 0, 2)
        feats_left = all_feats[:, :num_tiles, :]
        feats_right = all_feats[:, num_tiles:, :]
        
        return feats_left, feats_right
    
    def forward(
        self, x_left: torch.Tensor, x_right: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass."""
        # Extract tile features
        tiles_left, tiles_right = self._extract_features(x_left, x_right)
        
        # Stereo Disparity Features (before FiLM)
        disp_feat = None
        if self.use_disparity:
            disp_feat = self.disparity_module(tiles_left, tiles_right)
        
        # Context for FiLM
        ctx_left = tiles_left.mean(dim=1)
        ctx_right = tiles_right.mean(dim=1)
        
        # Apply FiLM cross-conditioning
        if self.use_film:
            gamma_l, beta_l = self.film_left(ctx_right)
            gamma_r, beta_r = self.film_right(ctx_left)
            tiles_left = tiles_left * (1 + gamma_l.unsqueeze(1)) + beta_l.unsqueeze(1)
            tiles_right = tiles_right * (1 + gamma_r.unsqueeze(1)) + beta_r.unsqueeze(1)
        
        # Pool tiles
        if self.use_attention_pool:
            f_left = self.attn_pool_left(tiles_left)
            f_right = self.attn_pool_right(tiles_right)
        else:
            f_left = tiles_left.mean(dim=1)
            f_right = tiles_right.mean(dim=1)
        
        # Combine features
        features_list = [f_left, f_right]
        
        if self.use_vegetation_indices:
            vi_left = self.vi_left(x_left)
            vi_right = self.vi_right(x_right)
            features_list.extend([vi_left, vi_right])
        
        if disp_feat is not None:
            features_list.append(disp_feat)
        
        f = torch.cat(features_list, dim=1)
        f = self.shared_proj(f)
        
        # Core predictions
        total_raw = self.softplus(self.head_total(f))
        green_raw = self.softplus(self.head_green(f))
        gdm_raw = self.softplus(self.head_gdm(f))
        
        # Enforce constraints: Total >= GDM >= Green
        total = total_raw
        gdm = torch.minimum(gdm_raw, total)
        green = torch.minimum(green_raw, gdm)
        
        # Dead: predicted or derived
        if self.head_dead is not None:
            dead_raw = self.softplus(self.head_dead(f))
            dead = torch.minimum(dead_raw, total - gdm + 1e-6)
            dead = F.relu(dead)
        else:
            dead = F.relu(total - gdm)
        
        # Clover: predicted or derived
        if self.head_clover is not None:
            clover_raw = self.softplus(self.head_clover(f))
            clover = torch.minimum(clover_raw, gdm - green + 1e-6)
            clover = F.relu(clover)
        else:
            clover = F.relu(gdm - green)
        
        return green, dead, clover, gdm, total


# ==================== MODEL LOADING ====================

def _strip_module_prefix(sd: dict) -> dict:
    """Remove 'module.' prefix from state dict keys (for DDP-trained models)."""
    if not sd:
        return sd
    keys = list(sd.keys())
    if all(k.startswith("module.") for k in keys):
        return {k[len("module."):]: v for k, v in sd.items()}
    return sd


def _detect_model_config(sd_keys: set) -> dict:
    """Auto-detect model config from checkpoint keys."""
    # Detect learnable aug color/spatial from keys
    has_learnable_aug = any(k.startswith("learnable_aug_left.") for k in sd_keys)
    learnable_aug_color = any("color_params" in k for k in sd_keys if k.startswith("learnable_aug_left."))
    learnable_aug_spatial = any("spatial_params" in k for k in sd_keys if k.startswith("learnable_aug_left."))
    
    return {
        "use_film": any(k.startswith("film_left.") for k in sd_keys),
        "use_attention_pool": any(k.startswith("attn_pool_left.") for k in sd_keys),
        "train_dead": any(k.startswith("head_dead.") for k in sd_keys),
        "train_clover": any(k.startswith("head_clover.") for k in sd_keys),
        "use_vegetation_indices": any(k.startswith("vi_left.") for k in sd_keys),
        "use_disparity": any(k.startswith("disparity_module.") for k in sd_keys),
        "use_learnable_aug": has_learnable_aug,
        "learnable_aug_color": learnable_aug_color,
        "learnable_aug_spatial": learnable_aug_spatial,
    }


def load_fold_model(
    path: str,
    grid: Tuple[int, int] = (2, 2),
    dropout: float = 0.3,
    hidden_ratio: float = 0.25,
    use_film: bool = True,
    use_attention_pool: bool = True,
    train_dead: bool = False,
    train_clover: bool = False,
    use_vegetation_indices: bool = False,
    use_disparity: bool = False,
    use_learnable_aug: bool = False,
    learnable_aug_color: bool = True,
    learnable_aug_spatial: bool = False,
) -> nn.Module:
    """Load a DINOv3Direct model checkpoint."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    
    try:
        raw_sd = torch.load(path, map_location="cpu", weights_only=True)
    except TypeError:
        raw_sd = torch.load(path, map_location="cpu")
    
    sd = _strip_module_prefix(raw_sd)
    sd_keys = set(sd.keys())
    
    # Auto-detect config from checkpoint
    detected_config = _detect_model_config(sd_keys)
    use_film = detected_config.get("use_film", use_film)
    use_attention_pool = detected_config.get("use_attention_pool", use_attention_pool)
    train_dead = detected_config.get("train_dead", train_dead)
    train_clover = detected_config.get("train_clover", train_clover)
    use_vegetation_indices = detected_config.get("use_vegetation_indices", use_vegetation_indices)
    use_disparity = detected_config.get("use_disparity", use_disparity)
    use_learnable_aug = detected_config.get("use_learnable_aug", use_learnable_aug)
    learnable_aug_color = detected_config.get("learnable_aug_color", learnable_aug_color)
    learnable_aug_spatial = detected_config.get("learnable_aug_spatial", learnable_aug_spatial)
    
    # DINOv3Direct model (from dinov3_train.py)
    model = DINOv3Direct(
        grid=grid,
        pretrained=False,
        dropout=dropout,
        hidden_ratio=hidden_ratio,
        use_film=use_film,
        use_attention_pool=use_attention_pool,
        train_dead=train_dead,
        train_clover=train_clover,
        use_vegetation_indices=use_vegetation_indices,
        use_disparity=use_disparity,
        use_learnable_aug=use_learnable_aug,
        learnable_aug_color=learnable_aug_color,
        learnable_aug_spatial=learnable_aug_spatial,
    )
    
    model.load_state_dict(sd, strict=True)
    model = model.to(CFG.DEVICE)
    model.eval()
    
    return model


def find_model_checkpoints(model_dir: str, top_k: Optional[int] = None) -> Tuple[List[str], List[int]]:
    """
    Find DINOv3 model checkpoints (dinov3_best_fold*.pth).
    
    Args:
        model_dir: Directory containing model checkpoints
        top_k: If set, return only top K folds by aR² (from results.json)
    
    Returns:
        Tuple of (checkpoint_paths, fold_indices)
    """
    # Find all checkpoints
    all_checkpoints = []
    for f in os.listdir(model_dir):
        if f.startswith("dinov3_best_fold") and f.endswith(".pth"):
            # Extract fold number from filename
            fold_num = int(f.replace("dinov3_best_fold", "").replace(".pth", ""))
            all_checkpoints.append((fold_num, os.path.join(model_dir, f)))
    
    all_checkpoints.sort(key=lambda x: x[0])  # Sort by fold number
    
    if top_k is None or top_k >= len(all_checkpoints):
        # Return all checkpoints
        return [c[1] for c in all_checkpoints], [c[0] for c in all_checkpoints]
    
    # Load results.json to get fold metrics
    results_path = os.path.join(model_dir, "results.json")
    if not os.path.exists(results_path):
        print(f"Warning: results.json not found, using all {len(all_checkpoints)} folds")
        return [c[1] for c in all_checkpoints], [c[0] for c in all_checkpoints]
    
    with open(results_path) as f:
        results = json.load(f)
    
    fold_results = results.get("folds", [])
    if not fold_results:
        print(f"Warning: No fold results in results.json, using all {len(all_checkpoints)} folds")
        return [c[1] for c in all_checkpoints], [c[0] for c in all_checkpoints]
    
    # Get fold metrics (aR² = best_r2 after our change, or avg_r2 from metrics)
    fold_metrics = []
    for fr in fold_results:
        fold_num = fr["fold"]
        # best_r2 is now aR² after our dinov3_train.py change
        # For older models, try metrics.avg_r2, then fall back to best_r2
        metrics = fr.get("metrics", {})
        ar2 = metrics.get("avg_r2", fr.get("best_r2", 0))
        fold_metrics.append((fold_num, ar2))
    
    # Sort by metric (descending) and take top K
    fold_metrics.sort(key=lambda x: x[1], reverse=True)
    top_folds = set(fm[0] for fm in fold_metrics[:top_k])
    
    print(f"Selecting top {top_k} folds by aR²:")
    for fold_num, ar2 in fold_metrics[:top_k]:
        print(f"  Fold {fold_num}: aR²={ar2:.4f}")
    
    # Filter checkpoints to top folds
    selected = [(fold_num, path) for fold_num, path in all_checkpoints if fold_num in top_folds]
    selected.sort(key=lambda x: x[0])
    
    return [c[1] for c in selected], [c[0] for c in selected]


# ==================== INFERENCE ====================

@torch.no_grad()
def predict_one_view(models: List[nn.Module], loader: DataLoader) -> np.ndarray:
    """Run inference with fold ensemble."""
    out_list = []
    
    for batch in tqdm(loader, desc="Predicting", leave=False):
        x_left, x_right, _ = batch
        x_left = x_left.to(CFG.DEVICE)
        x_right = x_right.to(CFG.DEVICE)
        
        fold_preds = []
        for model in models:
            outputs = model(x_left, x_right)
            # Handle both old (5 outputs) and new (6 outputs with aux_loss) models
            if len(outputs) == 6:
                green, dead, clover, gdm, total, _ = outputs
            else:
                green, dead, clover, gdm, total = outputs
            pred = torch.cat([green, dead, clover, gdm, total], dim=1)
            fold_preds.append(pred.float().cpu().numpy())
        
        avg_pred = np.mean(fold_preds, axis=0)
        out_list.append(avg_pred)
    
    return np.concatenate(out_list, axis=0)


def run_inference_tta(models: List[nn.Module], df: pd.DataFrame) -> np.ndarray:
    """Run inference with TTA."""
    transforms = get_tta_transforms() if CFG.USE_TTA else [get_val_transform()]
    
    all_preds = []
    for i, transform in enumerate(transforms):
        print(f"TTA view {i+1}/{len(transforms)}...")
        ds = TestBiomassDataset(df, CFG.TEST_IMAGE_DIR, transform)
        loader = DataLoader(
            ds, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=CFG.NUM_WORKERS
        )
        preds = predict_one_view(models, loader)
        all_preds.append(preds)
    
    return np.mean(all_preds, axis=0)


def create_submission(final_5: np.ndarray, test_long: pd.DataFrame, test_unique: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Create submission DataFrame with order matching test_long."""
    green = final_5[:, 0]
    dead = final_5[:, 1]
    clover = final_5[:, 2]
    gdm = final_5[:, 3]
    total = final_5[:, 4]

    def nnz(x: np.ndarray) -> np.ndarray:
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return np.maximum(0, x)

    green, dead, clover, gdm, total = map(nnz, [green, dead, clover, gdm, total])

    wide = pd.DataFrame(
        {
            "image_path": test_unique["image_path"],
            "Dry_Green_g": green,
            "Dry_Dead_g": dead,
            "Dry_Clover_g": clover,
            "GDM_g": gdm,
            "Dry_Total_g": total,
        }
    )

    long_preds = wide.melt(
        id_vars=["image_path"],
        value_vars=CFG.ALL_TARGET_COLS,
        var_name="target_name",
        value_name="target",
    )

    sub = pd.merge(
        test_long[["sample_id", "image_path", "target_name"]],
        long_preds,
        on=["image_path", "target_name"],
        how="left",
    )[["sample_id", "target"]]

    sub["target"] = np.nan_to_num(sub["target"], nan=0.0, posinf=0.0, neginf=0.0)
    sub.columns = ["sample_id", "value"]  # Rename for submission format
    sub.to_csv(CFG.SUBMISSION_FILE, index=False)
    print(f"Saved: {CFG.SUBMISSION_FILE}")
    print(sub.head(10))
    return sub, wide


def main():
    """Run DINOv3 Direct model inference."""
    print("="*60)
    print("DINOv3 Direct Model Inference")
    print("="*60)
    print(f"Device: {CFG.DEVICE}")
    print(f"Model dir: {CFG.MODEL_DIR}")
    print(f"TTA: {CFG.USE_TTA}")
    
    # Load config from results.json (dinov3_train.py output)
    config = load_config_from_results(CFG.MODEL_DIR)
    if config:
        # dinov3_train.py saves grid as single int
        grid_val = config.get("grid", 2)
        CFG.GRID = (grid_val, grid_val) if isinstance(grid_val, int) else tuple(grid_val)
        CFG.DROPOUT = config.get("dropout", CFG.DROPOUT)
        CFG.HIDDEN_RATIO = config.get("hidden_ratio", CFG.HIDDEN_RATIO)
        CFG.USE_FILM = config.get("use_film", CFG.USE_FILM)
        CFG.USE_ATTENTION_POOL = config.get("use_attention_pool", CFG.USE_ATTENTION_POOL)
        CFG.TRAIN_DEAD = config.get("train_dead", CFG.TRAIN_DEAD)
        CFG.TRAIN_CLOVER = config.get("train_clover", CFG.TRAIN_CLOVER)
        CFG.USE_VEGETATION_INDICES = config.get("use_vegetation_indices", CFG.USE_VEGETATION_INDICES)
        CFG.USE_DISPARITY = config.get("use_disparity", CFG.USE_DISPARITY)
        CFG.USE_LEARNABLE_AUG = config.get("use_learnable_aug", CFG.USE_LEARNABLE_AUG)
        CFG.LEARNABLE_AUG_COLOR = config.get("learnable_aug_color", CFG.LEARNABLE_AUG_COLOR)
        CFG.LEARNABLE_AUG_SPATIAL = config.get("learnable_aug_spatial", CFG.LEARNABLE_AUG_SPATIAL)
        print(f"Loaded config from results.json")
    
    print(f"Model: DINOv3Direct")
    print(f"Backbone: vit_base_patch16_dinov3")
    print(f"Grid: {CFG.GRID}")
    print(f"Image size: {CFG.IMG_SIZE}")
    if CFG.TRAIN_DEAD or CFG.TRAIN_CLOVER:
        print(f"Optional heads: train_dead={CFG.TRAIN_DEAD}, train_clover={CFG.TRAIN_CLOVER}")
    if CFG.USE_VEGETATION_INDICES or CFG.USE_DISPARITY or CFG.USE_LEARNABLE_AUG:
        extras = []
        if CFG.USE_VEGETATION_INDICES:
            extras.append("Vegetation Indices")
        if CFG.USE_DISPARITY:
            extras.append("Stereo Disparity")
        if CFG.USE_LEARNABLE_AUG:
            aug_types = []
            if CFG.LEARNABLE_AUG_COLOR:
                aug_types.append("color")
            if CFG.LEARNABLE_AUG_SPATIAL:
                aug_types.append("spatial")
            extras.append(f"Learnable Aug ({'+'.join(aug_types)})")
        print(f"Innovative features: {', '.join(extras)}")
    
    # Find checkpoints (optionally select top K folds)
    checkpoints, fold_indices = find_model_checkpoints(CFG.MODEL_DIR, top_k=CFG.TOP_K_FOLDS)
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {CFG.MODEL_DIR}")
    if CFG.TOP_K_FOLDS:
        print(f"Using top {len(checkpoints)} folds: {fold_indices}")
    else:
        print(f"Found {len(checkpoints)} fold checkpoints")
    
    # Load models
    print("\nLoading models...")
    models = []
    for ckpt in checkpoints:
        model = load_fold_model(
            ckpt,
            grid=CFG.GRID,
            dropout=CFG.DROPOUT,
            hidden_ratio=CFG.HIDDEN_RATIO,
            use_film=CFG.USE_FILM,
            use_attention_pool=CFG.USE_ATTENTION_POOL,
            train_dead=CFG.TRAIN_DEAD,
            train_clover=CFG.TRAIN_CLOVER,
            use_vegetation_indices=CFG.USE_VEGETATION_INDICES,
            use_disparity=CFG.USE_DISPARITY,
            use_learnable_aug=CFG.USE_LEARNABLE_AUG,
            learnable_aug_color=CFG.LEARNABLE_AUG_COLOR,
            learnable_aug_spatial=CFG.LEARNABLE_AUG_SPATIAL,
        )
        models.append(model)
        print(f"  Loaded: {os.path.basename(ckpt)}")
    
    # Load test data
    print("\nLoading test data...")
    test_long = pd.read_csv(CFG.TEST_CSV)
    
    # Extract sample_id_prefix for dataset compatibility
    if "sample_id_prefix" not in test_long.columns:
        if "image_path" in test_long.columns:
            test_long["sample_id_prefix"] = test_long["image_path"].str.extract(r'([A-Z]+\d+)')[0]
        elif "sample_id" in test_long.columns:
            test_long["sample_id_prefix"] = test_long["sample_id"].str.split("__").str[0]
        else:
            raise ValueError("Cannot find sample_id or image_path column")
    
    # Use image_path for dedup to match v1 submission ordering
    test_unique = test_long.drop_duplicates(subset=["image_path"]).reset_index(drop=True)
    print(f"Test samples: {len(test_unique)}")
    
    # Run inference
    print("\nRunning inference...")
    preds = run_inference_tta(models, test_unique)
    print(f"Predictions shape: {preds.shape}")
    
    # Constraint check
    component_sum = preds[:, 0] + preds[:, 1] + preds[:, 2]
    total_pred = preds[:, 4]
    diff = np.abs(component_sum - total_pred)
    print(f"\nConstraint check (G+D+C=T):")
    print(f"  Max diff: {diff.max():.6f}")
    print(f"  Mean diff: {diff.mean():.6f}")
    
    # Statistics
    print(f"\nPrediction stats:")
    for i, name in enumerate(["Green", "Dead", "Clover", "GDM", "Total"]):
        print(f"  {name}: mean={preds[:, i].mean():.2f}, std={preds[:, i].std():.2f}, "
              f"min={preds[:, i].min():.2f}, max={preds[:, i].max():.2f}")
    
    # Create submission
    print("\nCreating submission...")
    sub_df, wide_df = create_submission(preds, test_long, test_unique)
    print(sub_df.head(10))
    
    # Cleanup
    del models
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("\n" + "="*60)
    print("Done!")
    print("="*60)
    return sub_df


# Run
submission = main()

Device: mps
DINOv3 Direct Model Inference
Device: mps
Model dir: /Users/kienvu/Desktop/kaggle/biomass/outputs/dinov3_20251218_174413
TTA: True
Loaded config from results.json
Model: DINOv3Direct
Backbone: vit_base_patch16_dinov3
Grid: (2, 2)
Image size: 576
Optional heads: train_dead=False, train_clover=True
Selecting top 3 folds by aR²:
  Fold 2: aR²=0.7024
  Fold 4: aR²=0.6392
  Fold 1: aR²=0.6036
Using top 3 folds: [1, 2, 4]

Loading models...
  Loaded: dinov3_best_fold1.pth
  Loaded: dinov3_best_fold2.pth
  Loaded: dinov3_best_fold4.pth

Loading test data...
Test samples: 357

Running inference...
TTA view 1/3...


                                                             

TTA view 2/3...


                                                             

TTA view 3/3...


                                                             

Predictions shape: (357, 5)

Constraint check (G+D+C=T):
  Max diff: 0.000015
  Mean diff: 0.000002

Prediction stats:
  Green: mean=25.14, std=23.01, min=0.38, max=116.57
  Dead: mean=12.48, std=7.48, min=0.00, max=34.58
  Clover: mean=6.48, std=11.58, min=0.00, max=72.11
  GDM: mean=31.62, std=23.21, min=0.46, max=117.48
  Total: mean=44.10, std=24.85, min=0.46, max=130.69

Creating submission...
Saved: submission.csv
                    sample_id      value
0  ID1011485656__Dry_Clover_g   0.436817
1    ID1011485656__Dry_Dead_g  25.848732
2   ID1011485656__Dry_Green_g  19.265554
3   ID1011485656__Dry_Total_g  45.551102
4         ID1011485656__GDM_g  19.702368
5  ID1012260530__Dry_Clover_g   0.114165
6    ID1012260530__Dry_Dead_g   0.007654
7   ID1012260530__Dry_Green_g   4.185601
8   ID1012260530__Dry_Total_g   4.307419
9         ID1012260530__GDM_g   4.299765
                    sample_id      value
0  ID1011485656__Dry_Clover_g   0.436817
1    ID1011485656__Dry_Dead_g  25.848732
2 