In [1]:
# ============================================================
# CSIRO Image2Biomass - Robust Inference Notebook (Kaggle)
# - CPU-only / hidden rerun friendly
# - Generates submission.csv
# ============================================================

import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

import yaml
from tqdm.auto import tqdm

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # truncated image„ÅßËêΩ„Å°„Å´„Åè„Åè

# display„ÅåÁÑ°„ÅÑÁí∞Â¢É„Åß„ÇÇËêΩ„Å°„Å™„ÅÑ„Çà„ÅÜ„Å´„Åô„Çã
try:
    from IPython.display import display
except Exception:
    display = None


# --------------------------
# Utils
# --------------------------
def sep(title: str, n: int = 80):
    print("=" * n)
    print(title)
    print("=" * n)


def show_df(df: pd.DataFrame, n: int = 3, tail: bool = False):
    """display„ÅåÁÑ°„ÅÑÁí∞Â¢É„Åß„ÇÇËêΩ„Å°„Å™„ÅÑ DataFrame Ë°®Á§∫."""
    print("shape:", df.shape)
    if display is not None:
        display(df.head(n))
        if tail:
            display(df.tail(n))
    else:
        print(df.head(n).to_string(index=False))
        if tail:
            print(df.tail(n).to_string(index=False))


def seed_everything(seed: int = 1129) -> None:
    """CPU-only„Åß„ÇÇÂÆâÂÖ®„Å™seedÂõ∫ÂÆö„ÄÇ"""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    print(f"üå± Seed fixed: {seed}")


def _load_yaml(path: Path) -> Dict:
    with open(path, "r") as f:
        return yaml.safe_load(f)


def glob_walk(root: Path, pattern: str) -> List[Path]:
    return sorted(list(Path(root).glob(pattern)))


def find_comp_dir(comp_root: Optional[str] = None) -> Path:
    """
    comp dataset root „ÇíËá™ÂãïÊ§úÂá∫„ÄÇ
    /kaggle/input ÈÖç‰∏ã„Åã„Çâ test.csv „Å® sample_submission.csv „ÇíÂê´„ÇÄdir„ÇíÊé¢„Åô„ÄÇ
    """
    if comp_root is not None:
        p = Path(comp_root)
        if not p.exists():
            raise FileNotFoundError(f"comp_root not found: {p}")
        return p

    base = Path("/kaggle/input")
    if not base.exists():
        raise FileNotFoundError("'/kaggle/input' not found. Are you running on Kaggle?")

    candidates = []
    for d in base.iterdir():
        if d.is_dir() and (d / "sample_submission.csv").exists():
            # test.csv „ÅåÁÑ°„ÅÑ„Ç≥„É≥„Éö„ÇÇ„ÅÇ„Çã„Åå„ÄÅÂü∫Êú¨„ÅØ„ÅÇ„ÇãÂâçÊèê
            if (d / "test.csv").exists() or (d / "train.csv").exists():
                candidates.append(d)

    if len(candidates) == 0:
        raise FileNotFoundError("Could not find competition dataset dir with sample_submission.csv")

    # csiro„Å£„ÅΩ„ÅÑÂêçÂâç„ÇíÂÑ™ÂÖà
    candidates_sorted = sorted(
        candidates,
        key=lambda x: (("csiro" not in x.name.lower()), x.name.lower())
    )
    chosen = candidates_sorted[0]
    print(f"[INFO] comp_dir auto-detected: {chosen}")
    return chosen


def find_artifact_root(artifact_root: Optional[str] = None) -> Path:
    """
    artifact_root „ÇíËá™ÂãïÊ§úÂá∫„ÄÇ
    /kaggle/input ÈÖç‰∏ã„Åã„Çâ yaml/config.yaml „Å® model/*.pth „ÇíÂê´„ÇÄdir„ÇíÊé¢„Åô„ÄÇ
    """
    if artifact_root is not None:
        p = Path(artifact_root)
        if not p.exists():
            raise FileNotFoundError(f"artifact_root not found: {p}")
        return p

    base = Path("/kaggle/input")
    candidates = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        cfg = d / "yaml" / "config.yaml"
        mdl = d / "model"
        if cfg.exists() and mdl.exists() and len(list(mdl.glob("*.pth"))) > 0:
            candidates.append(d)

    if len(candidates) == 0:
        raise FileNotFoundError("Could not auto-detect artifact_root (need yaml/config.yaml and model/*.pth).")

    chosen = sorted(candidates, key=lambda x: x.name.lower())[0]
    print(f"[INFO] artifact_root auto-detected: {chosen}")
    return chosen


def _pick_state_dict(ckpt: Dict) -> Dict:
    """checkpoint dict„Åã„Çâ state_dict „ÇíÊäΩÂá∫„Åó„ÄÅÂøÖË¶Å„Å™„Çâ 'module.' „ÇíÂâ•„Åå„Åô„ÄÇ"""
    for key in ["model_state_dict", "state_dict", "model"]:
        if key in ckpt and isinstance(ckpt[key], dict):
            sd = ckpt[key]
            break
    else:
        # Áõ¥„Å´ state_dict „Å£„ÅΩ„ÅÑÂ†¥Âêà
        if isinstance(ckpt, dict) and all(isinstance(k, str) for k in ckpt.keys()):
            sd = ckpt
        else:
            raise KeyError("state_dict not found in checkpoint")

    # DDPÁ≠â„Åß module. „Åå‰ªò„ÅÑ„Å¶„ÅÑ„ÇãÂ†¥Âêà„ÅØÂâ•„Åå„Åô
    keys = list(sd.keys())
    if len(keys) > 0 and all(k.startswith("module.") for k in keys):
        sd = {k[len("module."):]: v for k, v in sd.items()}
    return sd


# --------------------------
# Model
# --------------------------
class ConvNeXtRegressor(nn.Module):
    def __init__(
        self,
        backbone: str = "convnext_base",
        pretrained: bool = False,
        num_targets: int = 5,
        in_chans: int = 3,
        drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        head_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.backbone = timm.create_model(
            backbone,
            pretrained=pretrained,   # Kaggle hidden„Åß„ÅØFalseÊé®Â•®
            num_classes=0,
            global_pool="avg",
            in_chans=in_chans,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
        )
        feat_dim = self.backbone.num_features

        self.head_dropout = nn.Dropout(head_dropout) if head_dropout and head_dropout > 0 else nn.Identity()
        self.head = nn.Linear(feat_dim, num_targets)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feat = self.backbone(x)
        feat = self.head_dropout(feat)
        return self.head(feat)


# --------------------------
# Transform / Dataset
# --------------------------
def build_infer_transform(img_size: int, mean: List[float], std: List[float]) -> A.Compose:
    return A.Compose(
        [
            A.Resize(img_size, img_size),
            A.Normalize(mean=mean, std=std),
            ToTensorV2(),
        ]
    )


class TestImageDataset(Dataset):
    def __init__(
        self,
        df_unique: pd.DataFrame,
        data_root: Path,
        transform: A.Compose,
        fallback_size: int = 256,
    ):
        self.df = df_unique.reset_index(drop=True)
        self.data_root = Path(data_root)
        self.transform = transform
        self.fallback_size = int(fallback_size)

        self.image_paths = self.df["image_path"].astype(str).values

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Dict:
        rel_path = self.image_paths[idx]
        img_path = self.data_root / rel_path

        try:
            with Image.open(img_path) as im:
                im = im.convert("RGB")
                img = np.array(im)
        except Exception:
            # hidden„ÅßÂ£ä„Çå/Ê¨†Êêç„ÅåÊ∑∑„Åò„Å£„Å¶„ÇÇËêΩ„Å°„Å™„ÅÑÔºà„Å®„Å´„Åã„ÅèÂÆåËµ∞ÂÑ™ÂÖàÔºâ
            img = np.zeros((self.fallback_size, self.fallback_size, 3), dtype=np.uint8)

        x = self.transform(image=img)["image"]
        return {"image": x, "image_path": rel_path}


# --------------------------
# Inference core
# --------------------------
@torch.inference_mode()
def predict_one_ckpt(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    use_amp: bool = True,
) -> Tuple[List[str], np.ndarray]:
    model.eval()

    paths_all: List[str] = []
    preds_all: List[np.ndarray] = []

    amp_enabled = bool(use_amp and device.type == "cuda")
    amp_device_type = device.type  # 'cuda' or 'cpu'

    for batch in tqdm(loader, desc="üîÆ infer", leave=False):
        x = batch["image"].to(device, non_blocking=(device.type == "cuda"))

        with autocast(device_type=amp_device_type, enabled=amp_enabled):
            pred = model(x)

        preds_all.append(pred.float().cpu().numpy())
        paths_all.extend(list(batch["image_path"]))

    preds = np.concatenate(preds_all, axis=0)
    return paths_all, preds


def ensemble_predict(
    cfg: Dict,
    ckpt_paths: List[Path],
    unique_test_df: pd.DataFrame,
    data_root: Path,
    device: torch.device,
    batch_size: int = 64,
    num_workers: int = 2,
    use_amp: bool = True,
    enforce_consistency: bool = True,  # GDM/Dry_Total„ÇíÂÜçË®àÁÆó„Åó„Å¶Êï¥ÂêàÊÄß„ÇíÂèñ„Çã
) -> pd.DataFrame:
    sep("üéØ Ensemble inference")

    # config
    img_size = int(cfg.get("img_size", 224))
    norm = cfg.get("normalize", {})
    mean = norm.get("mean", [0.485, 0.456, 0.406])
    std = norm.get("std", [0.229, 0.224, 0.225])

    # ‚úÖ target_cols „ÅØ config „ÇíÊ≠£„Å®„Åô„ÇãÔºàÈ†ÜÂ∫è„Ç∫„É¨Èò≤Ê≠¢Ôºâ
    target_cols = list(cfg.get("target_cols", []))
    if len(target_cols) == 0:
        # ÊúÄÁµÇfallback
        target_cols = ["Dry_Green_g", "Dry_Clover_g", "Dry_Dead_g", "GDM_g", "Dry_Total_g"]

    use_log1p_target = bool(cfg.get("use_log1p_target", True))

    print(f"[INFO] device={device}  amp={bool(use_amp and device.type=='cuda')}")
    print(f"[INFO] img_size={img_size}  targets={target_cols}  log1p={use_log1p_target}")
    print(f"[INFO] #ckpts={len(ckpt_paths)}")

    tfm = build_infer_transform(img_size, mean, std)
    ds = TestImageDataset(unique_test_df, data_root=data_root, transform=tfm)

    # hidden rerun„ÅØCPU„ÅÆÂèØËÉΩÊÄß„ÅåÈ´ò„ÅÑ ‚Üí num_workers=0„ÅåÂÆâÂÆö
    nw = int(num_workers) if device.type == "cuda" else 0
    loader = DataLoader(
        ds,
        batch_size=int(batch_size),
        shuffle=False,
        num_workers=nw,
        pin_memory=(device.type == "cuda"),
        drop_last=False,
    )

    pred_sum = None
    paths_ref = None

    for ckpt_path in ckpt_paths:
        # build model
        model_cfg = cfg.get("model", {})
        model = ConvNeXtRegressor(
            backbone=str(model_cfg.get("backbone", "convnext_base")),
            pretrained=False,
            num_targets=len(target_cols),
            in_chans=int(model_cfg.get("in_chans", 3)),
            drop_rate=float(model_cfg.get("drop_rate", 0.0)),
            drop_path_rate=float(model_cfg.get("drop_path_rate", 0.0)),
            head_dropout=float(model_cfg.get("head_dropout", 0.0)),
        ).to(device)

        # load checkpoint
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        state_dict = _pick_state_dict(ckpt)
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        if len(missing) > 0 or len(unexpected) > 0:
            print(f"[WARN] load_state_dict: missing={len(missing)} unexpected={len(unexpected)}")

        # predict
        paths, pred = predict_one_ckpt(model, loader, device=device, use_amp=use_amp)

        if paths_ref is None:
            paths_ref = paths
        else:
            # È†ÜÂ∫è„ÅÆ‰∏ÄËá¥„ÅåÂ¥©„Çå„Çã„Å®Ëá¥ÂëΩÁöÑ
            if paths_ref != paths:
                raise RuntimeError("Image order mismatch across checkpoints")

        pred_sum = pred if pred_sum is None else (pred_sum + pred)

        # cleanup
        del model
        if device.type == "cuda":
            torch.cuda.empty_cache()

    preds = pred_sum / float(len(ckpt_paths))

    # log1p -> raw
    if use_log1p_target:
        preds = np.expm1(np.clip(preds, -20.0, 20.0))

    # mass is non-negative
    preds = np.clip(preds, 0.0, None)

    pred_df = pd.DataFrame({"image_path": paths_ref})
    for j, col in enumerate(target_cols):
        pred_df[col] = preds[:, j].astype(np.float32)

    # optional: enforce consistency (GDM = Green + Clover, Total = GDM + Dead)
    if enforce_consistency:
        if {"Dry_Green_g", "Dry_Clover_g", "GDM_g"}.issubset(pred_df.columns):
            pred_df["GDM_g"] = pred_df["Dry_Green_g"] + pred_df["Dry_Clover_g"]
        if {"GDM_g", "Dry_Dead_g", "Dry_Total_g"}.issubset(pred_df.columns):
            pred_df["Dry_Total_g"] = pred_df["GDM_g"] + pred_df["Dry_Dead_g"]

    return pred_df


# --------------------------
# Submission
# --------------------------
def pick_ckpt_paths(artifact_root: Path, fold: str = "0") -> List[Path]:
    model_dir = artifact_root / "model"
    if not model_dir.exists():
        raise FileNotFoundError(f"model dir not found: {model_dir}")

    fold_u = str(fold).upper()
    if fold_u == "ALL":
        paths = sorted(model_dir.glob("best_fold*.pth"))
        if len(paths) == 0:
            paths = sorted(model_dir.glob("*.pth"))
        return paths

    # single fold
    p1 = model_dir / f"best_fold{fold}.pth"
    if p1.exists():
        return [p1]

    # fallback pattern
    paths = sorted(model_dir.glob(f"*fold{fold}*.pth"))
    if len(paths) > 0:
        return paths

    # final fallback
    paths = sorted(model_dir.glob("*.pth"))
    if len(paths) == 1:
        return paths

    raise FileNotFoundError(f"No checkpoint found for fold={fold} under {model_dir}")


def make_submission(
    artifact_root: Optional[str] = "/kaggle/input/csiro-artifacts",
    comp_root: Optional[str] = None,
    output_csv: str = "submission.csv",
    fold: str = "0",
    batch_size: int = 64,
    num_workers: int = 2,
    use_amp: bool = True,
    is_test: bool = True,
    device: Optional[str] = None,
    enforce_consistency: bool = True,
) -> pd.DataFrame:
    sep("üìù Make submission")

    # device
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    # CPU-onlyÊôÇ„ÅØAMP„ÇíÁÑ°ÂäπÂåñÔºàÂÆâÂÖ®Ôºâ
    use_amp = bool(use_amp and device.type == "cuda")

    # detect dirs
    comp_dir = find_comp_dir(comp_root)
    artifact_root_path = find_artifact_root(artifact_root)

    # choose files
    csv_name = "test.csv" if is_test else "train.csv"
    data_csv = comp_dir / csv_name
    sub_csv = comp_dir / "sample_submission.csv"
    if not data_csv.exists():
        raise FileNotFoundError(f"{csv_name} not found: {data_csv}")
    if not sub_csv.exists():
        raise FileNotFoundError(f"sample_submission.csv not found: {sub_csv}")

    # artifacts
    cfg_path = artifact_root_path / "yaml" / "config.yaml"
    if not cfg_path.exists():
        raise FileNotFoundError(f"config.yaml not found: {cfg_path}")
    cfg = _load_yaml(cfg_path)

    ckpt_paths = pick_ckpt_paths(artifact_root_path, fold=fold)
    if len(ckpt_paths) == 0:
        raise FileNotFoundError("No checkpoints found")

    print(f"[INFO] device={device} use_amp={use_amp}")
    print(f"[INFO] comp_dir={comp_dir}")
    print(f"[INFO] artifact_root={artifact_root_path}")
    print(f"[INFO] csv={data_csv.name}  sample_sub={sub_csv.name}")
    print(f"[INFO] ckpts={[p.name for p in ckpt_paths]}")

    # load data
    test_df = pd.read_csv(data_csv)
    sample_sub_df = pd.read_csv(sub_csv)

    # minimal validation
    required_cols = {"sample_id", "image_path", "target_name"}
    missing_cols = required_cols - set(test_df.columns)
    if len(missing_cols) > 0:
        raise KeyError(f"{csv_name} missing columns: {missing_cols}. got={test_df.columns.tolist()}")

    # unique images
    unique_test_df = test_df[["image_path"]].drop_duplicates().reset_index(drop=True)

    # inference
    pred_df = ensemble_predict(
        cfg=cfg,
        ckpt_paths=ckpt_paths,
        unique_test_df=unique_test_df,
        data_root=comp_dir,
        device=device,
        batch_size=batch_size,
        num_workers=num_workers,
        use_amp=use_amp,
        enforce_consistency=enforce_consistency,
    )

    # long merge: (image_path, target_name) -> target
    target_cols = list(cfg.get("target_cols", []))
    if len(target_cols) == 0:
        target_cols = ["Dry_Green_g", "Dry_Clover_g", "Dry_Dead_g", "GDM_g", "Dry_Total_g"]

    pred_long = pred_df.melt(
        id_vars=["image_path"],
        value_vars=[c for c in target_cols if c in pred_df.columns],
        var_name="target_name",
        value_name="target",
    )

    merged = test_df.merge(pred_long, on=["image_path", "target_name"], how="left")

    # submission df
    submission_df = merged[["sample_id", "target"]].copy()

    # ‚úÖ sample_submission „ÅÆÈ†Ü„Å´Êï¥ÂàóÔºàÈ†ÜÂ∫èÂâçÊèê„ÅÆÊé°ÁÇπ„ÇÇÊΩ∞„ÅôÔºâ
    submission_df = sample_sub_df[["sample_id"]].merge(
        submission_df, on="sample_id", how="left", sort=False
    )

    # ‚úÖ Ê¨†Êêç„ÇÑÈùûÊï∞„Åå„ÅÇ„Å£„Å¶„ÇÇÂÆåËµ∞ÔºàhiddenÂ∑ÆÁï∞ÂØæÁ≠ñÔºâ
    submission_df["target"] = pd.to_numeric(submission_df["target"], errors="coerce")
    n_nan = int(submission_df["target"].isna().sum())
    if n_nan > 0:
        print(f"[WARN] NaN targets after merge: {n_nan} -> fill 0.0")
    submission_df["target"] = submission_df["target"].fillna(0.0)
    submission_df["target"] = np.where(
        np.isfinite(submission_df["target"].values),
        submission_df["target"].values,
        0.0
    )

    # save
    submission_df.to_csv(output_csv, index=False)
    print(f"[OK] saved: {output_csv}  rows={len(submission_df)}")
    return submission_df


# ============================================================
# Run
# ============================================================
seed_everything(1129)

# „Åì„Åì„Å†„ÅëÂøÖË¶Å„Å´Âøú„Åò„Å¶Â§â„Åà„Å¶„Åè„Å†„Åï„ÅÑ
submission = make_submission(
    artifact_root="/kaggle/input/csiro-artifacts",  # Ëá™ÂãïÊ§úÂá∫„Åó„Åü„Åë„Çå„Å∞ None „Å´„Åó„Å¶OK
    comp_root=None,                                # Ëá™ÂãïÊ§úÂá∫
    output_csv="submission.csv",
    fold="0",                                      # "ALL" „ÅßÂÖ®fold„Ç¢„É≥„Çµ„É≥„Éñ„É´
    batch_size=64,
    num_workers=2,                                 # CPUÊôÇ„ÅØËá™Âãï„Åß0„Å´„Å™„Çä„Åæ„Åô
    use_amp=True,                                  # CPUÊôÇ„ÅØËá™Âãï„ÅßFalse„Å´„Å™„Çä„Åæ„Åô
    is_test=True,
    enforce_consistency=False,
)

# ÊúÄÁµÇ„ÉÅ„Çß„ÉÉ„ÇØÔºàËªΩ„ÇÅÔºâ
sep("üìã Sanity check")
print("rows:", len(submission))
print("cols:", submission.columns.tolist())
print("nan:", int(submission.isna().sum().sum()))
print("min/max:", float(submission["target"].min()), float(submission["target"].max()))
print(submission.head())

  data = fetch_version_info()


üå± Seed fixed: 1129
üìù Make submission
[INFO] comp_dir auto-detected: /kaggle/input/csiro-biomass
[INFO] device=cuda use_amp=True
[INFO] comp_dir=/kaggle/input/csiro-biomass
[INFO] artifact_root=/kaggle/input/csiro-artifacts
[INFO] csv=test.csv  sample_sub=sample_submission.csv
[INFO] ckpts=['best_fold0.pth']
üéØ Ensemble inference
[INFO] device=cuda  amp=True
[INFO] img_size=224  targets=['Dry_Green_g', 'Dry_Clover_g', 'Dry_Dead_g', 'GDM_g', 'Dry_Total_g']  log1p=True
[INFO] #ckpts=1
[WARN] load_state_dict: missing=342 unexpected=342


üîÆ infer:   0%|          | 0/1 [00:00<?, ?it/s]

[OK] saved: submission.csv  rows=5
üìã Sanity check
rows: 5
cols: ['sample_id', 'target']
nan: 0
min/max: 0.0 16.72542381286621
                    sample_id     target
0  ID1001187975__Dry_Clover_g   0.000000
1    ID1001187975__Dry_Dead_g   0.250154
2   ID1001187975__Dry_Green_g  16.725424
3   ID1001187975__Dry_Total_g   0.000000
4         ID1001187975__GDM_g   0.391756
