# Evaluate TTA
---
- TTAによる精度確認をする

- 使い方（重要）
    - これで出る weighted_r2_oof が「他の参加者が言うCV」に一番近いです.  
    （foldごとの値の平均じゃなく、全OOF結合の1発計算）
    - TTAはまず HFlipだけをおすすめ（VFlipが悪化するケースもあるのでOOFで確認）
    - 平均は今コードでは log空間平均です
raw空間平均も試したければ、pred_log_stack を expm1 して raw で平均 → log1p に戻す、に変える版も作れます（必要なら言ってください）

In [1]:
# =========================================================
# 0. Setup
# =========================================================
from __future__ import annotations

import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn

# --- path: src を import できるように ---
PROJECT_DIR = Path("/mnt/nfs/home/hidebu/study/CSIRO---Image2Biomass-Prediction")
SRC_DIR = PROJECT_DIR / "src"
assert SRC_DIR.exists(), SRC_DIR
import sys
sys.path.append(str(SRC_DIR))

# --- your modules ---
from omegaconf import OmegaConf
from datasets.dataset import CsiroDataset
from datasets.transforms import build_transforms
from utils.metric import global_weighted_r2_score, r2_per_target

# model (new)
from models.biomass_mil_hurdle import BiomassConvNeXtMILHurdle
# model (old) - 必要なら
from models.convnext_regressor import ConvNeXtRegressor

# original
import sys
sys.path.append("../src")
from utils.data import sep, show_df, glob_walk, set_seed, save_config_yaml, dict_to_namespace
from utils.metric import global_weighted_r2_score, r2_per_target

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

# 速度/再現性（必要に応じて）
torch.backends.cudnn.benchmark = True

DEVICE: cuda


In [19]:
# =========================================================
# 1. Config / Paths
# =========================================================
EXP = "101_train_exp017"
MODEL = "BiomassConvNeXtMILHurdle" # "BiomassConvNeXtMILHurdle"| "ConvNeXtRegressor"

# ★ここを変える（評価したい exp ディレクトリ）
EXP_DIR = PROJECT_DIR / "experiments" / EXP
CFG_PATH = EXP_DIR / "yaml" / "config.yaml"
assert CFG_PATH.exists(), CFG_PATH

cfg = OmegaConf.load(CFG_PATH)
print("Loaded cfg:", CFG_PATH)

# 学習で使った preprocess pivot を読む（fold列が入っている）
pp_dir = Path(cfg.pp_dir) / str(cfg.preprocess_ver)
pivot_path = pp_dir / str(cfg.pivot_csv_name)
assert pivot_path.exists(), pivot_path
df = pd.read_csv(pivot_path)
print("df:", df.shape)
df.head()

Loaded cfg: /mnt/nfs/home/hidebu/study/CSIRO---Image2Biomass-Prediction/experiments/101_train_exp017/yaml/config.yaml
df: (357, 36)


Unnamed: 0,image_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,Dry_Clover_g,Dry_Dead_g,Dry_Green_g,...,is_Lucerne,is_Mixed,is_Phalaris,is_Ryegrass,is_SilverGrass,is_SpearGrass,is_SubcloverDalkeith,is_SubcloverLosa,is_WhiteClover,Fold
0,ID1011485656,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,0.0,31.9984,16.2751,...,False,False,False,True,False,False,False,False,False,3
1,ID1012260530,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,0.0,0.0,7.6,...,True,False,False,False,False,False,False,False,False,2
2,ID1025234388,train/ID1025234388.jpg,2015/9/1,WA,SubcloverDalkeith,0.38,1.0,6.05,0.0,0.0,...,False,False,False,False,False,False,True,False,False,1
3,ID1028611175,train/ID1028611175.jpg,2015/5/18,Tas,Ryegrass,0.66,5.0,0.0,30.9703,24.2376,...,False,False,False,True,False,False,False,False,False,0
4,ID1035947949,train/ID1035947949.jpg,2015/9/11,Tas,Ryegrass,0.54,3.5,0.4343,23.2239,10.5261,...,False,False,False,True,False,False,False,False,False,3


In [20]:
# =========================================================
# 2. Helpers (model / tta / predict)
# =========================================================
def build_model_from_cfg(cfg) -> nn.Module:
    """cfg からモデルを構築する（新モデル: BiomassConvNeXtMILHurdle 想定）。

    Notes:
        - 旧モデルも評価したい場合はここで分岐を足してください。
        - ここでは config.yaml の `model.*` をそのまま使う方針です。
    """
    mcfg = cfg.model  # ← ここが model セクション

    if MODEL == "ConvNeXtRegressor":
        # ✅ target_cols は model 配下ではなく cfg 直下
        num_targets = len(cfg.target_cols)

        return ConvNeXtRegressor(
            backbone=str(getattr(mcfg, "backbone", "convnext_base")),
            pretrained=False,  # ✅ 評価はFalse（ckptロード前提、DL回避）
            num_targets=num_targets,
            in_chans=int(getattr(mcfg, "in_chans", 3)),
            drop_rate=float(getattr(mcfg, "drop_rate", 0.0)),
            drop_path_rate=float(getattr(mcfg, "drop_path_rate", 0.0)),
            head_dropout=float(getattr(mcfg, "head_dropout", 0.0)),
        )
    elif MODEL == "BiomassConvNeXtMILHurdle":
        return BiomassConvNeXtMILHurdle(
            backbone_name=str(mcfg.backbone),
            pretrained=False,  # ckptから読むのでFalse推奨
            in_chans=int(mcfg.in_chans),
            pool_dropout=float(getattr(mcfg, "pool_dropout", 0.0)),
            pool_temperature=float(getattr(mcfg, "pool_temperature", 1.0)),
            mil_mode=str(getattr(mcfg, "mil_mode", "gated_attn")),
            mil_attn_dim=int(getattr(mcfg, "mil_attn_dim", 256)),
            mil_dropout=float(getattr(mcfg, "mil_dropout", 0.0)),
            head_hidden_dim=int(getattr(mcfg, "head_hidden_dim", 512)),
            head_dropout=float(getattr(mcfg, "head_dropout", 0.2)),
            return_attention=bool(getattr(mcfg, "return_attention", False)),
        )
    else:
        raise ValueError(f"Unknown MODEL={MODEL}")

def load_ckpt(model: nn.Module, ckpt_path: Path, device: torch.device) -> nn.Module:
    """チェックポイントをロードして eval モードにする。"""
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

    # state_dict 抽出
    if isinstance(ckpt, dict):
        state = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt.get("model") or ckpt
    else:
        state = ckpt

    # DDP の module. を剥がす（必要な場合）
    if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
        state = {k.replace("module.", "", 1): v for k, v in state.items()}

    model.load_state_dict(state, strict=True)  # ← weights_only はここじゃない
    model.to(device)
    model.eval()
    return model


@torch.no_grad()
def _forward_to_pred_log1p(model_out) -> torch.Tensor:
    """モデル出力から log1p 予測テンソル (B,5) を取り出す。

    - 新モデル: dict で "pred_log1p" が入っている
    - 旧モデル: tensor そのものが log1p (B,5) の想定
    """
    if isinstance(model_out, dict):
        return model_out["pred_log1p"]
    return model_out


def _tta_apply(x: torch.Tensor, tta: str) -> List[torch.Tensor]:
    """入力テンソルにTTAを適用してバリエーションを返す。

    Args:
        x: (B,C,H,W) または (B,M,C,H,W)
        tta: "none" | "hflip" | "hflip_vflip"

    Returns:
        xs: xのリスト（複数ビュー）
    """
    if tta == "none":
        return [x]

    # hflip: 最後の次元(-1)反転
    def hflip(t: torch.Tensor) -> torch.Tensor:
        return torch.flip(t, dims=[-1])

    # vflip: 高さ方向(-2)反転
    def vflip(t: torch.Tensor) -> torch.Tensor:
        return torch.flip(t, dims=[-2])

    if tta == "hflip":
        return [x, hflip(x)]

    if tta == "hflip_vflip":
        return [x, hflip(x), vflip(x), vflip(hflip(x))]

    raise ValueError(f"Unknown tta={tta}")


@torch.no_grad()
def predict_loader(
    model: nn.Module,
    loader,
    device: torch.device,
    use_amp: bool = True,
    tta: str = "none",
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """loaderを回して (pred_log1p, target_log1p, ids) を返す。

    Notes:
        - predはTTA平均後の log1p
        - targetは dataset が返す (log1p target) をそのまま
    """
    preds_log_list: List[np.ndarray] = []
    targs_log_list: List[np.ndarray] = []
    ids_all: List[str] = []

    pbar = tqdm(loader, total=len(loader))
    for batch in pbar:
        x = batch["image"].to(device, non_blocking=True)     # (B,C,H,W) or (B,M,C,H,W)
        y = batch["target"].to(device, non_blocking=True)    # (B,5) log1p
        ids = batch["id"]

        if isinstance(ids, (list, tuple, np.ndarray)):
            ids_all.extend([str(v) for v in ids])
        else:
            ids_all.append(str(ids))

        xs = _tta_apply(x, tta=tta)

        with torch.amp.autocast(device_type="cuda", enabled=(use_amp and device.type == "cuda")):
            pred_log_stack = []
            for x_aug in xs:
                out = model(x_aug)
                pred_log = _forward_to_pred_log1p(out)  # (B,5)
                pred_log_stack.append(pred_log)

            # TTA平均（log空間平均）
            pred_log_mean = torch.stack(pred_log_stack, dim=0).mean(dim=0)

        preds_log_list.append(pred_log_mean.detach().float().cpu().numpy())
        targs_log_list.append(y.detach().float().cpu().numpy())

    preds_log = np.concatenate(preds_log_list, axis=0)
    targs_log = np.concatenate(targs_log_list, axis=0)
    return preds_log, targs_log, ids_all

In [21]:
# =========================================================
# 3. OOF evaluation (per fold + all folds combined)
# =========================================================
from torch.utils.data import DataLoader

def make_valid_loader(val_df: pd.DataFrame, cfg, batch_size: int) -> DataLoader:
    """valid用DataLoaderを作る（学習と同じvalid transformを使う）"""
    valid_tfm = build_transforms(cfg, is_train=False)
    ds = CsiroDataset(
        df=val_df.reset_index(drop=True),
        image_root=str(cfg.input_dir),
        target_cols=cfg.target_cols,
        transform=valid_tfm,
        use_log1p_target=bool(cfg.use_log1p_target),
        return_target=True,
    )
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=int(getattr(cfg, "num_workers", 4)),
        pin_memory=bool(getattr(cfg, "pin_memory", True)),
        persistent_workers=bool(getattr(cfg, "persistent_workers", True)),
        drop_last=False,
    )
    return loader


def log_to_raw(x_log: np.ndarray, clip_min: float = -20.0, clip_max: float = 20.0) -> np.ndarray:
    """log1p -> raw"""
    x_log = np.clip(x_log, clip_min, clip_max)
    x_raw = np.expm1(x_log)
    x_raw = np.nan_to_num(x_raw, nan=0.0, posinf=0.0, neginf=0.0)
    x_raw = np.clip(x_raw, 0.0, None)
    return x_raw


def eval_oof(
    df: pd.DataFrame,
    cfg,
    exp_dir: Path,
    folds: List[int],
    tta_modes: List[str] = ["none", "hflip"],
    batch_size: Optional[int] = None,
) -> Dict[str, Dict[str, float]]:
    """各TTAモードでOOF評価（fold結合）を行う。

    Returns:
        results[tta] = {
            "weighted_r2_oof": ...,
            "r2_mean": ...（参考：単純平均）
            "r2_<target>": ...（target別）
        }
    """
    fold_col = str(cfg.fold_col)
    weights = np.asarray(cfg.metric.weights, dtype=np.float64)
    target_names = list(cfg.target_cols)

    if batch_size is None:
        batch_size = int(cfg.train.batch_size) if hasattr(cfg, "train") else 32

    results: Dict[str, Dict[str, float]] = {}

    for tta in tta_modes:
        print(f"\n===== TTA: {tta} =====")
        all_preds_log = []
        all_targs_log = []
        all_ids = []

        # fold loop
        for fold in folds:
            ckpt_path = exp_dir / "model" / f"best_fold{fold}.pth"
            assert ckpt_path.exists(), ckpt_path

            model = build_model_from_cfg(cfg)
            model = load_ckpt(model, ckpt_path, device=DEVICE)

            val_df = df[df[fold_col] == fold].reset_index(drop=True)
            loader = make_valid_loader(val_df, cfg, batch_size=batch_size)

            preds_log, targs_log, ids = predict_loader(
                model=model,
                loader=loader,
                device=DEVICE,
                use_amp=bool(getattr(cfg, "use_amp", True)),
                tta=tta,
            )

            all_preds_log.append(preds_log)
            all_targs_log.append(targs_log)
            all_ids.extend(ids)

            # fold単体も見たい場合はここで計算して print してOK
            preds_raw_fold = log_to_raw(preds_log)
            targs_raw_fold = log_to_raw(targs_log)
            r2_fold = global_weighted_r2_score(targs_raw_fold, preds_raw_fold, weights)
            print(f"  fold{fold} weighted_r2={r2_fold:.6f} (n={len(val_df)})")

        # --- all folds combined OOF ---
        preds_log_all = np.concatenate(all_preds_log, axis=0)
        targs_log_all = np.concatenate(all_targs_log, axis=0)

        preds_raw = log_to_raw(preds_log_all)
        targs_raw = log_to_raw(targs_log_all)

        weighted_r2 = global_weighted_r2_score(targs_raw, preds_raw, weights)
        r2_each = r2_per_target(targs_raw, preds_raw)

        out = {
            "weighted_r2_oof": float(weighted_r2),
            "r2_mean": float(np.mean(r2_each)),
        }
        for name, r2v in zip(target_names, r2_each):
            out[f"r2_{name}"] = float(r2v)

        results[tta] = out

        print("  ---- OOF (all folds combined) ----")
        print(f"  weighted_r2_oof: {weighted_r2:.6f}")
        for name, r2v in zip(target_names, r2_each):
            print(f"  r2_{name}: {r2v:.6f}")

    return results

In [22]:
# =========================================================
# 4. Run evaluation
# =========================================================
folds = list(cfg.folds)  # config.yaml 内に保存されている想定
# もし config に folds がない場合は手動で:
# folds = [0,1,2,3,4]

res = eval_oof(
    df=df,
    cfg=cfg,
    exp_dir=EXP_DIR,
    folds=folds,
    tta_modes=["none", "hflip", "hflip_vflip"],  # 必要なものだけ
    batch_size=int(cfg.train.batch_size) if hasattr(cfg, "train") else 32,
)

res


===== TTA: none =====


100%|██████████| 2/2 [00:07<00:00,  3.58s/it]


  fold0 weighted_r2=0.595501 (n=99)


100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


  fold1 weighted_r2=0.760625 (n=85)


100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


  fold2 weighted_r2=0.679762 (n=88)


100%|██████████| 2/2 [00:05<00:00,  2.92s/it]


  fold3 weighted_r2=0.756235 (n=85)
  ---- OOF (all folds combined) ----
  weighted_r2_oof: 0.699859
  r2_Dry_Green_g: 0.636151
  r2_Dry_Clover_g: 0.564040
  r2_Dry_Dead_g: 0.357449
  r2_GDM_g: 0.645936
  r2_Dry_Total_g: 0.597993

===== TTA: hflip =====


100%|██████████| 2/2 [00:05<00:00,  2.65s/it]


  fold0 weighted_r2=0.603286 (n=99)


100%|██████████| 2/2 [00:05<00:00,  2.55s/it]


  fold1 weighted_r2=0.768903 (n=85)


100%|██████████| 2/2 [00:06<00:00,  3.19s/it]


  fold2 weighted_r2=0.664434 (n=88)


100%|██████████| 2/2 [00:05<00:00,  2.84s/it]


  fold3 weighted_r2=0.758813 (n=85)
  ---- OOF (all folds combined) ----
  weighted_r2_oof: 0.700594
  r2_Dry_Green_g: 0.632436
  r2_Dry_Clover_g: 0.557887
  r2_Dry_Dead_g: 0.354836
  r2_GDM_g: 0.641310
  r2_Dry_Total_g: 0.601917

===== TTA: hflip_vflip =====


100%|██████████| 2/2 [00:04<00:00,  2.33s/it]


  fold0 weighted_r2=0.617292 (n=99)


100%|██████████| 2/2 [00:06<00:00,  3.16s/it]


  fold1 weighted_r2=0.771007 (n=85)


100%|██████████| 2/2 [00:05<00:00,  2.59s/it]


  fold2 weighted_r2=0.678001 (n=88)


100%|██████████| 2/2 [00:06<00:00,  3.17s/it]


  fold3 weighted_r2=0.774796 (n=85)
  ---- OOF (all folds combined) ----
  weighted_r2_oof: 0.711288
  r2_Dry_Green_g: 0.644527
  r2_Dry_Clover_g: 0.549083
  r2_Dry_Dead_g: 0.360797
  r2_GDM_g: 0.653937
  r2_Dry_Total_g: 0.617958


{'none': {'weighted_r2_oof': 0.6998589374072298,
  'r2_mean': 0.5603139586646169,
  'r2_Dry_Green_g': 0.6361510130025221,
  'r2_Dry_Clover_g': 0.5640404645341136,
  'r2_Dry_Dead_g': 0.3574490866270049,
  'r2_GDM_g': 0.6459361648535427,
  'r2_Dry_Total_g': 0.5979930643059006},
 'hflip': {'weighted_r2_oof': 0.7005943446475429,
  'r2_mean': 0.5576771277843997,
  'r2_Dry_Green_g': 0.6324355076031747,
  'r2_Dry_Clover_g': 0.5578871612583497,
  'r2_Dry_Dead_g': 0.3548356509402033,
  'r2_GDM_g': 0.6413099043689443,
  'r2_Dry_Total_g': 0.6019174147513269},
 'hflip_vflip': {'weighted_r2_oof': 0.7112880906628143,
  'r2_mean': 0.5652602686694481,
  'r2_Dry_Green_g': 0.6445265066559239,
  'r2_Dry_Clover_g': 0.5490831411971732,
  'r2_Dry_Dead_g': 0.3607965127046058,
  'r2_GDM_g': 0.6539368063300139,
  'r2_Dry_Total_g': 0.6179583764595241}}