## CSIRO Image2Biomass – Ratio Model Inference

This notebook runs inference using Ratio DINOv2/DINOv3 models trained with `train_ratio.py`.

**Supports:**
- `SoftmaxRatioDINO`: Predicts Total + softmax(Green, Dead, Clover) ratios
- `HierarchicalRatioDINO`: Predicts Total → GDM/Total → Green/GDM hierarchically
- `DirectDINO`: Predicts Total, Green, GDM directly; derives Dead, Clover
- **DINOv2** (518×518) and **DINOv3** (256/512/992) backbones
- Multi-fold ensemble
- Test-Time Augmentation (TTA)
- MPS/CUDA/CPU support

**Key advantage**: Components always sum to Total (mathematically guaranteed)

In [None]:
import os
import gc
import json
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple, Optional, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm


class CFG:
    """Configuration for Ratio model inference."""
    
    # ==================== LOCAL TESTING MODE ====================
    LOCAL_TEST = True
    
    if LOCAL_TEST:
        BASE_PATH = "../data"
        TEST_CSV = os.path.join(BASE_PATH, "train.csv")  # Use train.csv for local OOF testing
        TEST_IMAGE_DIR = os.path.join(BASE_PATH, "train")
        # Trained ratio model directory
        MODEL_DIR = "../outputs/ratio_20251216_215550"
    else:
        BASE_PATH = "/kaggle/input/csiro-biomass"
        TEST_CSV = os.path.join(BASE_PATH, "test.csv")
        TEST_IMAGE_DIR = os.path.join(BASE_PATH, "test")
        MODEL_DIR = "/kaggle/input/your-ratio-model"  # Update for Kaggle
    
    # Backbone config (auto-loaded from results.json)
    BACKBONE = "vit_base_patch14_reg4_dinov2.lvd142m"
    
    # ==================== INFERENCE SETTINGS ====================
    SUBMISSION_FILE = "submission.csv"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Updated below
    BATCH_SIZE = 1
    NUM_WORKERS = 0
    
    # Model architecture params (auto-loaded from results.json)
    MODEL_TYPE = "hierarchical"  # "softmax", "hierarchical", or "direct"
    DROPOUT = 0.2
    HIDDEN_RATIO = 0.5
    GRID = (2, 2)
    USE_FILM = True
    USE_ATTENTION_POOL = True
    RATIO_TEMPERATURE = 1.0
    IMG_SIZE = 518  # Auto-detected: 518 for DINOv2, 256/512/992 for DINOv3
    
    # TTA settings
    USE_TTA = True
    
    ALL_TARGET_COLS = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]


# Backbone presets for common configurations
BACKBONE_PRESETS = {
    "vit_base_patch14_reg4_dinov2.lvd142m": {"img_size": 518, "grid": (2, 2)},
    "vit_base_patch16_dinov3": {"img_size": 512, "grid": (2, 2)},
    "dinov3_base": {"img_size": 512, "grid": (2, 2)},
}


def get_device() -> torch.device:
    """Get best available device."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


# Update device
CFG.DEVICE = get_device()

print(f"Device: {CFG.DEVICE}")


def load_config_from_results(model_dir: str) -> Dict:
    """Load training config from results.json if available."""
    results_path = os.path.join(model_dir, "results.json")
    if os.path.exists(results_path):
        with open(results_path) as f:
            results = json.load(f)
        return results.get("config", {})
    return {}


# ==================== DATASET ====================

class TestBiomassDataset(Dataset):
    """Dataset for test/inference."""
    
    def __init__(self, df: pd.DataFrame, image_dir: str, transform: A.Compose):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
        row = self.df.iloc[idx]
        sample_id = row["sample_id_prefix"]
        
        img_path = os.path.join(self.image_dir, f"{sample_id}.jpg")
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        H, W = img.shape[:2]
        mid = W // 2
        img_left = img[:, :mid, :]
        img_right = img[:, mid:, :]
        
        if self.transform:
            aug_left = self.transform(image=img_left)
            aug_right = self.transform(image=img_right)
            img_left = aug_left["image"]
            img_right = aug_right["image"]
        
        return img_left, img_right, sample_id


def get_val_transform(img_size: Optional[int] = None) -> A.Compose:
    if img_size is None:
        img_size = CFG.IMG_SIZE
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


def get_tta_transforms(img_size: Optional[int] = None) -> List[A.Compose]:
    """TTA transforms: original + hflip + brightness."""
    if img_size is None:
        img_size = CFG.IMG_SIZE
    base = A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    hflip = A.Compose([
        A.Resize(img_size, img_size),
        A.HorizontalFlip(p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    bright = A.Compose([
        A.Resize(img_size, img_size),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    return [base, hflip, bright]


# ==================== MODEL COMPONENTS ====================

def _build_dino_by_name(backbone_name: str, pretrained: bool = False):
    model = timm.create_model(backbone_name, pretrained=pretrained, num_classes=0)
    feat_dim = model.embed_dim
    if hasattr(model, "patch_embed"):
        input_res = model.patch_embed.img_size
        if isinstance(input_res, (tuple, list)):
            input_res = input_res[0]
    else:
        input_res = 518
    return model, feat_dim, input_res


def _make_edges(length: int, n_parts: int) -> List[Tuple[int, int]]:
    step = length // n_parts
    edges = [(i * step, (i + 1) * step) for i in range(n_parts)]
    edges[-1] = (edges[-1][0], length)
    return edges


class FiLM(nn.Module):
    def __init__(self, in_dim: int) -> None:
        super().__init__()
        hidden = max(64, in_dim // 2)
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, in_dim * 2),
        )
    
    def forward(self, context: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        gb = self.mlp(context)
        gamma, beta = torch.chunk(gb, 2, dim=1)
        return gamma, beta


class AttentionPooling(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.scale = dim ** -0.5
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.query(x.mean(dim=1, keepdim=True))
        k = self.key(x)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        return (attn @ x).squeeze(1)


# ==================== RATIO MODELS ====================

class SoftmaxRatioDINO(nn.Module):
    """Softmax Ratio Model: Predict Total + component ratios."""
    
    def __init__(
        self,
        backbone_name: str = "vit_base_patch14_reg4_dinov2.lvd142m",
        grid: Tuple[int, int] = (2, 2),
        pretrained: bool = False,
        dropout: float = 0.2,
        hidden_ratio: float = 0.5,
        use_film: bool = True,
        use_attention_pool: bool = True,
        ratio_temperature: float = 1.0,
    ) -> None:
        super().__init__()
        
        self.backbone, feat_dim, input_res = _build_dino_by_name(backbone_name, pretrained)
        self.input_res = int(input_res)
        self.feat_dim = feat_dim
        self.grid = tuple(grid)
        self.use_film = use_film
        self.use_attention_pool = use_attention_pool
        self.ratio_temperature = ratio_temperature
        
        if use_film:
            self.film_left = FiLM(feat_dim)
            self.film_right = FiLM(feat_dim)
        
        if use_attention_pool:
            self.attn_pool_left = AttentionPooling(feat_dim)
            self.attn_pool_right = AttentionPooling(feat_dim)
        
        self.combined_dim = feat_dim * 2
        hidden_dim = max(64, int(self.combined_dim * hidden_ratio))
        
        self.shared_proj = nn.Sequential(
            nn.LayerNorm(self.combined_dim),
            nn.Linear(self.combined_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        
        def _make_head(in_dim: int, out_dim: int = 1) -> nn.Sequential:
            return nn.Sequential(
                nn.Linear(in_dim, in_dim),
                nn.GELU(),
                nn.Dropout(dropout * 0.5),
                nn.Linear(in_dim, out_dim),
            )
        
        self.head_total = _make_head(hidden_dim, 1)
        self.head_ratios = _make_head(hidden_dim, 3)
        self.softplus = nn.Softplus(beta=1.0)
    
    def _collect_tiles(self, x: torch.Tensor) -> List[torch.Tensor]:
        _, C, H, W = x.shape
        r, c = self.grid
        rows = _make_edges(H, r)
        cols = _make_edges(W, c)
        tiles = []
        for (y0, y1) in rows:
            for (x0, x1) in cols:
                tile = x[:, :, y0:y1, x0:x1]
                tile = F.interpolate(tile, size=(self.input_res, self.input_res), mode="bilinear", align_corners=False)
                tiles.append(tile)
        return tiles
    
    def _extract_tiles_fused(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B = x_left.size(0)
        tiles_left = self._collect_tiles(x_left)
        tiles_right = self._collect_tiles(x_right)
        num_tiles = len(tiles_left)
        
        all_tiles = torch.cat(tiles_left + tiles_right, dim=0)
        all_feats = self.backbone(all_tiles)
        
        total_tiles = 2 * num_tiles
        all_feats = all_feats.view(total_tiles, B, -1).permute(1, 0, 2)
        feats_left = all_feats[:, :num_tiles, :]
        feats_right = all_feats[:, num_tiles:, :]
        return feats_left, feats_right
    
    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        tiles_left, tiles_right = self._extract_tiles_fused(x_left, x_right)
        
        ctx_left = tiles_left.mean(dim=1)
        ctx_right = tiles_right.mean(dim=1)
        
        if self.use_film:
            gamma_l, beta_l = self.film_left(ctx_right)
            gamma_r, beta_r = self.film_right(ctx_left)
            tiles_left = tiles_left * (1 + gamma_l.unsqueeze(1)) + beta_l.unsqueeze(1)
            tiles_right = tiles_right * (1 + gamma_r.unsqueeze(1)) + beta_r.unsqueeze(1)
        
        if self.use_attention_pool:
            f_l = self.attn_pool_left(tiles_left)
            f_r = self.attn_pool_right(tiles_right)
        else:
            f_l = tiles_left.mean(dim=1)
            f_r = tiles_right.mean(dim=1)
        
        f = torch.cat([f_l, f_r], dim=1)
        f = self.shared_proj(f)
        
        total = self.softplus(self.head_total(f))
        logits = self.head_ratios(f)
        ratios = F.softmax(logits / self.ratio_temperature, dim=1)
        green_ratio, dead_ratio, clover_ratio = ratios[:, 0:1], ratios[:, 1:2], ratios[:, 2:3]
        
        green = total * green_ratio
        dead = total * dead_ratio
        clover = total * clover_ratio
        gdm = green + clover
        
        return green, dead, clover, gdm, total


class HierarchicalRatioDINO(nn.Module):
    """Hierarchical Ratio Model: Total → GDM/Total → Green/GDM."""
    
    def __init__(
        self,
        backbone_name: str = "vit_base_patch14_reg4_dinov2.lvd142m",
        grid: Tuple[int, int] = (2, 2),
        pretrained: bool = False,
        dropout: float = 0.2,
        hidden_ratio: float = 0.5,
        use_film: bool = True,
        use_attention_pool: bool = True,
    ) -> None:
        super().__init__()
        
        self.backbone, feat_dim, input_res = _build_dino_by_name(backbone_name, pretrained)
        self.input_res = int(input_res)
        self.feat_dim = feat_dim
        self.grid = tuple(grid)
        self.use_film = use_film
        self.use_attention_pool = use_attention_pool
        
        if use_film:
            self.film_left = FiLM(feat_dim)
            self.film_right = FiLM(feat_dim)
        
        if use_attention_pool:
            self.attn_pool_left = AttentionPooling(feat_dim)
            self.attn_pool_right = AttentionPooling(feat_dim)
        
        self.combined_dim = feat_dim * 2
        hidden_dim = max(64, int(self.combined_dim * hidden_ratio))
        
        self.shared_proj = nn.Sequential(
            nn.LayerNorm(self.combined_dim),
            nn.Linear(self.combined_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        
        def _make_head(in_dim: int, out_dim: int = 1) -> nn.Sequential:
            return nn.Sequential(
                nn.Linear(in_dim, in_dim),
                nn.GELU(),
                nn.Dropout(dropout * 0.5),
                nn.Linear(in_dim, out_dim),
            )
        
        self.head_total = _make_head(hidden_dim, 1)
        self.head_alive_ratio = _make_head(hidden_dim, 1)
        self.head_green_ratio = _make_head(hidden_dim, 1)
        self.softplus = nn.Softplus(beta=1.0)
    
    def _collect_tiles(self, x: torch.Tensor) -> List[torch.Tensor]:
        _, C, H, W = x.shape
        r, c = self.grid
        rows = _make_edges(H, r)
        cols = _make_edges(W, c)
        tiles = []
        for (y0, y1) in rows:
            for (x0, x1) in cols:
                tile = x[:, :, y0:y1, x0:x1]
                tile = F.interpolate(tile, size=(self.input_res, self.input_res), mode="bilinear", align_corners=False)
                tiles.append(tile)
        return tiles
    
    def _extract_tiles_fused(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B = x_left.size(0)
        tiles_left = self._collect_tiles(x_left)
        tiles_right = self._collect_tiles(x_right)
        num_tiles = len(tiles_left)
        
        all_tiles = torch.cat(tiles_left + tiles_right, dim=0)
        all_feats = self.backbone(all_tiles)
        
        total_tiles = 2 * num_tiles
        all_feats = all_feats.view(total_tiles, B, -1).permute(1, 0, 2)
        feats_left = all_feats[:, :num_tiles, :]
        feats_right = all_feats[:, num_tiles:, :]
        return feats_left, feats_right
    
    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        tiles_left, tiles_right = self._extract_tiles_fused(x_left, x_right)
        
        ctx_left = tiles_left.mean(dim=1)
        ctx_right = tiles_right.mean(dim=1)
        
        if self.use_film:
            gamma_l, beta_l = self.film_left(ctx_right)
            gamma_r, beta_r = self.film_right(ctx_left)
            tiles_left = tiles_left * (1 + gamma_l.unsqueeze(1)) + beta_l.unsqueeze(1)
            tiles_right = tiles_right * (1 + gamma_r.unsqueeze(1)) + beta_r.unsqueeze(1)
        
        if self.use_attention_pool:
            f_l = self.attn_pool_left(tiles_left)
            f_r = self.attn_pool_right(tiles_right)
        else:
            f_l = tiles_left.mean(dim=1)
            f_r = tiles_right.mean(dim=1)
        
        f = torch.cat([f_l, f_r], dim=1)
        f = self.shared_proj(f)
        
        total = self.softplus(self.head_total(f))
        alive_ratio = torch.sigmoid(self.head_alive_ratio(f))
        green_ratio = torch.sigmoid(self.head_green_ratio(f))
        
        gdm = total * alive_ratio
        dead = total - gdm
        green = gdm * green_ratio
        clover = gdm - green
        
        dead = F.relu(dead)
        clover = F.relu(clover)
        
        return green, dead, clover, gdm, total


class DirectDINO(nn.Module):
    """Direct Model: Predict Total, Green, GDM directly; derive Dead, Clover."""
    
    def __init__(
        self,
        backbone_name: str = "vit_base_patch14_reg4_dinov2.lvd142m",
        grid: Tuple[int, int] = (2, 2),
        pretrained: bool = False,
        dropout: float = 0.2,
        hidden_ratio: float = 0.5,
        use_film: bool = True,
        use_attention_pool: bool = True,
    ) -> None:
        super().__init__()
        
        self.backbone, feat_dim, input_res = _build_dino_by_name(backbone_name, pretrained)
        self.input_res = int(input_res)
        self.feat_dim = feat_dim
        self.grid = tuple(grid)
        self.use_film = use_film
        self.use_attention_pool = use_attention_pool
        
        if use_film:
            self.film_left = FiLM(feat_dim)
            self.film_right = FiLM(feat_dim)
        
        if use_attention_pool:
            self.attn_pool_left = AttentionPooling(feat_dim)
            self.attn_pool_right = AttentionPooling(feat_dim)
        
        self.combined_dim = feat_dim * 2
        hidden_dim = max(64, int(self.combined_dim * hidden_ratio))
        
        self.shared_proj = nn.Sequential(
            nn.LayerNorm(self.combined_dim),
            nn.Linear(self.combined_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        
        def _make_head(in_dim: int, out_dim: int = 1) -> nn.Sequential:
            return nn.Sequential(
                nn.Linear(in_dim, in_dim),
                nn.GELU(),
                nn.Dropout(dropout * 0.5),
                nn.Linear(in_dim, out_dim),
            )
        
        self.head_total = _make_head(hidden_dim, 1)
        self.head_green = _make_head(hidden_dim, 1)
        self.head_gdm = _make_head(hidden_dim, 1)
        self.softplus = nn.Softplus(beta=1.0)
    
    def _collect_tiles(self, x: torch.Tensor) -> List[torch.Tensor]:
        _, C, H, W = x.shape
        r, c = self.grid
        rows = _make_edges(H, r)
        cols = _make_edges(W, c)
        tiles = []
        for (y0, y1) in rows:
            for (x0, x1) in cols:
                tile = x[:, :, y0:y1, x0:x1]
                tile = F.interpolate(tile, size=(self.input_res, self.input_res), mode="bilinear", align_corners=False)
                tiles.append(tile)
        return tiles
    
    def _extract_tiles_fused(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B = x_left.size(0)
        tiles_left = self._collect_tiles(x_left)
        tiles_right = self._collect_tiles(x_right)
        num_tiles = len(tiles_left)
        
        all_tiles = torch.cat(tiles_left + tiles_right, dim=0)
        all_feats = self.backbone(all_tiles)
        
        total_tiles = 2 * num_tiles
        all_feats = all_feats.view(total_tiles, B, -1).permute(1, 0, 2)
        feats_left = all_feats[:, :num_tiles, :]
        feats_right = all_feats[:, num_tiles:, :]
        return feats_left, feats_right
    
    def forward(self, x_left: torch.Tensor, x_right: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        tiles_left, tiles_right = self._extract_tiles_fused(x_left, x_right)
        
        ctx_left = tiles_left.mean(dim=1)
        ctx_right = tiles_right.mean(dim=1)
        
        if self.use_film:
            gamma_l, beta_l = self.film_left(ctx_right)
            gamma_r, beta_r = self.film_right(ctx_left)
            tiles_left = tiles_left * (1 + gamma_l.unsqueeze(1)) + beta_l.unsqueeze(1)
            tiles_right = tiles_right * (1 + gamma_r.unsqueeze(1)) + beta_r.unsqueeze(1)
        
        if self.use_attention_pool:
            f_l = self.attn_pool_left(tiles_left)
            f_r = self.attn_pool_right(tiles_right)
        else:
            f_l = tiles_left.mean(dim=1)
            f_r = tiles_right.mean(dim=1)
        
        f = torch.cat([f_l, f_r], dim=1)
        f = self.shared_proj(f)
        
        # Predict Total, Green, GDM directly (all positive via softplus)
        total_raw = self.softplus(self.head_total(f))
        green_raw = self.softplus(self.head_green(f))
        gdm_raw = self.softplus(self.head_gdm(f))
        
        # Enforce constraints: Total >= GDM >= Green >= 0
        total = total_raw
        gdm = torch.minimum(gdm_raw, total)
        green = torch.minimum(green_raw, gdm)
        
        # Derive Dead and Clover
        dead = total - gdm
        clover = gdm - green
        
        # Ensure non-negative (numerical safety)
        dead = F.relu(dead)
        clover = F.relu(clover)
        
        return green, dead, clover, gdm, total


# ==================== MODEL LOADING ====================

def _strip_module_prefix(sd: dict) -> dict:
    if not sd:
        return sd
    keys = list(sd.keys())
    if all(k.startswith("module.") for k in keys):
        return {k[len("module."):]: v for k, v in sd.items()}
    return sd


def _detect_model_type(sd_keys: set) -> str:
    if any(k.startswith("head_alive_ratio.") for k in sd_keys):
        return "hierarchical"
    elif any(k.startswith("head_ratios.") for k in sd_keys):
        return "softmax"
    elif any(k.startswith("head_green.") for k in sd_keys) and any(k.startswith("head_gdm.") for k in sd_keys):
        return "direct"
    else:
        raise ValueError("Unknown model type")


def _detect_model_config(sd_keys: set) -> dict:
    return {
        "use_film": any(k.startswith("film_left.") for k in sd_keys),
        "use_attention_pool": any(k.startswith("attn_pool_left.") for k in sd_keys),
    }


def load_fold_model(
    path: str,
    backbone_name: str,
    model_type: str = "hierarchical",
    grid: Tuple[int, int] = (2, 2),
    dropout: float = 0.2,
    hidden_ratio: float = 0.5,
    use_film: bool = True,
    use_attention_pool: bool = True,
    ratio_temperature: float = 1.0,
) -> nn.Module:
    """Load a ratio model checkpoint."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    
    try:
        raw_sd = torch.load(path, map_location="cpu", weights_only=True)
    except TypeError:
        raw_sd = torch.load(path, map_location="cpu")
    
    sd = _strip_module_prefix(raw_sd)
    sd_keys = set(sd.keys())
    
    detected_type = _detect_model_type(sd_keys)
    if detected_type != model_type:
        print(f"  Auto-detected model type: {detected_type}")
        model_type = detected_type
    
    detected_config = _detect_model_config(sd_keys)
    use_film = detected_config.get("use_film", use_film)
    use_attention_pool = detected_config.get("use_attention_pool", use_attention_pool)
    
    if model_type == "softmax":
        model = SoftmaxRatioDINO(
            backbone_name=backbone_name, grid=grid, pretrained=False,
            dropout=dropout, hidden_ratio=hidden_ratio,
            use_film=use_film, use_attention_pool=use_attention_pool,
            ratio_temperature=ratio_temperature,
        )
    elif model_type == "direct":
        model = DirectDINO(
            backbone_name=backbone_name, grid=grid, pretrained=False,
            dropout=dropout, hidden_ratio=hidden_ratio,
            use_film=use_film, use_attention_pool=use_attention_pool,
        )
    else:
        model = HierarchicalRatioDINO(
            backbone_name=backbone_name, grid=grid, pretrained=False,
            dropout=dropout, hidden_ratio=hidden_ratio,
            use_film=use_film, use_attention_pool=use_attention_pool,
        )
    
    model.load_state_dict(sd, strict=True)
    model = model.to(CFG.DEVICE)
    model.eval()
    
    return model


def find_model_checkpoints(model_dir: str) -> List[str]:
    checkpoints = []
    for f in os.listdir(model_dir):
        if f.startswith("ratio_best_fold") and f.endswith(".pth"):
            checkpoints.append(os.path.join(model_dir, f))
    checkpoints.sort()
    return checkpoints


# ==================== INFERENCE ====================

@torch.no_grad()
def predict_one_view(models: List[nn.Module], loader: DataLoader) -> np.ndarray:
    """Run inference with fold ensemble."""
    out_list = []
    
    for batch in tqdm(loader, desc="Predicting", leave=False):
        x_left, x_right, _ = batch
        x_left = x_left.to(CFG.DEVICE)
        x_right = x_right.to(CFG.DEVICE)
        
        fold_preds = []
        for model in models:
            green, dead, clover, gdm, total = model(x_left, x_right)
            pred = torch.cat([green, dead, clover, gdm, total], dim=1)
            fold_preds.append(pred.float().cpu().numpy())
        
        avg_pred = np.mean(fold_preds, axis=0)
        out_list.append(avg_pred)
    
    return np.concatenate(out_list, axis=0)


def run_inference_tta(models: List[nn.Module], df: pd.DataFrame) -> np.ndarray:
    """Run inference with TTA."""
    transforms = get_tta_transforms() if CFG.USE_TTA else [get_val_transform()]
    
    all_preds = []
    for i, transform in enumerate(transforms):
        print(f"TTA view {i+1}/{len(transforms)}...")
        ds = TestBiomassDataset(df, CFG.TEST_IMAGE_DIR, transform)
        loader = DataLoader(
            ds, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=CFG.NUM_WORKERS
        )
        preds = predict_one_view(models, loader)
        all_preds.append(preds)
    
    return np.mean(all_preds, axis=0)


def create_submission(preds: np.ndarray, df: pd.DataFrame) -> pd.DataFrame:
    """Create submission DataFrame."""
    green = np.maximum(0, preds[:, 0])
    dead = np.maximum(0, preds[:, 1])
    clover = np.maximum(0, preds[:, 2])
    gdm = np.maximum(0, preds[:, 3])
    total = np.maximum(0, preds[:, 4])
    
    rows = []
    for i, sample_id in enumerate(df["sample_id_prefix"]):
        rows.append({"sample_id": f"{sample_id}__Dry_Green_g", "value": green[i]})
        rows.append({"sample_id": f"{sample_id}__Dry_Dead_g", "value": dead[i]})
        rows.append({"sample_id": f"{sample_id}__Dry_Clover_g", "value": clover[i]})
        rows.append({"sample_id": f"{sample_id}__GDM_g", "value": gdm[i]})
        rows.append({"sample_id": f"{sample_id}__Dry_Total_g", "value": total[i]})
    
    return pd.DataFrame(rows)


def main():
    """Run ratio model inference."""
    print("="*60)
    print("Ratio Model Inference")
    print("="*60)
    print(f"Device: {CFG.DEVICE}")
    print(f"Model dir: {CFG.MODEL_DIR}")
    print(f"TTA: {CFG.USE_TTA}")
    
    # Load config
    config = load_config_from_results(CFG.MODEL_DIR)
    if config:
        CFG.MODEL_TYPE = config.get("model_type", CFG.MODEL_TYPE)
        CFG.BACKBONE = config.get("backbone", CFG.BACKBONE)
        CFG.GRID = tuple(config.get("grid", list(CFG.GRID)))
        CFG.DROPOUT = config.get("dropout", CFG.DROPOUT)
        CFG.HIDDEN_RATIO = config.get("hidden_ratio", CFG.HIDDEN_RATIO)
        CFG.USE_FILM = config.get("use_film", CFG.USE_FILM)
        CFG.USE_ATTENTION_POOL = config.get("use_attention_pool", CFG.USE_ATTENTION_POOL)
        CFG.IMG_SIZE = config.get("img_size", CFG.IMG_SIZE)
        print(f"Loaded config from results.json")
    
    # Check if img_size was in config
    if "img_size" not in config:
        print(f"⚠️  WARNING: img_size not in results.json - using default/preset")
        # Apply backbone preset
        if CFG.BACKBONE in BACKBONE_PRESETS:
            preset = BACKBONE_PRESETS[CFG.BACKBONE]
            CFG.IMG_SIZE = preset["img_size"]
            print(f"  Applied preset for {CFG.BACKBONE}: img_size={CFG.IMG_SIZE}")
        # Auto-detect DINOv3 backbones
        if "dinov3" in CFG.BACKBONE.lower() and CFG.IMG_SIZE == 518:
            CFG.IMG_SIZE = 512  # Default for DINOv3
            print(f"  Auto-adjusted IMG_SIZE to {CFG.IMG_SIZE} for DINOv3")
        print(f"  ⚠️  If training used different img_size (e.g., 992), set CFG.IMG_SIZE manually!")
    
    print(f"Model type: {CFG.MODEL_TYPE}")
    print(f"Backbone: {CFG.BACKBONE}")
    print(f"Grid: {CFG.GRID}")
    print(f"Image size: {CFG.IMG_SIZE}")
    
    # Find checkpoints
    checkpoints = find_model_checkpoints(CFG.MODEL_DIR)
    if not checkpoints:
        raise ValueError(f"No checkpoints found in {CFG.MODEL_DIR}")
    print(f"Found {len(checkpoints)} fold checkpoints")
    
    # Load models
    print("\nLoading models...")
    models = []
    for ckpt in checkpoints:
        model = load_fold_model(
            ckpt,
            backbone_name=CFG.BACKBONE,
            model_type=CFG.MODEL_TYPE,
            grid=CFG.GRID,
            dropout=CFG.DROPOUT,
            hidden_ratio=CFG.HIDDEN_RATIO,
            use_film=CFG.USE_FILM,
            use_attention_pool=CFG.USE_ATTENTION_POOL,
        )
        models.append(model)
        print(f"  Loaded: {os.path.basename(ckpt)}")
    
    # Load test data
    print("\nLoading test data...")
    test_long = pd.read_csv(CFG.TEST_CSV)
    
    # Extract sample_id_prefix
    if "sample_id_prefix" not in test_long.columns:
        if "image_path" in test_long.columns:
            test_long["sample_id_prefix"] = test_long["image_path"].str.extract(r'([A-Z]+\d+)')[0]
        elif "sample_id" in test_long.columns:
            test_long["sample_id_prefix"] = test_long["sample_id"].str.split("__").str[0]
        else:
            raise ValueError("Cannot find sample_id or image_path column")
    
    test_unique = test_long.drop_duplicates(subset=["sample_id_prefix"]).reset_index(drop=True)
    print(f"Test samples: {len(test_unique)}")
    
    # Run inference
    print("\nRunning inference...")
    preds = run_inference_tta(models, test_unique)
    print(f"Predictions shape: {preds.shape}")
    
    # Constraint check
    component_sum = preds[:, 0] + preds[:, 1] + preds[:, 2]
    total_pred = preds[:, 4]
    diff = np.abs(component_sum - total_pred)
    print(f"\nConstraint check (G+D+C=T):")
    print(f"  Max diff: {diff.max():.6f}")
    print(f"  Mean diff: {diff.mean():.6f}")
    
    # Statistics
    print(f"\nPrediction stats:")
    for i, name in enumerate(["Green", "Dead", "Clover", "GDM", "Total"]):
        print(f"  {name}: mean={preds[:, i].mean():.2f}, std={preds[:, i].std():.2f}, "
              f"min={preds[:, i].min():.2f}, max={preds[:, i].max():.2f}")
    
    # Create submission
    print("\nCreating submission...")
    sub_df = create_submission(preds, test_unique)
    sub_df.to_csv(CFG.SUBMISSION_FILE, index=False)
    print(f"Saved: {CFG.SUBMISSION_FILE}")
    print(sub_df.head(10))
    
    # Cleanup
    del models
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("\n" + "="*60)
    print("Done!")
    print("="*60)
    return sub_df


# Run
submission = main()

  from .autonotebook import tqdm as notebook_tqdm


Device: mps
Ratio Model Inference
Device: mps
Model dir: ../outputs/ratio_20251216_215550
TTA: True
Loaded config from results.json
Model type: direct
Backbone: vit_base_patch16_dinov3
Grid: (2, 2)
Image size: 992
Found 5 fold checkpoints

Loading models...
  Loaded: ratio_best_fold0.pth
  Loaded: ratio_best_fold1.pth
  Loaded: ratio_best_fold2.pth
  Loaded: ratio_best_fold3.pth
  Loaded: ratio_best_fold4.pth

Loading test data...
Test samples: 357

Running inference...
TTA view 1/3...


                                                             

TTA view 2/3...


                                                             

TTA view 3/3...


                                                             

Predictions shape: (357, 5)

Constraint check (G+D+C=T):
  Max diff: 0.000015
  Mean diff: 0.000002

Prediction stats:
  Green: mean=28.24, std=23.72, min=0.00, max=109.63
  Dead: mean=11.20, std=6.77, min=1.14, max=31.68
  Clover: mean=7.15, std=11.51, min=0.00, max=60.15
  GDM: mean=35.39, std=23.28, min=1.84, max=112.52
  Total: mean=46.60, std=24.75, min=4.01, max=122.78

Creating submission...
Saved: submission.csv
                    sample_id      value
0   ID1011485656__Dry_Green_g  25.797304
1    ID1011485656__Dry_Dead_g  24.589590
2  ID1011485656__Dry_Clover_g   0.000000
3         ID1011485656__GDM_g  25.797304
4   ID1011485656__Dry_Total_g  50.386898
5   ID1012260530__Dry_Green_g   4.848897
6    ID1012260530__Dry_Dead_g   1.139058
7  ID1012260530__Dry_Clover_g   1.187443
8         ID1012260530__GDM_g   6.036340
9   ID1012260530__Dry_Total_g   7.175398

Done!


