In [1]:
# =============================================================================
# CSIRO Image2Biomass - v4 CrossPVT T2T Mamba Inference (Self-contained)
# -----------------------------------------------------------------------------
# - 复刻 train.py 中的 CrossPVT_T2T_MambaDINO 结构
# - 直接从 checkpoint 中读取 cfg 覆盖本地 CFG，保证结构一致
# - 5-fold ensemble + TTA (原图/水平翻转/垂直翻转)
# - 自动处理 DataParallel 的 module. 前缀
# - 使用 parse_known_args() 规避 Kaggle/Colab 多余命令行参数
# - 输出 submission.csv
# =============================================================================

import os
import gc
import math
import argparse
import logging
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional

import cv2
import timm
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

# =============================================================================
# Logger
# =============================================================================
LOGGER = logging.getLogger("csiro_infer_v4")
if not LOGGER.handlers:
    LOGGER.addHandler(logging.StreamHandler())
LOGGER.setLevel(logging.INFO)


# =============================================================================
# 训练时使用的配置（会被 checkpoint 中的 cfg 覆盖）
# =============================================================================
@dataclass
class TrainCFG:
    dropout: float = 0.1
    hidden_ratio: float = 0.35

    # DINO backbone 相关
    dino_candidates: Tuple[str, ...] = (
        "vit_base_patch14_dinov2",
        "vit_base_patch14_reg4_dinov2",
        "vit_small_patch14_dinov2",
    )
    small_grid: Tuple[int, int] = (4, 4)
    big_grid: Tuple[int, int] = (2, 2)
    t2t_depth: int = 2
    cross_layers: int = 2
    cross_heads: int = 6

    # Pyramid + Mamba
    pyramid_dims: Tuple[int, int, int] = (384, 512, 640)
    mobilevit_heads: int = 4
    mobilevit_depth: int = 2
    sra_heads: int = 8
    sra_ratio: int = 2
    mamba_depth: int = 3
    mamba_kernel: int = 5
    aux_head: bool = True
    aux_loss_weight: float = 0.4

    # 目标列（用于辅助头 / 5 目标打包）
    ALL_TARGET_COLS: Tuple[str, ...] = (
        "Dry_Green_g",
        "Dry_Dead_g",
        "Dry_Clover_g",
        "GDM_g",
        "Dry_Total_g",
    )


CFG = TrainCFG()


def update_cfg_from_checkpoint(cfg_dict: dict):
    """
    用 checkpoint 中保存的 cfg 覆盖当前 CFG 的同名字段，
    确保推理结构与训练完全一致。
    """
    global CFG
    if not cfg_dict:
        return
    for k, v in cfg_dict.items():
        if hasattr(CFG, k):
            setattr(CFG, k, v)


# =============================================================================
# Model blocks（与 train.py 保持一致的简化实现）
# =============================================================================
class FeedForward(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        hid = int(dim * mlp_ratio)
        self.net = nn.Sequential(
            nn.Linear(dim, hid),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hid, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class AttentionBlock(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.0, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, mlp_ratio=mlp_ratio, dropout=dropout)

    def forward(self, x):
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = x + attn_out
        x = x + self.ff(self.norm2(x))
        return x


class MobileViTBlock(nn.Module):
    """
    轻量 MobileViT：局部 CNN + 小型 Transformer（token 化和 fold back）
    """

    def __init__(self, dim, heads=4, depth=2, patch=(2, 2), dropout=0.0):
        super().__init__()
        self.local = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1, groups=dim),
            nn.Conv2d(dim, dim, 1),
            nn.GELU(),
        )
        self.patch = patch
        self.transformer = nn.ModuleList(
            [AttentionBlock(dim, heads=heads, dropout=dropout, mlp_ratio=2.0) for _ in range(depth)]
        )
        self.fuse = nn.Conv2d(dim * 2, dim, kernel_size=1)

    def forward(self, x: torch.Tensor):
        local_feat = self.local(x)
        B, C, H, W = local_feat.shape
        ph, pw = self.patch
        new_h = math.ceil(H / ph) * ph
        new_w = math.ceil(W / pw) * pw
        if new_h != H or new_w != W:
            local_feat = F.interpolate(local_feat, size=(new_h, new_w), mode="bilinear", align_corners=False)
            H, W = new_h, new_w

        tokens = local_feat.unfold(2, ph, ph).unfold(3, pw, pw)  # B,C,nh,nw,ph,pw
        tokens = tokens.contiguous().view(B, C, -1, ph, pw)
        tokens = tokens.permute(0, 2, 3, 4, 1).reshape(B, -1, C)

        for blk in self.transformer:
            tokens = blk(tokens)

        feat = tokens.view(B, -1, ph * pw, C).permute(0, 3, 1, 2)
        nh = H // ph
        nw = W // pw
        feat = feat.view(B, C, nh, nw, ph, pw).permute(0, 1, 2, 4, 3, 5)
        feat = feat.reshape(B, C, H, W)

        if feat.shape[-2:] != x.shape[-2:]:
            feat = F.interpolate(feat, size=x.shape[-2:], mode="bilinear", align_corners=False)

        out = self.fuse(torch.cat([x, feat], dim=1))
        return out


class SpatialReductionAttention(nn.Module):
    def __init__(self, dim, heads=8, sr_ratio=2, dropout=0.0):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, dim * 2)
        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)
        else:
            self.sr = None
        self.proj = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, hw: Tuple[int, int]):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.heads, C // self.heads).permute(0, 2, 1, 3)

        if self.sr is not None:
            H, W = hw
            feat = x.transpose(1, 2).reshape(B, C, H, W)
            feat = self.sr(feat)
            feat = feat.reshape(B, C, -1).transpose(1, 2)
            feat = self.norm(feat)
        else:
            feat = x

        kv = self.kv(feat)
        k, v = kv.chunk(2, dim=-1)
        k = k.reshape(B, -1, self.heads, C // self.heads).permute(0, 2, 3, 1)
        v = v.reshape(B, -1, self.heads, C // self.heads).permute(0, 2, 1, 3)

        attn = torch.matmul(q, k) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.drop(attn)
        out = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(B, N, C)
        out = self.proj(out)
        return out


class PVTBlock(nn.Module):
    def __init__(self, dim, heads=8, sr_ratio=2, dropout=0.0, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.sra = SpatialReductionAttention(dim, heads=heads, sr_ratio=sr_ratio, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, mlp_ratio=mlp_ratio, dropout=dropout)

    def forward(self, x, hw: Tuple[int, int]):
        x = x + self.sra(self.norm1(x), hw)
        x = x + self.ff(self.norm2(x))
        return x


class LocalMambaBlock(nn.Module):
    """
    简化版 local Mamba：DW-Conv + gating + 线性映射
    """

    def __init__(self, dim, kernel_size=5, dropout=0.0):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
        self.gate = nn.Linear(dim, dim)
        self.proj = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        shortcut = x
        x = self.norm(x)
        g = torch.sigmoid(self.gate(x))
        x = (x * g).transpose(1, 2)  # B, C, N
        x = self.dwconv(x).transpose(1, 2)
        x = self.proj(x)
        x = self.drop(x)
        return shortcut + x


class T2TRetokenizer(nn.Module):
    """
    将 4x4 tile token 做局部 attention + 下采样到 2x2
    """

    def __init__(self, dim, depth=2, heads=4, dropout=0.0):
        super().__init__()
        self.blocks = nn.ModuleList(
            [AttentionBlock(dim, heads=heads, dropout=dropout, mlp_ratio=2.0) for _ in range(depth)]
        )

    def forward(self, tokens: torch.Tensor, grid_hw: Tuple[int, int]):
        B, T, C = tokens.shape
        H, W = grid_hw
        feat_map = tokens.transpose(1, 2).reshape(B, C, H, W)
        seq = feat_map.flatten(2).transpose(1, 2)
        for blk in self.blocks:
            seq = blk(seq)
        seq_map = seq.transpose(1, 2).reshape(B, C, H, W)
        pooled = F.adaptive_avg_pool2d(seq_map, (2, 2))
        retokens = pooled.flatten(2).transpose(1, 2)
        return retokens, seq_map


class CrossScaleFusion(nn.Module):
    def __init__(self, dim, heads=6, dropout=0.0, layers=2):
        super().__init__()
        self.layers_s = nn.ModuleList(
            [AttentionBlock(dim, heads=heads, dropout=dropout, mlp_ratio=2.0) for _ in range(layers)]
        )
        self.layers_b = nn.ModuleList(
            [AttentionBlock(dim, heads=heads, dropout=dropout, mlp_ratio=2.0) for _ in range(layers)]
        )
        self.cross_s = nn.ModuleList(
            [
                nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True, kdim=dim, vdim=dim)
                for _ in range(layers)
            ]
        )
        self.cross_b = nn.ModuleList(
            [
                nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True, kdim=dim, vdim=dim)
                for _ in range(layers)
            ]
        )
        self.norm_s = nn.LayerNorm(dim)
        self.norm_b = nn.LayerNorm(dim)

    def forward(self, tok_s: torch.Tensor, tok_b: torch.Tensor):
        B, Ts, C = tok_s.shape
        Tb = tok_b.shape[1]
        cls_s = tok_s.new_zeros(B, 1, C)
        cls_b = tok_b.new_zeros(B, 1, C)
        tok_s = torch.cat([cls_s, tok_s], dim=1)
        tok_b = torch.cat([cls_b, tok_b], dim=1)

        for ls, lb, cs, cb in zip(self.layers_s, self.layers_b, self.cross_s, self.cross_b):
            tok_s = ls(tok_s)
            tok_b = lb(tok_b)
            q_s = self.norm_s(tok_s[:, :1])
            q_b = self.norm_b(tok_b[:, :1])
            cls_s_upd, _ = cs(
                q_s,
                torch.cat([tok_b, q_b], dim=1),
                torch.cat([tok_b, q_b], dim=1),
                need_weights=False,
            )
            cls_b_upd, _ = cb(
                q_b,
                torch.cat([tok_s, q_s], dim=1),
                torch.cat([tok_s, q_s], dim=1),
                need_weights=False,
            )
            tok_s = torch.cat([tok_s[:, :1] + cls_s_upd, tok_s[:, 1:]], dim=1)
            tok_b = torch.cat([tok_b[:, :1] + cls_b_upd, tok_b[:, 1:]], dim=1)

        tokens = torch.cat([tok_s[:, :1], tok_b[:, :1], tok_s[:, 1:], tok_b[:, 1:]], dim=1)
        return tokens  # shape ~ (B, 2 + Ts + Tb, C)


class TileEncoder(nn.Module):
    def __init__(self, backbone: nn.Module, input_res: int):
        super().__init__()
        self.backbone = backbone
        self.input_res = input_res

    def forward(self, x: torch.Tensor, grid: Tuple[int, int]):
        B, C, H, W = x.shape
        r, c = grid
        hs = torch.linspace(0, H, steps=r + 1, device=x.device).round().long()
        ws = torch.linspace(0, W, steps=c + 1, device=x.device).round().long()
        tiles = []
        for i in range(r):
            for j in range(c):
                rs, re = hs[i].item(), hs[i + 1].item()
                cs, ce = ws[j].item(), ws[j + 1].item()
                xt = x[:, :, rs:re, cs:ce]
                if xt.shape[-2:] != (self.input_res, self.input_res):
                    xt = F.interpolate(xt, size=(self.input_res, self.input_res), mode="bilinear", align_corners=False)
                tiles.append(xt)
        tiles = torch.stack(tiles, dim=1)  # (B, T, C, H, W)
        flat = tiles.view(-1, C, self.input_res, self.input_res)
        feats = self.backbone(flat)
        feats = feats.view(B, -1, feats.shape[-1])
        return feats


class PyramidMixer(nn.Module):
    def __init__(
        self,
        dim_in: int,
        dims: Tuple[int, int, int],
        mobilevit_heads: int = 4,
        mobilevit_depth: int = 2,
        sra_heads: int = 6,
        sra_ratio: int = 2,
        mamba_depth: int = 3,
        mamba_kernel: int = 5,
        dropout: float = 0.0,
    ):
        super().__init__()
        c1, c2, c3 = dims
        self.proj1 = nn.Linear(dim_in, c1)
        self.mobilevit = MobileViTBlock(c1, heads=mobilevit_heads, depth=mobilevit_depth, dropout=dropout)
        self.proj2 = nn.Linear(c1, c2)
        self.pvt = PVTBlock(c2, heads=sra_heads, sr_ratio=sra_ratio, dropout=dropout, mlp_ratio=3.0)
        self.mamba_local = LocalMambaBlock(c2, kernel_size=mamba_kernel, dropout=dropout)
        self.proj3 = nn.Linear(c2, c3)
        self.mamba_global = nn.ModuleList(
            [LocalMambaBlock(c3, kernel_size=mamba_kernel, dropout=dropout) for _ in range(mamba_depth)]
        )
        self.final_attn = AttentionBlock(c3, heads=min(8, c3 // 64 + 1), dropout=dropout, mlp_ratio=2.0)

    def _tokens_to_map(self, tokens: torch.Tensor, target_hw: Tuple[int, int]):
        B, N, C = tokens.shape
        H, W = target_hw
        need = H * W
        if N < need:
            pad = tokens.new_zeros(B, need - N, C)
            tokens = torch.cat([tokens, pad], dim=1)
        tokens = tokens[:, :need, :]
        feat_map = tokens.transpose(1, 2).reshape(B, C, H, W)
        return feat_map

    @staticmethod
    def _fit_hw(n_tokens: int) -> Tuple[int, int]:
        """选择一个接近方形、满足 h*w>=n_tokens 的网格。"""
        h = int(math.sqrt(n_tokens))
        w = h
        while h * w < n_tokens:
            w += 1
            if h * w < n_tokens:
                h += 1
        return h, w

    def forward(self, tokens: torch.Tensor):
        # 约 10 tokens -> 3x4 map
        B, N, C = tokens.shape
        map_hw = (3, 4)
        feat_map = self._tokens_to_map(tokens, map_hw)

        t1 = self.proj1(tokens)
        m1 = self._tokens_to_map(t1, map_hw)
        m1 = self.mobilevit(m1)
        t1_out = m1.flatten(2).transpose(1, 2)[:, :N]

        # Stage2: 下采样 token 数量（平均池化）
        t2 = self.proj2(t1_out)
        new_len = max(4, N // 2)
        t2 = t2[:, :new_len] + F.adaptive_avg_pool1d(t2.transpose(1, 2), new_len).transpose(1, 2)
        hw2 = self._fit_hw(t2.size(1))
        if t2.size(1) < hw2[0] * hw2[1]:
            pad = t2.new_zeros(B, hw2[0] * hw2[1] - t2.size(1), t2.size(2))
            t2 = torch.cat([t2, pad], dim=1)
        t2 = self.pvt(t2, hw2)
        t2 = self.mamba_local(t2)

        # Stage3: 全局
        t3 = self.proj3(t2)
        pooled = torch.stack([t3.mean(dim=1), t3.max(dim=1).values], dim=1)  # (B,2,C)
        t3 = pooled
        for blk in self.mamba_global:
            t3 = blk(t3)
        t3 = self.final_attn(t3)
        global_feat = t3.mean(dim=1)
        return global_feat, {"stage1_map": m1.detach(), "stage2_tokens": t2.detach(), "stage3_tokens": t3.detach()}


class CrossPVT_T2T_MambaDINO(nn.Module):
    def __init__(self, dropout: float = 0.1, hidden_ratio: float = 0.35):
        super().__init__()
        self.backbone, self.feat_dim, self.backbone_name, self.input_res = self._build_dino_backbone()
        self.tile_encoder = TileEncoder(self.backbone, self.input_res)
        self.t2t = T2TRetokenizer(self.feat_dim, depth=CFG.t2t_depth, heads=CFG.cross_heads, dropout=dropout)
        self.cross = CrossScaleFusion(
            self.feat_dim, heads=CFG.cross_heads, dropout=dropout, layers=CFG.cross_layers
        )
        self.pyramid = PyramidMixer(
            dim_in=self.feat_dim,
            dims=CFG.pyramid_dims,
            mobilevit_heads=CFG.mobilevit_heads,
            mobilevit_depth=CFG.mobilevit_depth,
            sra_heads=CFG.sra_heads,
            sra_ratio=CFG.sra_ratio,
            mamba_depth=CFG.mamba_depth,
            mamba_kernel=CFG.mamba_kernel,
            dropout=dropout,
        )

        combined = CFG.pyramid_dims[-1] * 2
        self.combined_dim = combined
        hidden = max(32, int(combined * hidden_ratio))

        def head():
            return nn.Sequential(
                nn.Linear(combined, hidden),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, 1),
            )

        self.head_green = head()
        self.head_clover = head()
        self.head_dead = head()
        self.score_head = nn.Sequential(nn.LayerNorm(combined), nn.Linear(combined, 1))
        self.aux_head = (
            nn.Sequential(nn.LayerNorm(CFG.pyramid_dims[1]), nn.Linear(CFG.pyramid_dims[1], 5))
            if CFG.aux_head
            else None
        )
        self.softplus = nn.Softplus(beta=1.0)

        self.cross_gate_left = nn.Linear(CFG.pyramid_dims[-1], CFG.pyramid_dims[-1])
        self.cross_gate_right = nn.Linear(CFG.pyramid_dims[-1], CFG.pyramid_dims[-1])

    def _build_dino_backbone(self):
        """
        只创建骨干结构，不加载预训练权重（pretrained=False），
        避免在 Kaggle 无网环境下下载；权重由 checkpoint 提供。
        """
        last_err = None
        for name in CFG.dino_candidates:
            for gp in ["token", "avg", "__default__"]:
                try:
                    if gp == "__default__":
                        m = timm.create_model(name, pretrained=False, num_classes=0)
                        gp_str = "default"
                    else:
                        m = timm.create_model(name, pretrained=False, num_classes=0, global_pool=gp)
                        gp_str = gp
                    feat = m.num_features
                    input_res = self._infer_input_res(m)
                    LOGGER.info(
                        f"✅ 使用 DINO 主干: {name} | global_pool={gp_str} | "
                        f"feat_dim={feat} | input_res={input_res}"
                    )
                    if hasattr(m, "set_grad_checkpointing"):
                        m.set_grad_checkpointing(True)
                    return m, feat, name, int(input_res)
                except Exception as e:
                    last_err = e
                    continue
        raise RuntimeError(f"无法创建任何 DINO 主干。最后错误: {last_err}")

    @staticmethod
    def _infer_input_res(m) -> int:
        if hasattr(m, "patch_embed") and hasattr(m.patch_embed, "img_size"):
            isz = m.patch_embed.img_size
            return int(isz if isinstance(isz, (int, float)) else isz[0])
        if hasattr(m, "img_size"):
            isz = m.img_size
            return int(isz if isinstance(isz, (int, float)) else isz[0])
        dc = getattr(m, "default_cfg", {}) or {}
        ins = dc.get("input_size", None)
        if ins:
            if isinstance(ins, (tuple, list)) and len(ins) >= 2:
                return int(ins[1])
            return int(ins if isinstance(ins, (int, float)) else 224)
        return 518

    def _half_forward(self, x_half: torch.Tensor):
        tiles_small = self.tile_encoder(x_half, CFG.small_grid)
        tiles_big = self.tile_encoder(x_half, CFG.big_grid)
        t2, stage1_map = self.t2t(tiles_small, CFG.small_grid)
        fused = self.cross(t2, tiles_big)
        feat, feat_maps = self.pyramid(fused)
        feat_maps["stage1_map"] = stage1_map
        return feat, feat_maps

    def _merge_heads(self, f_l: torch.Tensor, f_r: torch.Tensor):
        g_l = torch.sigmoid(self.cross_gate_left(f_r))
        g_r = torch.sigmoid(self.cross_gate_right(f_l))
        f_l = f_l * g_l
        f_r = f_r * g_r
        f = torch.cat([f_l, f_r], dim=1)
        green_pos = self.softplus(self.head_green(f))
        clover_pos = self.softplus(self.head_clover(f))
        dead_pos = self.softplus(self.head_dead(f))
        gdm = green_pos + clover_pos
        total = gdm + dead_pos
        return total, gdm, green_pos, f

    def _param_device_dtype(self):
        try:
            ref = next(self.parameters())
            return ref.device, ref.dtype
        except StopIteration:
            return torch.device("cpu"), torch.float32

    def _empty_forward_output(self, device=None, dtype=None, return_features: bool = False):
        if device is None or dtype is None:
            device_p, dtype_p = self._param_device_dtype()
            if device is None:
                device = device_p
            if dtype is None:
                dtype = dtype_p

        zero = torch.zeros(0, 1, device=device, dtype=dtype)
        out = {
            "total": zero,
            "gdm": zero,
            "green": zero,
            "score_feat": torch.zeros(0, self.combined_dim, device=device, dtype=dtype),
        }
        if self.aux_head is not None:
            out["aux"] = torch.zeros(0, len(CFG.ALL_TARGET_COLS), device=device, dtype=dtype)
        if return_features:
            out["feature_maps"] = {}
        return out

    def forward(self, *inputs, x_left=None, x_right=None, return_features: bool = False):
        # 兼容 DataParallel 传参方式（单 tensor / 元组）
        if inputs:
            if len(inputs) == 1:
                first = inputs[0]
                if isinstance(first, (tuple, list)):
                    if len(first) >= 1:
                        x_left = first[0]
                    if len(first) >= 2:
                        x_right = first[1]
                else:
                    x_left = first
            else:
                x_left = inputs[0]
                x_right = inputs[1]

        if x_left is None:
            return self._empty_forward_output(return_features=return_features)
        if isinstance(x_left, torch.Tensor) and x_left.shape[0] == 0:
            return self._empty_forward_output(return_features=return_features)

        if x_right is None:
            if isinstance(x_left, torch.Tensor):
                if x_left.shape[1] % 2 != 0:
                    raise ValueError("无法从单个张量推断左右分支，请显式提供 x_right。")
                x_left, x_right = torch.chunk(x_left, 2, dim=1)
            else:
                raise ValueError("缺少 x_right 输入。")

        feat_l, feats_l = self._half_forward(x_left)
        feat_r, feats_r = self._half_forward(x_right)
        total, gdm, green, f_concat = self._merge_heads(feat_l, feat_r)

        out = {
            "total": total,
            "gdm": gdm,
            "green": green,
            "score_feat": f_concat,
        }
        if self.aux_head is not None:
            aux_tokens = torch.cat([feats_l["stage2_tokens"], feats_r["stage2_tokens"]], dim=1)
            aux_pred = self.softplus(self.aux_head(aux_tokens.mean(dim=1)))
            out["aux"] = aux_pred  # 顺序与 CFG.ALL_TARGET_COLS 对齐
        if return_features:
            out["feature_maps"] = {
                "stage1_left": feats_l.get("stage1_map"),
                "stage1_right": feats_r.get("stage1_map"),
                "stage3_left": feats_l.get("stage3_tokens"),
                "stage3_right": feats_r.get("stage3_tokens"),
            }
        return out


# =============================================================================
# 推理配置
# =============================================================================
class INF_CFG:
    # 数据路径（Kaggle 默认，本地测试可修改）
    BASE_PATH = "/kaggle/input/csiro-biomass"
    TEST_CSV = os.path.join(BASE_PATH, "test.csv")
    TEST_IMAGE_DIR = os.path.join(BASE_PATH, "test")

    # 实验目录（你的权重所在目录）
    EXPERIMENT_DIR = "/kaggle/input/csiro/pytorch/default/12"

    # Checkpoint 路径（5-fold）
    CKPT_PATTERN_FOLD_X = os.path.join(
        EXPERIMENT_DIR, "fold_{fold}", "checkpoints", "best_wr2.pt"
    )
    CKPT_PATTERN_FOLDX = os.path.join(
        EXPERIMENT_DIR, "fold{fold}", "checkpoints", "best_wr2.pt"
    )
    N_FOLDS = 5

    # 输出
    SUBMISSION_FILE = "submission.csv"

    # 推理设置
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    BATCH_SIZE = 1
    NUM_WORKERS = 0
    MIXED_PRECISION = True

    # TTA 设置
    USE_TTA = True
    TTA_TRANSFORMS = ["original", "hflip", "vflip"]

    # 目标列顺序（与训练一致）
    ALL_TARGET_COLS = ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]


print(f"Device: {INF_CFG.DEVICE}")
print(f"Experiment Dir: {INF_CFG.EXPERIMENT_DIR}")


# =============================================================================
# 数据集
# =============================================================================
class TestBiomassDataset(Dataset):
    """测试数据集：左右两路输入"""

    def __init__(self, df: pd.DataFrame, transform, image_dir: str):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.image_dir = image_dir
        self.paths = self.df["image_path"].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        filename = os.path.basename(self.paths[idx])
        full_path = os.path.join(self.image_dir, filename)

        img = cv2.imread(full_path)
        if img is None:
            # 容错：若读图失败，用黑图占位
            img = np.zeros((1000, 2000, 3), dtype=np.uint8)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 左右切半（与训练一致）
        h, w, _ = img.shape
        mid = w // 2
        left = img[:, :mid]
        right = img[:, mid:]

        left_t = self.transform(image=left)["image"]
        right_t = self.transform(image=right)["image"]

        return left_t, right_t


# =============================================================================
# TTA 变换
# =============================================================================
def get_tta_transforms(img_size: int) -> List[A.Compose]:
    """生成 TTA 变换列表"""
    base = [
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]

    transforms = []

    # 原图
    transforms.append(
        A.Compose([
            A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
            *base,
        ])
    )

    # 水平翻转
    transforms.append(
        A.Compose([
            A.HorizontalFlip(p=1.0),
            A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
            *base,
        ])
    )

    # 垂直翻转
    transforms.append(
        A.Compose([
            A.VerticalFlip(p=1.0),
            A.Resize(img_size, img_size, interpolation=cv2.INTER_AREA),
            *base,
        ])
    )

    return transforms


# =============================================================================
# 权重加载
# =============================================================================
def strip_module_prefix(state_dict: dict) -> dict:
    """移除 DataParallel 的 module. 前缀"""
    if not state_dict:
        return state_dict

    keys = list(state_dict.keys())
    if all(k.startswith("module.") for k in keys):
        return {k[len("module."):]: v for k, v in state_dict.items()}
    return state_dict


def load_checkpoint(path: str) -> dict:
    """加载 checkpoint"""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Checkpoint not found: {path}")

    try:
        state = torch.load(path, map_location="cpu", weights_only=False)
    except TypeError:
        state = torch.load(path, map_location="cpu")

    return state


def load_model_from_checkpoint(ckpt_path: str) -> nn.Module:
    """从 checkpoint 加载模型"""
    print(f"\n加载 checkpoint: {os.path.basename(ckpt_path)}")

    state = load_checkpoint(ckpt_path)

    # 用 checkpoint 中的 cfg 覆盖当前 CFG（保证结构对应）
    cfg_dict = state.get("cfg", {})
    update_cfg_from_checkpoint(cfg_dict)

    dropout = cfg_dict.get("dropout", CFG.dropout)
    hidden_ratio = cfg_dict.get("hidden_ratio", CFG.hidden_ratio)

    # 创建模型（结构与训练一致）
    model = CrossPVT_T2T_MambaDINO(dropout=dropout, hidden_ratio=hidden_ratio)

    # 提取模型状态
    model_state = state.get("model_state")
    if model_state is None:
        # 如果 checkpoint 直接是 state_dict
        model_state = state

    # 移除 module. 前缀
    model_state = strip_module_prefix(model_state)

    # 加载权重
    missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)

    if missing_keys:
        print(f"  ⚠️  缺失的键: {len(missing_keys)} 个")
        if len(missing_keys) <= 10:
            for k in missing_keys[:10]:
                print(f"    - {k}")

    if unexpected_keys:
        print(f"  ⚠️  意外的键: {len(unexpected_keys)} 个")
        if len(unexpected_keys) <= 10:
            for k in unexpected_keys[:10]:
                print(f"    - {k}")

    model.to(INF_CFG.DEVICE)
    model.eval()

    # 获取输入分辨率
    input_res = getattr(model, "input_res", 518)
    backbone_name = getattr(model, "backbone_name", "unknown")

    print(f"  ✓ 模型加载成功 | backbone={backbone_name} | input_res={input_res}")

    return model


# =============================================================================
# 推理
# =============================================================================
def pack5_targets(total: torch.Tensor, gdm: torch.Tensor, green: torch.Tensor) -> torch.Tensor:
    """将 total, gdm, green 打包为 5 个目标"""
    clover = gdm - green
    dead = total - gdm
    return torch.cat([green, dead, clover, gdm, total], dim=1)


@torch.no_grad()
def predict_one_view(models: List[nn.Module], loader: DataLoader) -> np.ndarray:
    """对单个 TTA 视角进行预测"""
    preds_list = []
    amp_dtype = "cuda" if INF_CFG.DEVICE.type == "cuda" else "cpu"

    for xl, xr in tqdm(loader, desc="  Predicting", leave=False):
        xl = xl.to(INF_CFG.DEVICE, non_blocking=True)
        xr = xr.to(INF_CFG.DEVICE, non_blocking=True)

        # 拼接为单 tensor（与训练时的 DataParallel 调用方式一致）
        x_cat = torch.cat([xl, xr], dim=1)

        per_model_preds = []

        with torch.amp.autocast(amp_dtype, enabled=INF_CFG.MIXED_PRECISION):
            for model in models:
                out = model(x_cat, return_features=False)

                total = out["total"]
                gdm = out["gdm"]
                green = out["green"]

                # 打包为 5 个目标
                five = pack5_targets(total, gdm, green)

                # 非负约束
                five = torch.clamp(five, min=0.0)

                per_model_preds.append(five.float().cpu())

        # 5-fold ensemble 平均
        stacked = torch.mean(torch.stack(per_model_preds, dim=0), dim=0)
        preds_list.append(stacked.numpy())

    return np.concatenate(preds_list, axis=0)


def run_inference(test_df: pd.DataFrame, image_dir: str) -> np.ndarray:
    """运行完整推理流程（5-fold ensemble + TTA）"""
    print("\n" + "=" * 80)
    print("开始推理")
    print("=" * 80)

    # 加载所有 fold 的模型
    print("\n加载模型 (5-fold)...")
    models = []
    input_res = None

    for fold in range(INF_CFG.N_FOLDS):
        ckpt_path = INF_CFG.CKPT_PATTERN_FOLD_X.format(fold=fold)
        if not os.path.exists(ckpt_path):
            ckpt_path = INF_CFG.CKPT_PATTERN_FOLDX.format(fold=fold)

        if not os.path.exists(ckpt_path):
            print(f"  ⚠️  Fold {fold} checkpoint 不存在，尝试路径:")
            print(f"    - {INF_CFG.CKPT_PATTERN_FOLD_X.format(fold=fold)}")
            print(f"    - {INF_CFG.CKPT_PATTERN_FOLDX.format(fold=fold)}")
            continue

        model = load_model_from_checkpoint(ckpt_path)
        models.append(model)

        if input_res is None:
            input_res = getattr(model, "input_res", 518)
            print(f"  输入分辨率: {input_res}")

    if len(models) == 0:
        print("\n❌ 错误：没有找到任何可用的 checkpoint！")
        print(f"   请检查实验目录: {INF_CFG.EXPERIMENT_DIR}")
        for fold in range(INF_CFG.N_FOLDS):
            print(f"     - {INF_CFG.CKPT_PATTERN_FOLD_X.format(fold=fold)}")
            print(f"     - {INF_CFG.CKPT_PATTERN_FOLDX.format(fold=fold)}")
        raise RuntimeError("没有找到任何可用的 checkpoint！")

    print(f"\n✓ 成功加载 {len(models)} 个模型")

    # TTA 推理
    if INF_CFG.USE_TTA:
        tta_transforms = get_tta_transforms(input_res)
        print(f"\n使用 TTA: {len(tta_transforms)} 个视角")

        per_view_preds = []

        for i, transform in enumerate(tta_transforms):
            view_name = INF_CFG.TTA_TRANSFORMS[i] if i < len(INF_CFG.TTA_TRANSFORMS) else f"view_{i+1}"
            print(f"\n--- TTA 视角 {i+1}/{len(tta_transforms)}: {view_name} ---")

            ds = TestBiomassDataset(test_df, transform, image_dir)
            dl = DataLoader(
                ds,
                batch_size=INF_CFG.BATCH_SIZE,
                shuffle=False,
                num_workers=INF_CFG.NUM_WORKERS,
                pin_memory=True,
            )

            view_pred = predict_one_view(models, dl)
            per_view_preds.append(view_pred)

        final_pred = np.mean(per_view_preds, axis=0)
        print(f"\n✓ TTA 完成，最终预测形状: {final_pred.shape}")
    else:
        transform = get_tta_transforms(input_res)[0]
        ds = TestBiomassDataset(test_df, transform, image_dir)
        dl = DataLoader(
            ds,
            batch_size=INF_CFG.BATCH_SIZE,
            shuffle=False,
            num_workers=INF_CFG.NUM_WORKERS,
            pin_memory=True,
        )
        final_pred = predict_one_view(models, dl)

    return final_pred


# =============================================================================
# 生成提交文件
# =============================================================================
def create_submission(final_pred: np.ndarray, test_long: pd.DataFrame, test_unique: pd.DataFrame) -> pd.DataFrame:
    """生成提交文件"""
    print("\n" + "=" * 80)
    print("生成提交文件")
    print("=" * 80)

    # 提取各目标
    green = final_pred[:, 0]
    dead = final_pred[:, 1]
    clover = final_pred[:, 2]
    gdm = final_pred[:, 3]
    total = final_pred[:, 4]

    # 非负 & NaN/Inf 处理
    def clean(x):
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return np.maximum(0, x)

    green, dead, clover, gdm, total = map(clean, [green, dead, clover, gdm, total])

    wide = pd.DataFrame(
        {
            "image_path": test_unique["image_path"],
            "Dry_Green_g": green,
            "Dry_Dead_g": dead,
            "Dry_Clover_g": clover,
            "GDM_g": gdm,
            "Dry_Total_g": total,
        }
    )

    long_preds = wide.melt(
        id_vars=["image_path"],
        value_vars=INF_CFG.ALL_TARGET_COLS,
        var_name="target_name",
        value_name="target",
    )

    sub = pd.merge(
        test_long[["sample_id", "image_path", "target_name"]],
        long_preds,
        on=["image_path", "target_name"],
        how="left",
    )[["sample_id", "target"]]

    sub["target"] = np.nan_to_num(sub["target"], nan=0.0, posinf=0.0, neginf=0.0)
    sub.to_csv(INF_CFG.SUBMISSION_FILE, index=False)

    print(f"\n✓ 提交文件已保存: {INF_CFG.SUBMISSION_FILE}")
    print(f"  样本数: {len(sub)}")
    print(f"  预测统计:")
    print(f"    Dry_Green_g:   mean={green.mean():.2f}, std={green.std():.2f}, min={green.min():.2f}, max={green.max():.2f}")
    print(f"    Dry_Dead_g:    mean={dead.mean():.2f}, std={dead.std():.2f}, min={dead.min():.2f}, max={dead.max():.2f}")
    print(f"    Dry_Clover_g:  mean={clover.mean():.2f}, std={clover.std():.2f}, min={clover.min():.2f}, max={clover.max():.2f}")
    print(f"    GDM_g:         mean={gdm.mean():.2f}, std={gdm.std():.2f}, min={gdm.min():.2f}, max={gdm.max():.2f}")
    print(f"    Dry_Total_g:   mean={total.mean():.2f}, std={total.std():.2f}, min={total.min():.2f}, max={total.max():.2f}")
    print(f"\n前 10 行预览:")
    print(sub.head(10).to_string())
    return sub


# =============================================================================
# 主函数（使用 parse_known_args 避免 SystemExit:2）
# =============================================================================
def parse_args():
    parser = argparse.ArgumentParser(description="CSIRO v4 CrossPVT T2T Mamba Inference")

    parser.add_argument(
        "--test-csv",
        type=str,
        default=None,
        help="测试集 CSV 路径（默认: INF_CFG.TEST_CSV）",
    )
    parser.add_argument(
        "--test-image-dir",
        type=str,
        default=None,
        help="测试图像目录（默认: INF_CFG.TEST_IMAGE_DIR）",
    )
    parser.add_argument(
        "--experiment-dir",
        type=str,
        default=None,
        help="实验目录（checkpoint 所在位置，默认: INF_CFG.EXPERIMENT_DIR）",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="输出文件路径（默认: INF_CFG.SUBMISSION_FILE）",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=None,
        help="批次大小（默认: INF_CFG.BATCH_SIZE）",
    )
    parser.add_argument(
        "--no-tta",
        action="store_true",
        help="禁用 TTA",
    )

    # 关键修改：用 parse_known_args 吞掉 Kaggle/Colab 多余参数（如 -f /tmp/xxx.json）
    args, _ = parser.parse_known_args()
    return args


def main():
    args = parse_args()

    # 更新配置
    if args.test_csv:
        INF_CFG.TEST_CSV = args.test_csv
    if args.test_image_dir:
        INF_CFG.TEST_IMAGE_DIR = args.test_image_dir
    if args.experiment_dir:
        INF_CFG.EXPERIMENT_DIR = args.experiment_dir
        INF_CFG.CKPT_PATTERN_FOLD_X = os.path.join(
            INF_CFG.EXPERIMENT_DIR, "fold_{fold}", "checkpoints", "best_wr2.pt"
        )
        INF_CFG.CKPT_PATTERN_FOLDX = os.path.join(
            INF_CFG.EXPERIMENT_DIR, "fold{fold}", "checkpoints", "best_wr2.pt"
        )
    if args.output:
        INF_CFG.SUBMISSION_FILE = args.output
    if args.batch_size:
        INF_CFG.BATCH_SIZE = args.batch_size
    if args.no_tta:
        INF_CFG.USE_TTA = False

    print("=" * 80)
    print("CSIRO Image2Biomass - v4 CrossPVT T2T Mamba Inference")
    print("=" * 80)
    print(f"测试 CSV: {INF_CFG.TEST_CSV}")
    print(f"测试图像目录: {INF_CFG.TEST_IMAGE_DIR}")
    print(f"实验目录: {INF_CFG.EXPERIMENT_DIR}")
    print(f"输出文件: {INF_CFG.SUBMISSION_FILE}")
    print(f"批次大小: {INF_CFG.BATCH_SIZE}")
    print(f"使用 TTA: {INF_CFG.USE_TTA}")

    # 加载测试数据
    print("\n加载测试数据...")
    if not os.path.exists(INF_CFG.TEST_CSV):
        raise FileNotFoundError(f"测试 CSV 不存在: {INF_CFG.TEST_CSV}")

    test_long = pd.read_csv(INF_CFG.TEST_CSV)
    test_unique = test_long.drop_duplicates(subset=["image_path"]).reset_index(drop=True)
    print(f"✓ 找到 {len(test_unique)} 张独立测试图像")
    print(f"  总测试样本数: {len(test_long)}")

    # 运行推理
    final_pred = run_inference(test_unique, INF_CFG.TEST_IMAGE_DIR)

    # 生成提交文件
    submission = create_submission(final_pred, test_long, test_unique)

    print("\n" + "=" * 80)
    print("推理完成！")
    print("=" * 80)

    # 清理
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()


  data = fetch_version_info()


Device: cuda
Experiment Dir: /kaggle/input/csiro/pytorch/default/12
CSIRO Image2Biomass - v4 CrossPVT T2T Mamba Inference
测试 CSV: /kaggle/input/csiro-biomass/test.csv
测试图像目录: /kaggle/input/csiro-biomass/test
实验目录: /kaggle/input/csiro/pytorch/default/12
输出文件: submission.csv
批次大小: 1
使用 TTA: True

加载测试数据...
✓ 找到 1 张独立测试图像
  总测试样本数: 5

开始推理

加载模型 (5-fold)...

加载 checkpoint: best_wr2.pt


✅ 使用 DINO 主干: vit_base_patch14_dinov2 | global_pool=token | feat_dim=768 | input_res=518


  ✓ 模型加载成功 | backbone=vit_base_patch14_dinov2 | input_res=518
  输入分辨率: 518

加载 checkpoint: best_wr2.pt


✅ 使用 DINO 主干: vit_base_patch14_dinov2 | global_pool=token | feat_dim=768 | input_res=518


  ✓ 模型加载成功 | backbone=vit_base_patch14_dinov2 | input_res=518

加载 checkpoint: best_wr2.pt


✅ 使用 DINO 主干: vit_base_patch14_dinov2 | global_pool=token | feat_dim=768 | input_res=518


  ✓ 模型加载成功 | backbone=vit_base_patch14_dinov2 | input_res=518

加载 checkpoint: best_wr2.pt


✅ 使用 DINO 主干: vit_base_patch14_dinov2 | global_pool=token | feat_dim=768 | input_res=518


  ✓ 模型加载成功 | backbone=vit_base_patch14_dinov2 | input_res=518

加载 checkpoint: best_wr2.pt


✅ 使用 DINO 主干: vit_base_patch14_dinov2 | global_pool=token | feat_dim=768 | input_res=518


  ✓ 模型加载成功 | backbone=vit_base_patch14_dinov2 | input_res=518

✓ 成功加载 5 个模型

使用 TTA: 3 个视角

--- TTA 视角 1/3: original ---


                                                           


--- TTA 视角 2/3: hflip ---


                                                           


--- TTA 视角 3/3: vflip ---


                                                           


✓ TTA 完成，最终预测形状: (1, 5)

生成提交文件

✓ 提交文件已保存: submission.csv
  样本数: 5
  预测统计:
    Dry_Green_g:   mean=25.12, std=0.00, min=25.12, max=25.12
    Dry_Dead_g:    mean=31.38, std=0.00, min=31.38, max=31.38
    Dry_Clover_g:  mean=6.09, std=0.00, min=6.09, max=6.09
    GDM_g:         mean=31.21, std=0.00, min=31.21, max=31.21
    Dry_Total_g:   mean=62.60, std=0.00, min=62.60, max=62.60

前 10 行预览:
                    sample_id     target
0  ID1001187975__Dry_Clover_g   6.090722
1    ID1001187975__Dry_Dead_g  31.381250
2   ID1001187975__Dry_Green_g  25.123957
3   ID1001187975__Dry_Total_g  62.595928
4         ID1001187975__GDM_g  31.214682

推理完成！
