## CSIRO Image2Biomass – DINOv3 Direct Model Inference (with Depth)

This notebook runs inference using DINOv3 Direct models trained with `dinov3_train.py`.

**Model:**
- `DINOv3Direct`: Predicts Total, Green, GDM directly; derives Dead, Clover
- Components always sum to Total (mathematically guaranteed)

**Depth Features (Depth Anything V2):**
- `use_depth`: Extracts depth statistics (depth_gradient has r=0.63 correlation with green biomass!)
- `use_depth_attention`: Uses depth maps to guide tile attention pooling

**Other Features:**
- Multi-fold ensemble (5-fold CV)
- Test-Time Augmentation (TTA)
- Vegetation Indices, Stereo Disparity (optional)
- MPS/CUDA/CPU support

In [None]:
import os
import sys

# ==================== FIX PROTOBUF CONFLICT ON KAGGLE ====================
# Kaggle has TensorFlow pre-installed which conflicts with transformers' protobuf
# Use pure Python implementation to avoid the conflict
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

# ==================== KAGGLE OFFLINE MODE ====================
# Set these BEFORE importing libraries that access HuggingFace
# On Kaggle with no internet, this prevents network access attempts
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"

import gc
import json
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple, Optional, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm

# Add parent directory to path for importing src modules
sys.path.insert(0, os.path.dirname(os.path.abspath(".")))


class CFG:
    """Configuration for DINOv3 Direct model inference."""
    
    # ==================== LOCAL TESTING MODE ====================
    LOCAL_TEST = True
    
    if LOCAL_TEST:
        BASE_PATH = "../data"
        TEST_CSV = os.path.join(BASE_PATH, "train.csv")  # Use train.csv for local OOF testing
        TEST_IMAGE_DIR = os.path.join(BASE_PATH, "train")
        # Trained DINOv3 model directory (from dinov3_train.py)
        MODEL_DIR = "/workspace/biomass-kaggle/outputs/dinov3_full_mse"  # Model with presence/ndvi/height heads
    else:
        BASE_PATH = "/kaggle/input/csiro-biomass"
        TEST_CSV = os.path.join(BASE_PATH, "test.csv")
        TEST_IMAGE_DIR = os.path.join(BASE_PATH, "test")
        MODEL_DIR = "/kaggle/input/your-dinov3-model"  # Update for Kaggle
    
    # DINOv3 backbone (fixed)
    BACKBONE = "vit_base_patch16_dinov3"
    
    # ==================== INFERENCE SETTINGS ====================
    SUBMISSION_FILE = "submission.csv"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Updated below
    BATCH_SIZE = 1
    NUM_WORKERS = 0
    
    # Model architecture params (auto-loaded from results.json)
    DROPOUT = 0.3
    HIDDEN_RATIO = 0.25
    GRID = (2, 2)
    USE_FILM = True
    USE_ATTENTION_POOL = True
    TRAIN_DEAD = False  # Whether model has head_dead
    TRAIN_CLOVER = False  # Whether model has head_clover
    USE_VEGETATION_INDICES = False  # Whether model uses VI features
    USE_DISPARITY = False  # Whether model uses stereo disparity features
    USE_DEPTH = False  # Whether model uses Depth Anything V2 features
    DEPTH_MODEL_SIZE = "small"  # Depth model size: "small" or "base"
    USE_DEPTH_ATTENTION = False  # Whether model uses depth-guided attention
    
    # New auxiliary heads (presence, NDVI, height)
    USE_PRESENCE_HEADS = False  # Whether model uses presence heads for Dead/Clover
    USE_NDVI_HEAD = False  # Whether model uses NDVI regression head
    USE_HEIGHT_HEAD = False  # Whether model uses height regression head
    USE_SPECIES_HEAD = False  # Whether model uses species classification head
    
    # For Kaggle: local path to saved depth model (None = use HuggingFace)
    # Save with: model.save_pretrained("path/to/depth_model")
    # Upload the folder as a Kaggle dataset
    DEPTH_MODEL_PATH: Optional[str] = None  # e.g., "/kaggle/input/depth-anything-v2/small"
    USE_LEARNABLE_AUG = False  # Whether model uses learnable augmentation
    LEARNABLE_AUG_COLOR = True  # Learnable color augmentation
    LEARNABLE_AUG_SPATIAL = False  # Learnable spatial augmentation
    IMG_SIZE = 672  # DINOv3 default from dinov3_train.py
    
    # TTA settings
    USE_TTA = True
    TTA_LEVEL = "default"  # Options: "none", "light", "default", "heavy", "extreme"
    # - none: No TTA (just base transform)
    # - light: base + hflip (2 views)
    # - default: base + hflip + brightness (3 views) 
    # - heavy: default + vflip + darker (5 views)
    # - extreme: heavy + rotate90 + hue_shift + gamma (8 views)
    
    # Fold selection: None = use all folds, int = use top K folds by aR²
    TOP_K_FOLDS: Optional[int] = None  # None = use all 5 folds
    
    ALL_TARGET_COLS = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]


# DINOv3 uses 256 native resolution but accepts any size divisible by 16
DINOV3_NATIVE_RES = 256


def get_device() -> torch.device:
    """Get best available device."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


# Update device
CFG.DEVICE = get_device()

print(f"Device: {CFG.DEVICE}")


def load_config_from_results(model_dir: str) -> Dict:
    """Load training config from results.json if available."""
    results_path = os.path.join(model_dir, "results.json")
    if os.path.exists(results_path):
        with open(results_path) as f:
            results = json.load(f)
        return results.get("config", {})
    return {}


# ==================== DATASET ====================

class TestBiomassDataset(Dataset):
    """Dataset for test/inference."""
    
    def __init__(self, df: pd.DataFrame, image_dir: str, transform: A.Compose):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
        row = self.df.iloc[idx]
        sample_id = row["sample_id_prefix"]
        
        img_path = os.path.join(self.image_dir, f"{sample_id}.jpg")
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        H, W = img.shape[:2]
        mid = W // 2
        img_left = img[:, :mid, :]
        img_right = img[:, mid:, :]
        
        if self.transform:
            aug_left = self.transform(image=img_left)
            aug_right = self.transform(image=img_right)
            img_left = aug_left["image"]
            img_right = aug_right["image"]
        
        return img_left, img_right, sample_id


def get_val_transform(img_size: Optional[int] = None) -> A.Compose:
    """Validation transform matching dinov3_train.py exactly."""
    if img_size is None:
        img_size = CFG.IMG_SIZE
    return A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


def get_tta_transforms(img_size: Optional[int] = None, level: str = "default") -> List[A.Compose]:
    """
    TTA transforms with configurable levels.
    
    Args:
        img_size: Image size (default: CFG.IMG_SIZE)
        level: TTA level - "none", "light", "default", "heavy", "extreme"
    
    Returns:
        List of albumentations Compose transforms
    """
    if img_size is None:
        img_size = CFG.IMG_SIZE
    
    # Base transform (always included)
    base = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    if level == "none":
        return [base]
    
    # Horizontal flip
    hflip = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.HorizontalFlip(p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    if level == "light":
        return [base, hflip]
    
    # Brightness/Contrast adjustment (slightly brighter)
    bright = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.RandomBrightnessContrast(brightness_limit=(0.08, 0.12), contrast_limit=(0.08, 0.12), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    if level == "default":
        return [base, hflip, bright]
    
    # Vertical flip (vegetation is somewhat invariant to this)
    vflip = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.VerticalFlip(p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    # Darker version (simulates different lighting)
    darker = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.RandomBrightnessContrast(brightness_limit=(-0.12, -0.08), contrast_limit=(-0.05, 0.05), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    if level == "heavy":
        return [base, hflip, bright, vflip, darker]
    
    # 90° rotation
    rotate90 = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.Rotate(limit=(90, 90), p=1.0, border_mode=cv2.BORDER_REFLECT_101),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    # Hue shift (simulates different camera settings)
    hue_shift = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    # Gamma adjustment
    gamma = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
        A.RandomGamma(gamma_limit=(90, 110), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    if level == "extreme":
        return [base, hflip, bright, vflip, darker, rotate90, hue_shift, gamma]
    
    # Default fallback
    return [base, hflip, bright]


# ==================== MODEL COMPONENTS ====================

def build_dinov3_backbone(pretrained: bool = False) -> Tuple[nn.Module, int, int]:
    """Build DINOv3 backbone (vit_base_patch16_dinov3)."""
    name = "vit_base_patch16_dinov3"
    model = timm.create_model(name, pretrained=pretrained, num_classes=0)
    feat_dim = model.num_features  # 768 for ViT-B
    input_res = 256  # DINOv3 default
    return model, feat_dim, input_res


def _make_edges(length: int, n_parts: int) -> List[Tuple[int, int]]:
    step = length // n_parts
    edges = [(i * step, (i + 1) * step) for i in range(n_parts)]
    edges[-1] = (edges[-1][0], length)
    return edges


class FiLM(nn.Module):
    def __init__(self, in_dim: int) -> None:
        super().__init__()
        hidden = max(64, in_dim // 2)
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, in_dim * 2),
        )
    
    def forward(self, context: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        gb = self.mlp(context)
        gamma, beta = torch.chunk(gb, 2, dim=1)
        return gamma, beta


class AttentionPooling(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.scale = dim ** -0.5
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.query(x.mean(dim=1, keepdim=True))
        k = self.key(x)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        return (attn @ x).squeeze(1)


# ==================== INNOVATIVE FEATURES ====================

class VegetationIndices(nn.Module):
    """Compute vegetation indices (ExG, ExR, GRVI) from RGB image."""
    
    def __init__(self, out_dim: int = 24) -> None:
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(24, out_dim),
            nn.GELU(),
        )
    
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        mean = torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1)
        img_denorm = (img * std + mean).clamp(0, 1)
        
        r, g, b = img_denorm.unbind(dim=1)
        exg = 2 * g - r - b
        grvi = (g - r) / (g + r + 1e-6)
        vari = (g - r) / (g + r - b + 1e-6)
        exr = 1.4 * r - g
        exgr = exg - exr
        norm_g = g / (r + g + b + 1e-6)
        
        indices = torch.stack([exg, exr, exgr, grvi, norm_g, vari], dim=1)
        feats = []
        for i in range(indices.size(1)):
            idx = indices[:, i]
            feats.extend([
                idx.mean(dim=(-2, -1)),
                idx.std(dim=(-2, -1)),
                idx.flatten(1).quantile(0.1, dim=1),
                idx.flatten(1).quantile(0.9, dim=1),
            ])
        stats = torch.stack(feats, dim=1)
        return self.proj(stats)


class DisparityFeatures(nn.Module):
    """Extract stereo disparity features from tile features."""
    
    def __init__(self, feat_dim: int, max_disparity: int = 8, out_dim: int = None) -> None:
        super().__init__()
        self.max_disparity = max_disparity
        if out_dim is None:
            out_dim = feat_dim // 4
        self.proj = nn.Sequential(
            nn.Linear(max_disparity, out_dim),
            nn.GELU(),
        )
        self.diff_proj = nn.Sequential(
            nn.Linear(6, out_dim // 2),
            nn.GELU(),
        )
        self.out_dim = out_dim + out_dim // 2
    
    def forward(self, feat_left: torch.Tensor, feat_right: torch.Tensor) -> torch.Tensor:
        B, N, D = feat_left.shape
        feat_l = F.normalize(feat_left, dim=-1)
        feat_r = F.normalize(feat_right, dim=-1)
        
        correlations = []
        for d in range(self.max_disparity):
            shifted_r = torch.roll(feat_r, shifts=d, dims=1)
            corr = (feat_l * shifted_r).sum(dim=-1).mean(dim=1)
            correlations.append(corr)
        corr_volume = torch.stack(correlations, dim=-1)
        corr_feat = self.proj(corr_volume)
        
        fl_pooled = feat_left.mean(dim=1)
        fr_pooled = feat_right.mean(dim=1)
        fl_norm = F.normalize(fl_pooled, dim=-1)
        fr_norm = F.normalize(fr_pooled, dim=-1)
        
        correlation = (fl_norm * fr_norm).sum(dim=-1, keepdim=True)
        diff = fl_pooled - fr_pooled
        diff_norm = diff.norm(dim=-1, keepdim=True)
        diff_mean = diff.mean(dim=-1, keepdim=True)
        diff_std = diff.std(dim=-1, keepdim=True)
        ratio = fl_pooled / (fr_pooled + 1e-6)
        ratio_mean = ratio.mean(dim=-1, keepdim=True)
        ratio_std = ratio.std(dim=-1, keepdim=True)
        
        stats = torch.cat([correlation, diff_norm, diff_mean, diff_std, ratio_mean, ratio_std], dim=-1)
        diff_feat = self.diff_proj(stats)
        return torch.cat([corr_feat, diff_feat], dim=-1)


class DepthFeatures(nn.Module):
    """
    Extract depth-based features using Depth Anything V2.
    
    Uses frozen DA2 model to generate depth maps, then extracts statistics.
    Key features: depth_gradient (r=0.63 with green), depth_mean, depth_range, depth_volume.
    
    For Kaggle submission: depth model is loaded upfront and weights come from checkpoint.
    """
    
    def __init__(self, out_dim: int = 32, model_size: str = "small", depth_model_path: Optional[str] = None) -> None:
        super().__init__()
        self.out_dim = out_dim
        self.model_size = model_size
        
        # Load depth model upfront (weights will come from checkpoint)
        from transformers import AutoModelForDepthEstimation
        
        # Use local path if provided (for Kaggle offline), else HuggingFace
        if depth_model_path and os.path.exists(depth_model_path):
            print(f"    Loading depth model from local: {depth_model_path}")
            self._depth_model = AutoModelForDepthEstimation.from_pretrained(depth_model_path, local_files_only=True)
        else:
            model_names = {
                "small": "depth-anything/Depth-Anything-V2-Small-hf",
                "base": "depth-anything/Depth-Anything-V2-Base-hf",
            }
            model_name = model_names.get(model_size, model_names["small"])
            print(f"    Loading depth model from HuggingFace: {model_name}")
            self._depth_model = AutoModelForDepthEstimation.from_pretrained(model_name)
        
        self._depth_model.eval()
        
        # Freeze depth model
        for p in self._depth_model.parameters():
            p.requires_grad = False
        
        # Project depth statistics to feature space
        # 10 stats per view × 2 views + 2 stereo stats = 22 features
        self.proj = nn.Sequential(
            nn.Linear(22, out_dim),
            nn.GELU(),
            nn.Linear(out_dim, out_dim),
        )
    
    @torch.no_grad()
    def _get_depth_map(self, img: torch.Tensor) -> torch.Tensor:
        """Get depth map from image tensor."""
        device = img.device
        
        # Move depth model to same device if needed
        if next(self._depth_model.parameters()).device != device:
            self._depth_model = self._depth_model.to(device)
        
        # Denormalize from ImageNet to [0, 1]
        mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
        img_denorm = (img * std + mean).clamp(0, 1)
        
        # DA2 expects specific preprocessing - resize to 518
        B, _, H, W = img.shape
        img_resized = F.interpolate(img_denorm, size=(518, 518), mode="bilinear", align_corners=False)
        
        # Apply DA2 normalization
        da2_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
        da2_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
        img_normalized = (img_resized - da2_mean) / da2_std
        
        # Get depth
        outputs = self._depth_model(pixel_values=img_normalized)
        depth = outputs.predicted_depth  # (B, h, w)
        
        # Resize back to original
        depth = F.interpolate(
            depth.unsqueeze(1), size=(H, W), mode="bilinear", align_corners=False
        ).squeeze(1)
        
        return depth
    
    def _compute_stats(self, depth: torch.Tensor) -> torch.Tensor:
        """Compute statistics from depth map."""
        B, H, W = depth.shape
        flat = depth.view(B, -1)
        
        # Basic statistics
        depth_mean = flat.mean(dim=1)
        depth_std = flat.std(dim=1)
        depth_min = flat.min(dim=1).values
        depth_max = flat.max(dim=1).values
        depth_range = depth_max - depth_min
        
        # Percentiles
        depth_p10 = flat.quantile(0.1, dim=1)
        depth_p90 = flat.quantile(0.9, dim=1)
        
        # Gradient (vegetation boundaries) - KEY FEATURE (r=0.63)
        grad_y = torch.abs(depth[:, 1:, :] - depth[:, :-1, :]).mean(dim=(1, 2))
        grad_x = torch.abs(depth[:, :, 1:] - depth[:, :, :-1]).mean(dim=(1, 2))
        depth_gradient = grad_y + grad_x
        
        # Volume proxy (sum above minimum)
        depth_volume = (flat - depth_min.unsqueeze(1)).mean(dim=1)
        
        # High depth ratio
        threshold = flat.quantile(0.75, dim=1, keepdim=True)
        depth_high_ratio = (flat > threshold).float().mean(dim=1)
        
        return torch.stack([
            depth_mean, depth_std, depth_min, depth_max, depth_range,
            depth_p10, depth_p90, depth_gradient, depth_volume, depth_high_ratio
        ], dim=1)  # (B, 10)
    
    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> torch.Tensor:
        """Extract depth features from stereo images."""
        # Get depth maps
        depth_left = self._get_depth_map(x_left)
        depth_right = self._get_depth_map(x_right)
        
        # Compute per-view statistics
        stats_left = self._compute_stats(depth_left)    # (B, 10)
        stats_right = self._compute_stats(depth_right)  # (B, 10)
        
        # Stereo statistics (L-R difference as disparity proxy)
        depth_lr_diff = torch.abs(depth_left - depth_right).mean(dim=(1, 2)).unsqueeze(1)  # (B, 1)
        depth_lr_corr = F.cosine_similarity(
            depth_left.flatten(1), depth_right.flatten(1), dim=1
        ).unsqueeze(1)  # (B, 1)
        
        # Combine all features
        all_stats = torch.cat([
            stats_left, stats_right, depth_lr_diff, depth_lr_corr
        ], dim=1)  # (B, 22)
        
        return self.proj(all_stats)  # (B, out_dim)


class DepthGuidedAttention(nn.Module):
    """
    Depth-guided attention for tile pooling.
    Uses depth maps to weight which spatial regions contribute more to predictions.
    
    For Kaggle submission: depth model is loaded upfront and weights come from checkpoint.
    """
    
    def __init__(
        self, 
        feat_dim: int, 
        grid: Tuple[int, int] = (2, 2),
        model_size: str = "small",
        depth_model_path: Optional[str] = None
    ) -> None:
        super().__init__()
        self.feat_dim = feat_dim
        self.grid = grid
        self.num_tiles = grid[0] * grid[1]
        
        # Load depth model upfront (weights will come from checkpoint)
        from transformers import AutoModelForDepthEstimation
        
        # Use local path if provided (for Kaggle offline), else HuggingFace
        if depth_model_path and os.path.exists(depth_model_path):
            print(f"    Loading depth model from local: {depth_model_path}")
            self._depth_model = AutoModelForDepthEstimation.from_pretrained(depth_model_path, local_files_only=True)
        else:
            model_name = f"depth-anything/Depth-Anything-V2-{model_size.capitalize()}-hf"
            print(f"    Loading depth model from HuggingFace: {model_name}")
            self._depth_model = AutoModelForDepthEstimation.from_pretrained(model_name)
        
        self._depth_model.eval()
        
        # Freeze depth model
        for p in self._depth_model.parameters():
            p.requires_grad = False
        
        # Depth stats per tile (5 stats: mean, max, gradient, volume, high_ratio)
        self.depth_stats_dim = 5
        
        # Depth → attention weight (per tile)
        self.depth_to_attn = nn.Sequential(
            nn.Linear(self.depth_stats_dim, 32),
            nn.GELU(),
            nn.Linear(32, 1),
        )
        
        # Feature-based attention (like original)
        self.query = nn.Linear(feat_dim, feat_dim)
        self.key = nn.Linear(feat_dim, feat_dim)
        self.scale = feat_dim ** -0.5
        
        # Combine depth attention and feature attention
        self.gate = nn.Parameter(torch.tensor(0.5))
    
    @torch.no_grad()
    def _get_depth_map(self, img: torch.Tensor) -> torch.Tensor:
        """Get depth map from normalized image."""
        device = img.device
        B, C, H, W = img.shape
        
        # Move depth model to same device if needed
        if next(self._depth_model.parameters()).device != device:
            self._depth_model = self._depth_model.to(device)
        
        # De-normalize from ImageNet
        mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
        img_denorm = (img * std + mean).clamp(0, 1)
        
        # DA2 expects specific preprocessing - resize to 518
        img_resized = F.interpolate(img_denorm, size=(518, 518), mode="bilinear", align_corners=False)
        
        # Apply DA2 normalization
        da2_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
        da2_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
        img_normalized = (img_resized - da2_mean) / da2_std
        
        # Get depth
        outputs = self._depth_model(pixel_values=img_normalized)
        depth = outputs.predicted_depth
        
        # Resize to match input
        if depth.shape[-2:] != (H, W):
            depth = F.interpolate(
                depth.unsqueeze(1), size=(H, W), mode="bilinear", align_corners=False
            ).squeeze(1)
        
        return depth
    
    def _compute_tile_stats(self, depth: torch.Tensor) -> torch.Tensor:
        """Compute per-tile depth statistics."""
        B, H, W = depth.shape
        r, c = self.grid
        
        tile_h = H // r
        tile_w = W // c
        
        stats_list = []
        for i in range(r):
            for j in range(c):
                tile = depth[:, i*tile_h:(i+1)*tile_h, j*tile_w:(j+1)*tile_w]
                flat = tile.reshape(B, -1)
                
                tile_mean = flat.mean(dim=1)
                tile_max = flat.max(dim=1).values
                
                # Gradient
                grad_y = torch.abs(tile[:, 1:, :] - tile[:, :-1, :]).mean(dim=(1, 2))
                grad_x = torch.abs(tile[:, :, 1:] - tile[:, :, :-1]).mean(dim=(1, 2))
                tile_gradient = grad_y + grad_x
                
                # Volume
                tile_min = flat.min(dim=1).values
                tile_volume = (flat - tile_min.unsqueeze(1)).mean(dim=1)
                
                # High ratio
                threshold = flat.quantile(0.75, dim=1, keepdim=True)
                tile_high_ratio = (flat > threshold).float().mean(dim=1)
                
                stats = torch.stack([tile_mean, tile_max, tile_gradient, tile_volume, tile_high_ratio], dim=1)
                stats_list.append(stats)
        
        return torch.stack(stats_list, dim=1)  # (B, num_tiles, 5)
    
    def forward(self, x: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, num_tiles, D) tile features
            img: (B, 3, H, W) original image for depth
        Returns:
            pooled: (B, D)
        """
        B, N, D = x.shape
        
        # Get depth-based attention weights
        depth_map = self._get_depth_map(img)  # (B, H, W)
        tile_stats = self._compute_tile_stats(depth_map)  # (B, num_tiles, 5)
        depth_attn = self.depth_to_attn(tile_stats).squeeze(-1)  # (B, num_tiles)
        depth_attn = F.softmax(depth_attn, dim=-1)
        
        # Get feature-based attention weights
        q = self.query(x.mean(dim=1, keepdim=True))  # (B, 1, D)
        k = self.key(x)  # (B, N, D)
        feat_attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, 1, N)
        feat_attn = F.softmax(feat_attn, dim=-1).squeeze(1)  # (B, N)
        
        # Combine attention weights
        gate = torch.sigmoid(self.gate)
        combined_attn = gate * depth_attn + (1 - gate) * feat_attn
        combined_attn = combined_attn / combined_attn.sum(dim=-1, keepdim=True)
        
        # Apply attention
        pooled = (combined_attn.unsqueeze(-1) * x).sum(dim=1)  # (B, D)
        
        return pooled


class LearnableAugmentation(nn.Module):
    """Learnable augmentation module (identity at inference time).
    
    Includes all strong augmentations: color, spatial, blur, CLAHE.
    """
    
    def __init__(
        self,
        enable_color: bool = True,
        enable_spatial: bool = False,
        enable_blur: bool = True,
        enable_local_contrast: bool = True,
        color_strength: float = 0.25,
        spatial_strength: float = 0.15,
        noise_std: float = 0.1,
    ) -> None:
        super().__init__()
        self.enable_color = enable_color
        self.enable_spatial = enable_spatial
        self.enable_blur = enable_blur
        self.enable_local_contrast = enable_local_contrast
        
        if enable_color:
            self.color_params = nn.Parameter(torch.zeros(6))
            self.color_predictor = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(3, 32),
                nn.GELU(),
                nn.Linear(32, 6),
                nn.Tanh(),
            )
        
        if enable_spatial:
            self.spatial_params = nn.Parameter(torch.zeros(5))
            self.spatial_predictor = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(3, 32),
                nn.GELU(),
                nn.Linear(32, 5),
                nn.Tanh(),
            )
        
        if enable_blur:
            self.blur_params = nn.Parameter(torch.zeros(2))
            self.blur_predictor = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(3, 16),
                nn.GELU(),
                nn.Linear(16, 2),
                nn.Sigmoid(),
            )
            self.blur_kernel = nn.Parameter(torch.tensor([
                [1., 2., 1.],
                [2., 4., 2.],
                [1., 2., 1.],
            ]) / 16.0)
        
        if enable_local_contrast:
            self.contrast_params = nn.Parameter(torch.zeros(2))
            self.contrast_predictor = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(3, 16),
                nn.GELU(),
                nn.Linear(16, 2),
                nn.Sigmoid(),
            )
    
    def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Identity during inference (eval mode)
        return img, torch.tensor(0.0, device=img.device)


# ==================== DINOv3 DIRECT MODEL ====================

class DINOv3Direct(nn.Module):
    """
    DINOv3 Direct Model for Biomass Prediction.
    Predicts Total, Green, GDM directly; optionally predicts Dead, Clover.
    Components always sum to Total (mathematically guaranteed).
    
    Optional features:
    - Vegetation Indices (VI): ExG, ExR, GRVI etc.
    - Stereo Disparity: 3D volume features from stereo correspondence
    - Depth Features: Depth Anything V2 depth maps (r=0.63 correlation with green!)
    - Depth-guided Attention: Uses depth to weight tile attention
    """
    
    # Auxiliary head class counts (must match dataset.py)
    NUM_STATES = 4
    NUM_MONTHS = 10
    NUM_SPECIES = 8
    
    def __init__(
        self,
        grid: Tuple[int, int] = (2, 2),
        pretrained: bool = False,
        dropout: float = 0.3,
        hidden_ratio: float = 0.25,
        use_film: bool = True,
        use_attention_pool: bool = True,
        train_dead: bool = False,
        train_clover: bool = False,
        use_vegetation_indices: bool = False,
        use_disparity: bool = False,
        use_depth: bool = False,
        depth_model_size: str = "small",
        use_depth_attention: bool = False,
        use_learnable_aug: bool = False,
        learnable_aug_color: bool = True,
        learnable_aug_spatial: bool = False,
        depth_model_path: Optional[str] = None,  # For Kaggle offline: local path to depth model
        use_presence_heads: bool = False,  # Binary presence for Dead/Clover
        use_ndvi_head: bool = False,  # NDVI auxiliary head
        use_height_head: bool = False,  # Height auxiliary head
        use_species_head: bool = False,  # Species classification head
        ckpt_path: Optional[str] = None,  # For custom backbone weights (ignored if pretrained=False)
    ) -> None:
        super().__init__()
        
        # Build DINOv3 backbone
        self.backbone, self.feat_dim, self.input_res = build_dinov3_backbone(pretrained)
        self.grid = tuple(grid)
        self.use_film = use_film
        self.use_attention_pool = use_attention_pool
        self.train_dead = train_dead
        self.train_clover = train_clover
        self.use_vegetation_indices = use_vegetation_indices
        self.use_disparity = use_disparity
        self.use_depth = use_depth
        self.use_depth_attention = use_depth_attention
        self.use_learnable_aug = use_learnable_aug
        self.depth_model_size = depth_model_size
        self.use_presence_heads = use_presence_heads
        self.use_ndvi_head = use_ndvi_head
        self.use_height_head = use_height_head
        self.use_species_head = use_species_head
        self.hidden_dim = max(64, int((self.feat_dim * 2) * hidden_ratio))
        
        # Learnable augmentation (identity at inference)
        # Includes all strong augmentations: color, spatial, blur, CLAHE
        if use_learnable_aug:
            self.learnable_aug_left = LearnableAugmentation(
                enable_color=learnable_aug_color,
                enable_spatial=learnable_aug_spatial,
                enable_blur=learnable_aug_color,
                enable_local_contrast=learnable_aug_color,
            )
            self.learnable_aug_right = LearnableAugmentation(
                enable_color=learnable_aug_color,
                enable_spatial=learnable_aug_spatial,
                enable_blur=learnable_aug_color,
                enable_local_contrast=learnable_aug_color,
            )
        
        # FiLM for cross-view conditioning
        if use_film:
            self.film_left = FiLM(self.feat_dim)
            self.film_right = FiLM(self.feat_dim)
        
        # Attention pooling for tiles (or depth-guided attention)
        if use_depth_attention:
            self.attn_pool_left = DepthGuidedAttention(self.feat_dim, grid=grid, model_size=depth_model_size, depth_model_path=depth_model_path)
            self.attn_pool_right = DepthGuidedAttention(self.feat_dim, grid=grid, model_size=depth_model_size, depth_model_path=depth_model_path)
        elif use_attention_pool:
            self.attn_pool_left = AttentionPooling(self.feat_dim)
            self.attn_pool_right = AttentionPooling(self.feat_dim)
        
        # === Optional feature modules ===
        extra_dim = 0
        if use_vegetation_indices:
            vi_out = 24
            self.vi_left = VegetationIndices(out_dim=vi_out)
            self.vi_right = VegetationIndices(out_dim=vi_out)
            extra_dim += vi_out * 2
        
        if use_disparity:
            self.disparity_module = DisparityFeatures(self.feat_dim, max_disparity=8, out_dim=self.feat_dim // 4)
            extra_dim += self.disparity_module.out_dim
        
        # Depth Features (Depth Anything V2)
        if use_depth:
            depth_out_dim = 32
            self.depth_module = DepthFeatures(out_dim=depth_out_dim, model_size=depth_model_size, depth_model_path=depth_model_path)
            extra_dim += depth_out_dim
        
        # Head dimensions
        combined_dim = self.feat_dim * 2 + extra_dim
        hidden_dim = max(64, int((self.feat_dim * 2) * hidden_ratio))
        
        # Shared projection
        self.shared_proj = nn.Sequential(
            nn.LayerNorm(combined_dim),
            nn.Linear(combined_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        
        # Prediction heads
        def _make_head() -> nn.Sequential:
            return nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout * 0.5),
                nn.Linear(hidden_dim, 1),
            )
        
        self.head_total = _make_head()
        self.head_green = _make_head()
        self.head_gdm = _make_head()
        
        # Optional heads for Dead and Clover
        self.head_dead = _make_head() if train_dead else None
        self.head_clover = _make_head() if train_clover else None
        
        # Presence heads: Binary classification for "has Dead?" / "has Clover?"
        if use_presence_heads:
            self.head_dead_presence = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, 1),
            )
            self.head_clover_presence = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, 1),
            )
        
        # NDVI auxiliary head
        if use_ndvi_head:
            self.head_ndvi = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, 1),
                nn.Sigmoid(),
            )
        
        # Height auxiliary head
        if use_height_head:
            self.head_height = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, 1),
                nn.Softplus(),
            )
        
        # Species classification head
        if use_species_head:
            self.head_species_only = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, self.NUM_SPECIES),
            )
        
        self.softplus = nn.Softplus(beta=1.0)
    
    def _collect_tiles(self, x: torch.Tensor) -> List[torch.Tensor]:
        """Split image into grid of tiles."""
        _, _, H, W = x.shape
        r, c = self.grid
        rows = _make_edges(H, r)
        cols = _make_edges(W, c)
        
        tiles = []
        for rs, re in rows:
            for cs, ce in cols:
                tile = x[:, :, rs:re, cs:ce]
                if tile.shape[-2:] != (self.input_res, self.input_res):
                    tile = F.interpolate(
                        tile,
                        size=(self.input_res, self.input_res),
                        mode="bilinear",
                        align_corners=False,
                    )
                tiles.append(tile)
        return tiles
    
    def _extract_features(
        self, x_left: torch.Tensor, x_right: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract tile features from both views in one backbone call."""
        B = x_left.size(0)
        
        tiles_left = self._collect_tiles(x_left)
        tiles_right = self._collect_tiles(x_right)
        num_tiles = len(tiles_left)
        
        # Process all tiles in one forward pass
        all_tiles = torch.cat(tiles_left + tiles_right, dim=0)
        all_feats = self.backbone(all_tiles)
        
        # Reshape
        all_feats = all_feats.view(2 * num_tiles, B, -1).permute(1, 0, 2)
        feats_left = all_feats[:, :num_tiles, :]
        feats_right = all_feats[:, num_tiles:, :]
        
        return feats_left, feats_right
    
    def forward(
        self, x_left: torch.Tensor, x_right: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass."""
        # Extract tile features
        tiles_left, tiles_right = self._extract_features(x_left, x_right)
        
        # Stereo Disparity Features (before FiLM)
        disp_feat = None
        if self.use_disparity:
            disp_feat = self.disparity_module(tiles_left, tiles_right)
        
        # Context for FiLM
        ctx_left = tiles_left.mean(dim=1)
        ctx_right = tiles_right.mean(dim=1)
        
        # Apply FiLM cross-conditioning
        if self.use_film:
            gamma_l, beta_l = self.film_left(ctx_right)
            gamma_r, beta_r = self.film_right(ctx_left)
            tiles_left = tiles_left * (1 + gamma_l.unsqueeze(1)) + beta_l.unsqueeze(1)
            tiles_right = tiles_right * (1 + gamma_r.unsqueeze(1)) + beta_r.unsqueeze(1)
        
        # Pool tiles
        if self.use_depth_attention:
            # Depth-guided attention needs the original images
            f_left = self.attn_pool_left(tiles_left, x_left)
            f_right = self.attn_pool_right(tiles_right, x_right)
        elif self.use_attention_pool:
            f_left = self.attn_pool_left(tiles_left)
            f_right = self.attn_pool_right(tiles_right)
        else:
            f_left = tiles_left.mean(dim=1)
            f_right = tiles_right.mean(dim=1)
        
        # Combine features
        features_list = [f_left, f_right]
        
        if self.use_vegetation_indices:
            vi_left = self.vi_left(x_left)
            vi_right = self.vi_right(x_right)
            features_list.extend([vi_left, vi_right])
        
        if disp_feat is not None:
            features_list.append(disp_feat)
        
        # Depth Features (Depth Anything V2)
        if self.use_depth:
            depth_feat = self.depth_module(x_left, x_right)
            features_list.append(depth_feat)
        
        f = torch.cat(features_list, dim=1)
        f = self.shared_proj(f)
        
        # Core predictions
        total_raw = self.softplus(self.head_total(f))
        green_raw = self.softplus(self.head_green(f))
        gdm_raw = self.softplus(self.head_gdm(f))
        
        # Enforce constraints: Total >= GDM >= Green
        total = total_raw
        gdm = torch.minimum(gdm_raw, total)
        green = torch.minimum(green_raw, gdm)
        
        # Presence probabilities for gating (if enabled)
        dead_presence_logit = None
        clover_presence_logit = None
        if self.use_presence_heads:
            dead_presence_logit = self.head_dead_presence(f)
            clover_presence_logit = self.head_clover_presence(f)
            dead_presence_prob = torch.sigmoid(dead_presence_logit)
            clover_presence_prob = torch.sigmoid(clover_presence_logit)
        
        # Dead: predicted or derived
        if self.head_dead is not None:
            dead_raw = self.softplus(self.head_dead(f))
            dead = torch.minimum(dead_raw, total - gdm + 1e-6)
            dead = F.relu(dead)
            if self.use_presence_heads:
                dead = dead * dead_presence_prob
        else:
            dead = F.relu(total - gdm)
        
        # Clover: predicted or derived
        if self.head_clover is not None:
            clover_raw = self.softplus(self.head_clover(f))
            clover = torch.minimum(clover_raw, gdm - green + 1e-6)
            clover = F.relu(clover)
            if self.use_presence_heads:
                clover = clover * clover_presence_prob
        else:
            clover = F.relu(gdm - green)
        
        # Auxiliary predictions
        ndvi_pred = None
        if self.use_ndvi_head:
            ndvi_pred = self.head_ndvi(f)
        
        height_pred = None
        if self.use_height_head:
            height_pred = self.head_height(f)
        
        species_logits = None
        if self.use_species_head:
            species_logits = self.head_species_only(f)
        
        # Return with auxiliary outputs if any enabled
        aux_loss = torch.tensor(0.0, device=x_left.device)
        if self.use_presence_heads or self.use_ndvi_head or self.use_height_head or self.use_species_head:
            return green, dead, clover, gdm, total, aux_loss, dead_presence_logit, clover_presence_logit, ndvi_pred, height_pred, species_logits
        
        return green, dead, clover, gdm, total, aux_loss


# ==================== MODEL LOADING ====================

def _strip_module_prefix(sd: dict) -> dict:
    """Remove 'module.' prefix from state dict keys (for DDP-trained models)."""
    if not sd:
        return sd
    keys = list(sd.keys())
    if all(k.startswith("module.") for k in keys):
        return {k[len("module."):]: v for k, v in sd.items()}
    return sd


def _filter_depth_model_keys(sd: dict) -> dict:
    """
    Filter out embedded depth model weights from old checkpoints.
    
    Old checkpoints saved depth model weights as part of state_dict.
    New code loads depth model separately, so these keys are unexpected.
    """
    if not sd:
        return sd
    
    filtered = {}
    removed_count = 0
    for k, v in sd.items():
        if "._depth_model." in k:
            removed_count += 1
        else:
            filtered[k] = v
    
    if removed_count > 0:
        print(f"  Filtered {removed_count} embedded depth model keys from checkpoint")
    
    return filtered


def _detect_model_config(sd_keys: set) -> dict:
    """Auto-detect model config from checkpoint keys."""
    # Detect learnable aug color/spatial from keys
    has_learnable_aug = any(k.startswith("learnable_aug_left.") for k in sd_keys)
    learnable_aug_color = any("color_params" in k for k in sd_keys if k.startswith("learnable_aug_left."))
    learnable_aug_spatial = any("spatial_params" in k for k in sd_keys if k.startswith("learnable_aug_left."))
    
    # Detect depth-guided attention (has depth_to_attn in attn_pool)
    has_depth_attention = any("depth_to_attn" in k for k in sd_keys if k.startswith("attn_pool_left."))
    
    # Detect new auxiliary heads
    has_presence_heads = any(k.startswith("head_dead_presence.") for k in sd_keys)
    has_ndvi_head = any(k.startswith("head_ndvi.") for k in sd_keys)
    has_height_head = any(k.startswith("head_height.") for k in sd_keys)
    has_species_head = any(k.startswith("head_species_only.") for k in sd_keys)
    
    return {
        "use_film": any(k.startswith("film_left.") for k in sd_keys),
        "use_attention_pool": any(k.startswith("attn_pool_left.") for k in sd_keys),
        "train_dead": any(k.startswith("head_dead.") for k in sd_keys),
        "train_clover": any(k.startswith("head_clover.") for k in sd_keys),
        "use_vegetation_indices": any(k.startswith("vi_left.") for k in sd_keys),
        "use_disparity": any(k.startswith("disparity_module.") for k in sd_keys),
        "use_depth": any(k.startswith("depth_module.") for k in sd_keys),
        "use_depth_attention": has_depth_attention,
        "use_learnable_aug": has_learnable_aug,
        "learnable_aug_color": learnable_aug_color,
        "learnable_aug_spatial": learnable_aug_spatial,
        "use_presence_heads": has_presence_heads,
        "use_ndvi_head": has_ndvi_head,
        "use_height_head": has_height_head,
        "use_species_head": has_species_head,
    }


def get_depth_model_path() -> Optional[str]:
    """
    Get the path to the depth model for inference.
    
    Priority:
    1. CFG.DEPTH_MODEL_PATH if explicitly set
    2. MODEL_DIR/depth_model if exists (saved during training)
    3. None (will download from HuggingFace - requires internet)
    """
    # Check explicit config first
    if hasattr(CFG, 'DEPTH_MODEL_PATH') and CFG.DEPTH_MODEL_PATH:
        if os.path.exists(CFG.DEPTH_MODEL_PATH):
            return CFG.DEPTH_MODEL_PATH
        print(f"Warning: DEPTH_MODEL_PATH set but not found: {CFG.DEPTH_MODEL_PATH}")
    
    # Check MODEL_DIR/depth_model
    local_path = os.path.join(CFG.MODEL_DIR, "depth_model")
    if os.path.exists(local_path):
        return local_path
    
    # Fallback: will need internet to download from HuggingFace
    return None


def load_fold_model(
    path: str,
    grid: Tuple[int, int] = (2, 2),
    dropout: float = 0.3,
    hidden_ratio: float = 0.25,
    use_film: bool = True,
    use_attention_pool: bool = True,
    train_dead: bool = False,
    train_clover: bool = False,
    use_vegetation_indices: bool = False,
    use_disparity: bool = False,
    use_depth: bool = False,
    depth_model_size: str = "small",
    use_depth_attention: bool = False,
    use_learnable_aug: bool = False,
    learnable_aug_color: bool = True,
    learnable_aug_spatial: bool = False,
    use_presence_heads: bool = False,
    use_ndvi_head: bool = False,
    use_height_head: bool = False,
    use_species_head: bool = False,
    depth_model_path: Optional[str] = None,  # For Kaggle offline
) -> nn.Module:
    """Load a DINOv3Direct model checkpoint."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    
    try:
        raw_sd = torch.load(path, map_location="cpu", weights_only=True)
    except TypeError:
        raw_sd = torch.load(path, map_location="cpu")
    
    sd = _strip_module_prefix(raw_sd)
    
    # Detect config BEFORE filtering (depth model keys indicate features used)
    sd_keys = set(sd.keys())
    
    # Auto-detect config from checkpoint
    detected_config = _detect_model_config(sd_keys)
    
    # Filter out embedded depth model keys from old checkpoints
    # (new code loads depth model separately)
    sd = _filter_depth_model_keys(sd)
    use_film = detected_config.get("use_film", use_film)
    use_attention_pool = detected_config.get("use_attention_pool", use_attention_pool)
    train_dead = detected_config.get("train_dead", train_dead)
    train_clover = detected_config.get("train_clover", train_clover)
    use_vegetation_indices = detected_config.get("use_vegetation_indices", use_vegetation_indices)
    use_disparity = detected_config.get("use_disparity", use_disparity)
    use_depth = detected_config.get("use_depth", use_depth)
    use_depth_attention = detected_config.get("use_depth_attention", use_depth_attention)
    use_learnable_aug = detected_config.get("use_learnable_aug", use_learnable_aug)
    learnable_aug_color = detected_config.get("learnable_aug_color", learnable_aug_color)
    learnable_aug_spatial = detected_config.get("learnable_aug_spatial", learnable_aug_spatial)
    use_presence_heads = detected_config.get("use_presence_heads", use_presence_heads)
    use_ndvi_head = detected_config.get("use_ndvi_head", use_ndvi_head)
    use_height_head = detected_config.get("use_height_head", use_height_head)
    use_species_head = detected_config.get("use_species_head", use_species_head)
    
    # DINOv3Direct is defined locally in this notebook (no src import needed for Kaggle)
    
    # DINOv3Direct model - use pretrained=False since we load trained weights from checkpoint
    # Depth model loaded from local path (Kaggle) or HuggingFace
    model = DINOv3Direct(
        grid=grid,
        pretrained=False,  # Don't download - backbone weights come from checkpoint
        dropout=dropout,
        hidden_ratio=hidden_ratio,
        use_film=use_film,
        use_attention_pool=use_attention_pool,
        train_dead=train_dead,
        train_clover=train_clover,
        use_vegetation_indices=use_vegetation_indices,
        use_disparity=use_disparity,
        use_depth=use_depth,
        depth_model_size=depth_model_size,
        use_depth_attention=use_depth_attention,
        use_learnable_aug=use_learnable_aug,
        learnable_aug_color=learnable_aug_color,
        learnable_aug_spatial=learnable_aug_spatial,
        use_presence_heads=use_presence_heads,
        use_ndvi_head=use_ndvi_head,
        use_height_head=use_height_head,
        use_species_head=use_species_head,
        depth_model_path=depth_model_path,  # Local path for Kaggle offline (depth model)
    )
    
    # Load trained weights (depth model already loaded via pretrained=True)
    # Use strict=False since depth model keys are filtered out
    model.load_state_dict(sd, strict=False)
    model = model.to(CFG.DEVICE)
    model.eval()
    
    return model


def find_model_checkpoints(model_dir: str, top_k: Optional[int] = None) -> Tuple[List[str], List[int]]:
    """
    Find DINOv3 model checkpoints.
    
    Supports:
    - Final checkpoints: dinov3_best_fold*.pth, dinov3_top*.pth
    - In-progress checkpoints: _topk_fold*_ep*.pth (from --save-top-k)
    
    Args:
        model_dir: Directory containing model checkpoints
        top_k: If set, return only top K folds by aR² (from results.json)
    
    Returns:
        Tuple of (checkpoint_paths, fold_indices)
    """
    import re
    
    # Find all checkpoints
    all_checkpoints = []
    
    # Pattern 1: Final checkpoints (dinov3_best_fold*.pth)
    for f in os.listdir(model_dir):
        if f.startswith("dinov3_best_fold") and f.endswith(".pth"):
            fold_num = int(f.replace("dinov3_best_fold", "").replace(".pth", ""))
            all_checkpoints.append((fold_num, os.path.join(model_dir, f)))
    
    # Pattern 2: Top-K ranked checkpoints (dinov3_top*_fold*.pth)
    if not all_checkpoints:
        for f in os.listdir(model_dir):
            match = re.match(r"dinov3_top(\d+)_fold(\d+)\.pth", f)
            if match:
                rank, fold_num = int(match.group(1)), int(match.group(2))
                all_checkpoints.append((fold_num, os.path.join(model_dir, f)))
    
    # Pattern 3: In-progress checkpoints (_topk_fold*_ep*.pth) - training not finished
    # OR top-k checkpoints from --save-top-k (all from same fold in train-all mode)
    if not all_checkpoints:
        temp_checkpoints = []
        for f in os.listdir(model_dir):
            match = re.match(r"_topk_fold(\d+)_ep(\d+)\.pth", f)
            if match:
                fold_num, epoch = int(match.group(1)), int(match.group(2))
                temp_checkpoints.append((fold_num, epoch, os.path.join(model_dir, f)))
        
        if temp_checkpoints:
            # Check if train-all mode (all checkpoints are from fold 0)
            unique_folds = set(fc[0] for fc in temp_checkpoints)
            
            if len(unique_folds) == 1:
                # Train-all mode: use ALL checkpoints as ensemble (top-k from same fold)
                # Sort by epoch descending to prioritize newer checkpoints
                temp_checkpoints.sort(key=lambda x: x[1], reverse=True)
                for idx, (fold_num, epoch, path) in enumerate(temp_checkpoints):
                    all_checkpoints.append((idx, path))  # Use idx as pseudo-fold for sorting
                print(f"Using top-{len(temp_checkpoints)} checkpoints for ensemble (train-all mode)")
            else:
                # CV mode: take best epoch per fold
                from collections import defaultdict
                fold_best = defaultdict(lambda: (-1, None))
                for fold_num, epoch, path in temp_checkpoints:
                    if epoch > fold_best[fold_num][0]:
                        fold_best[fold_num] = (epoch, path)
                
                for fold_num, (epoch, path) in fold_best.items():
                    all_checkpoints.append((fold_num, path))
                print(f"Using in-progress checkpoints (training not finished)")
    
    all_checkpoints.sort(key=lambda x: x[0])  # Sort by fold number
    
    if top_k is None or top_k >= len(all_checkpoints):
        # Return all checkpoints
        return [c[1] for c in all_checkpoints], [c[0] for c in all_checkpoints]
    
    # Load results.json to get fold metrics
    results_path = os.path.join(model_dir, "results.json")
    if not os.path.exists(results_path):
        print(f"Warning: results.json not found, using all {len(all_checkpoints)} folds")
        return [c[1] for c in all_checkpoints], [c[0] for c in all_checkpoints]
    
    with open(results_path) as f:
        results = json.load(f)
    
    fold_results = results.get("folds", [])
    if not fold_results:
        print(f"Warning: No fold results in results.json, using all {len(all_checkpoints)} folds")
        return [c[1] for c in all_checkpoints], [c[0] for c in all_checkpoints]
    
    # Get fold metrics (aR² = best_r2 after our change, or avg_r2 from metrics)
    fold_metrics = []
    for fr in fold_results:
        fold_num = fr["fold"]
        # best_r2 is now aR² after our dinov3_train.py change
        # For older models, try metrics.avg_r2, then fall back to best_r2
        metrics = fr.get("metrics", {})
        ar2 = metrics.get("avg_r2", fr.get("best_r2", 0))
        fold_metrics.append((fold_num, ar2))
    
    # Sort by metric (descending) and take top K
    fold_metrics.sort(key=lambda x: x[1], reverse=True)
    top_folds = set(fm[0] for fm in fold_metrics[:top_k])
    
    print(f"Selecting top {top_k} folds by aR²:")
    for fold_num, ar2 in fold_metrics[:top_k]:
        print(f"  Fold {fold_num}: aR²={ar2:.4f}")
    
    # Filter checkpoints to top folds
    selected = [(fold_num, path) for fold_num, path in all_checkpoints if fold_num in top_folds]
    selected.sort(key=lambda x: x[0])
    
    return [c[1] for c in selected], [c[0] for c in selected]


# ==================== INFERENCE ====================

@torch.no_grad()
def predict_one_view(models: List[nn.Module], loader: DataLoader) -> np.ndarray:
    """Run inference with fold ensemble."""
    out_list = []
    
    for batch in tqdm(loader, desc="Predicting", leave=False):
        x_left, x_right, _ = batch
        x_left = x_left.to(CFG.DEVICE)
        x_right = x_right.to(CFG.DEVICE)
        
        fold_preds = []
        for model in models:
            outputs = model(x_left, x_right)
            # Handle all output formats:
            # - 5 outputs: (green, dead, clover, gdm, total) - basic model
            # - 6 outputs: + aux_loss
            # - 9 outputs: + aux_loss + state/month/species logits (use_aux_heads)
            # - 11 outputs: + aux_loss + presence/ndvi/height/species (use_presence_heads etc.)
            # First 5 outputs are always the biomass predictions
            if isinstance(outputs, tuple):
                green, dead, clover, gdm, total = outputs[:5]
            else:
                # Single tensor output (shouldn't happen but handle it)
                green, dead, clover, gdm, total = outputs[:5]
            pred = torch.cat([green, dead, clover, gdm, total], dim=1)
            fold_preds.append(pred.float().cpu().numpy())
        
        avg_pred = np.mean(fold_preds, axis=0)
        out_list.append(avg_pred)
    
    return np.concatenate(out_list, axis=0)


def run_inference_tta(models: List[nn.Module], df: pd.DataFrame) -> np.ndarray:
    """Run inference with TTA."""
    if CFG.USE_TTA:
        tta_level = getattr(CFG, 'TTA_LEVEL', 'default')
        transforms = get_tta_transforms(level=tta_level)
        print(f"TTA level: {tta_level} ({len(transforms)} views)")
    else:
        transforms = [get_val_transform()]
        print("TTA: disabled (1 view)")
    
    all_preds = []
    for i, transform in enumerate(transforms):
        print(f"  View {i+1}/{len(transforms)}...")
        ds = TestBiomassDataset(df, CFG.TEST_IMAGE_DIR, transform)
        loader = DataLoader(
            ds, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=CFG.NUM_WORKERS
        )
        preds = predict_one_view(models, loader)
        all_preds.append(preds)
    
    return np.mean(all_preds, axis=0)


def create_submission(final_5: np.ndarray, test_long: pd.DataFrame, test_unique: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Create submission DataFrame with order matching test_long."""
    green = final_5[:, 0]
    dead = final_5[:, 1]
    clover = final_5[:, 2]
    gdm = final_5[:, 3]
    total = final_5[:, 4]

    def nnz(x: np.ndarray) -> np.ndarray:
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return np.maximum(0, x)

    green, dead, clover, gdm, total = map(nnz, [green, dead, clover, gdm, total])

    wide = pd.DataFrame(
        {
            "image_path": test_unique["image_path"],
            "Dry_Green_g": green,
            "Dry_Dead_g": dead,
            "Dry_Clover_g": clover,
            "GDM_g": gdm,
            "Dry_Total_g": total,
        }
    )

    long_preds = wide.melt(
        id_vars=["image_path"],
        value_vars=CFG.ALL_TARGET_COLS,
        var_name="target_name",
        value_name="target",
    )

    sub = pd.merge(
        test_long[["sample_id", "image_path", "target_name"]],
        long_preds,
        on=["image_path", "target_name"],
        how="left",
    )[["sample_id", "target"]]

    sub["target"] = np.nan_to_num(sub["target"], nan=0.0, posinf=0.0, neginf=0.0)
    sub.columns = ["sample_id", "value"]  # Rename for submission format
    sub.to_csv(CFG.SUBMISSION_FILE, index=False)
    print(f"Saved: {CFG.SUBMISSION_FILE}")
    print(sub.head(10))
    return sub, wide


def main():
    """Run DINOv3 Direct model inference."""
    print("="*60)
    print("DINOv3 Direct Model Inference")
    print("="*60)
    print(f"Device: {CFG.DEVICE}")
    print(f"Model dir: {CFG.MODEL_DIR}")
    tta_info = f"{CFG.TTA_LEVEL} ({len(get_tta_transforms(level=CFG.TTA_LEVEL))} views)" if CFG.USE_TTA else "disabled"
    print(f"TTA: {tta_info}")
    
    # Load config from results.json (dinov3_train.py output)
    config = load_config_from_results(CFG.MODEL_DIR)
    if config:
        # dinov3_train.py saves grid as single int
        grid_val = config.get("grid", 2)
        CFG.GRID = (grid_val, grid_val) if isinstance(grid_val, int) else tuple(grid_val)
        # Auto-detect image size from config
        CFG.IMG_SIZE = config.get("img_size", CFG.IMG_SIZE)
        CFG.DROPOUT = config.get("dropout", CFG.DROPOUT)
        CFG.HIDDEN_RATIO = config.get("hidden_ratio", CFG.HIDDEN_RATIO)
        CFG.USE_FILM = config.get("use_film", CFG.USE_FILM)
        CFG.USE_ATTENTION_POOL = config.get("use_attention_pool", CFG.USE_ATTENTION_POOL)
        CFG.TRAIN_DEAD = config.get("train_dead", CFG.TRAIN_DEAD)
        CFG.TRAIN_CLOVER = config.get("train_clover", CFG.TRAIN_CLOVER)
        CFG.USE_VEGETATION_INDICES = config.get("use_vegetation_indices", CFG.USE_VEGETATION_INDICES)
        CFG.USE_DISPARITY = config.get("use_disparity", CFG.USE_DISPARITY)
        CFG.USE_DEPTH = config.get("use_depth", CFG.USE_DEPTH)
        CFG.DEPTH_MODEL_SIZE = config.get("depth_model_size", CFG.DEPTH_MODEL_SIZE)
        CFG.USE_DEPTH_ATTENTION = config.get("depth_attention", CFG.USE_DEPTH_ATTENTION)
        CFG.USE_LEARNABLE_AUG = config.get("use_learnable_aug", CFG.USE_LEARNABLE_AUG)
        CFG.LEARNABLE_AUG_COLOR = config.get("learnable_aug_color", CFG.LEARNABLE_AUG_COLOR)
        CFG.LEARNABLE_AUG_SPATIAL = config.get("learnable_aug_spatial", CFG.LEARNABLE_AUG_SPATIAL)
        # New auxiliary heads
        CFG.USE_PRESENCE_HEADS = config.get("use_presence_heads", CFG.USE_PRESENCE_HEADS)
        CFG.USE_NDVI_HEAD = config.get("use_ndvi_head", CFG.USE_NDVI_HEAD)
        CFG.USE_HEIGHT_HEAD = config.get("use_height_head", CFG.USE_HEIGHT_HEAD)
        CFG.USE_SPECIES_HEAD = config.get("use_species_head", CFG.USE_SPECIES_HEAD)
        print(f"Loaded config from results.json")
    
    print(f"Model: DINOv3Direct")
    print(f"Backbone: vit_base_patch16_dinov3")
    print(f"Grid: {CFG.GRID}")
    print(f"Image size: {CFG.IMG_SIZE}")
    if CFG.TRAIN_DEAD or CFG.TRAIN_CLOVER:
        print(f"Optional heads: train_dead={CFG.TRAIN_DEAD}, train_clover={CFG.TRAIN_CLOVER}")
    has_extras = (CFG.USE_VEGETATION_INDICES or CFG.USE_DISPARITY or CFG.USE_DEPTH or 
                  CFG.USE_DEPTH_ATTENTION or CFG.USE_LEARNABLE_AUG or CFG.USE_PRESENCE_HEADS or
                  CFG.USE_NDVI_HEAD or CFG.USE_HEIGHT_HEAD)
    if has_extras:
        extras = []
        if CFG.USE_VEGETATION_INDICES:
            extras.append("Vegetation Indices")
        if CFG.USE_DISPARITY:
            extras.append("Stereo Disparity")
        if CFG.USE_DEPTH:
            extras.append(f"Depth Stats (DA2-{CFG.DEPTH_MODEL_SIZE})")
        if CFG.USE_DEPTH_ATTENTION:
            extras.append(f"Depth Attention (DA2-{CFG.DEPTH_MODEL_SIZE})")
        if CFG.USE_LEARNABLE_AUG:
            aug_types = []
            if CFG.LEARNABLE_AUG_COLOR:
                aug_types.append("color")
            if CFG.LEARNABLE_AUG_SPATIAL:
                aug_types.append("spatial")
            extras.append(f"Learnable Aug ({'+'.join(aug_types)})")
        if CFG.USE_PRESENCE_HEADS:
            extras.append("Presence Heads (Dead/Clover)")
        if CFG.USE_NDVI_HEAD:
            extras.append("NDVI Head")
        if CFG.USE_HEIGHT_HEAD:
            extras.append("Height Head")
        print(f"Innovative features: {', '.join(extras)}")
    
    # Find checkpoints (optionally select top K folds)
    checkpoints, fold_indices = find_model_checkpoints(CFG.MODEL_DIR, top_k=CFG.TOP_K_FOLDS)
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {CFG.MODEL_DIR}")
    if CFG.TOP_K_FOLDS:
        print(f"Using top {len(checkpoints)} folds: {fold_indices}")
    else:
        print(f"Found {len(checkpoints)} fold checkpoints")
    
    # Detect depth model path (for Kaggle offline)
    depth_model_path = get_depth_model_path()
    if depth_model_path:
        print(f"\nDepth model path: {depth_model_path}")
    else:
        print("\nDepth model: Will download from HuggingFace (requires internet)")
    
    # Load models
    print("\nLoading models...")
    models = []
    for ckpt in checkpoints:
        model = load_fold_model(
            ckpt,
            grid=CFG.GRID,
            dropout=CFG.DROPOUT,
            hidden_ratio=CFG.HIDDEN_RATIO,
            use_film=CFG.USE_FILM,
            use_attention_pool=CFG.USE_ATTENTION_POOL,
            train_dead=CFG.TRAIN_DEAD,
            train_clover=CFG.TRAIN_CLOVER,
            use_vegetation_indices=CFG.USE_VEGETATION_INDICES,
            use_disparity=CFG.USE_DISPARITY,
            use_depth=CFG.USE_DEPTH,
            depth_model_size=CFG.DEPTH_MODEL_SIZE,
            use_depth_attention=CFG.USE_DEPTH_ATTENTION,
            use_learnable_aug=CFG.USE_LEARNABLE_AUG,
            learnable_aug_color=CFG.LEARNABLE_AUG_COLOR,
            learnable_aug_spatial=CFG.LEARNABLE_AUG_SPATIAL,
            use_presence_heads=CFG.USE_PRESENCE_HEADS,
            use_ndvi_head=CFG.USE_NDVI_HEAD,
            use_height_head=CFG.USE_HEIGHT_HEAD,
            use_species_head=CFG.USE_SPECIES_HEAD,
            depth_model_path=depth_model_path,  # Pass local path for Kaggle
        )
        models.append(model)
        print(f"  Loaded: {os.path.basename(ckpt)}")
    
    # Load test data
    print("\nLoading test data...")
    test_long = pd.read_csv(CFG.TEST_CSV)
    
    # Extract sample_id_prefix for dataset compatibility
    if "sample_id_prefix" not in test_long.columns:
        if "image_path" in test_long.columns:
            test_long["sample_id_prefix"] = test_long["image_path"].str.extract(r'([A-Z]+\d+)')[0]
        elif "sample_id" in test_long.columns:
            test_long["sample_id_prefix"] = test_long["sample_id"].str.split("__").str[0]
        else:
            raise ValueError("Cannot find sample_id or image_path column")
    
    # Use image_path for dedup to match v1 submission ordering
    test_unique = test_long.drop_duplicates(subset=["image_path"]).reset_index(drop=True)
    print(f"Test samples: {len(test_unique)}")
    
    # Run inference
    print("\nRunning inference...")
    preds = run_inference_tta(models, test_unique)
    print(f"Predictions shape: {preds.shape}")
    
    # Constraint check
    component_sum = preds[:, 0] + preds[:, 1] + preds[:, 2]
    total_pred = preds[:, 4]
    diff = np.abs(component_sum - total_pred)
    print(f"\nConstraint check (G+D+C=T):")
    print(f"  Max diff: {diff.max():.6f}")
    print(f"  Mean diff: {diff.mean():.6f}")
    
    # Statistics
    print(f"\nPrediction stats:")
    for i, name in enumerate(["Green", "Dead", "Clover", "GDM", "Total"]):
        print(f"  {name}: mean={preds[:, i].mean():.2f}, std={preds[:, i].std():.2f}, "
              f"min={preds[:, i].min():.2f}, max={preds[:, i].max():.2f}")
    
    # Create submission
    print("\nCreating submission...")
    sub_df, wide_df = create_submission(preds, test_long, test_unique)
    print(sub_df.head(10))
    
    # Cleanup
    del models
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("\n" + "="*60)
    print("Done!")
    print("="*60)
    return sub_df


# Run
submission = main()

Device: cuda
DINOv3 Direct Model Inference
Device: cuda
Model dir: /workspace/biomass-kaggle/outputs/dinov3_full_mse
TTA: default (3 views)
Loaded config from results.json
Model: DINOv3Direct
Backbone: vit_base_patch16_dinov3
Grid: (2, 2)
Image size: 576
Optional heads: train_dead=False, train_clover=True
Innovative features: Depth Stats (DA2-small), Depth Attention (DA2-small), Presence Heads (Dead/Clover), NDVI Head, Height Head
Using top-5 checkpoints for ensemble (train-all mode)
Found 5 fold checkpoints

Depth model path: /workspace/biomass-kaggle/outputs/dinov3_full_mse/depth_model

Loading models...
  Filtered 861 embedded depth model keys from checkpoint
  Loaded: _topk_fold0_ep52.pth
  Filtered 861 embedded depth model keys from checkpoint
  Loaded: _topk_fold0_ep51.pth
  Filtered 861 embedded depth model keys from checkpoint
  Loaded: _topk_fold0_ep48.pth
  Filtered 861 embedded depth model keys from checkpoint
  Loaded: _topk_fold0_ep47.pth
  Filtered 861 embedded depth mode

                                                             

  View 2/3...


                                                             

  View 3/3...


                                                             

Predictions shape: (357, 5)

Constraint check (G+D+C=T):
  Max diff: 0.247450
  Mean diff: 0.005716

Prediction stats:
  Green: mean=25.51, std=24.34, min=0.00, max=126.05
  Dead: mean=11.00, std=9.35, min=0.00, max=42.67
  Clover: mean=6.61, std=11.89, min=0.00, max=69.95
  GDM: mean=32.13, std=24.36, min=0.62, max=128.31
  Total: mean=43.12, std=26.37, min=0.62, max=139.16

Creating submission...
Saved: submission.csv
                    sample_id         value
0  ID1011485656__Dry_Clover_g  9.999725e-07
1    ID1011485656__Dry_Dead_g  3.211934e+01
2   ID1011485656__Dry_Green_g  1.832914e+01
3   ID1011485656__Dry_Total_g  5.044847e+01
4         ID1011485656__GDM_g  1.832914e+01
5  ID1012260530__Dry_Clover_g  9.319995e-07
6    ID1012260530__Dry_Dead_g  0.000000e+00
7   ID1012260530__Dry_Green_g  8.601438e+00
8   ID1012260530__Dry_Total_g  8.601438e+00
9         ID1012260530__GDM_g  8.601438e+00
                    sample_id         value
0  ID1011485656__Dry_Clover_g  9.999725e-07
1   