In [None]:
import torch

BIOMASS_COLS = ["Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"]
BIOMASS_W = torch.tensor([0.1, 0.1, 0.1, 0.2, 0.5])  # must match your loss weights

def _inv_transform(col: str, x: torch.Tensor, state: dict) -> torch.Tensor:
    """
    x: tensor on standardized (and maybe log1p) scale -> returns raw scale tensor
    """
    stats = state["stats"]
    cfg   = state.get("target_cfg", {})

    mu, sig = stats[col]
    x = x * sig + mu

    if cfg.get(col, {}).get("log1p", False):
        x = torch.expm1(x)
        x = torch.clamp(x, min=0.0)

    return x

@torch.inference_mode()
def compute_biomass_r2_raw(model, data_loader, device, state, eps=1e-8):
    model.eval()

    all_pred_raw = []
    all_true_raw = []

    for batch in data_loader:
        images = batch["img_t"].to(device)
        y_true = batch["y"].to(device)              # (B,7) in transformed space

        preds = model(images)
        y_pred_bio = preds["biomass"]               # (B,5) transformed space

        y_true_bio = y_true[:, 2:7]                 # (B,5) transformed space

        # inverse transform each biomass column
        pred_cols = []
        true_cols = []
        for i, col in enumerate(BIOMASS_COLS):
            pred_cols.append(_inv_transform(col, y_pred_bio[:, i], state).unsqueeze(1))
            true_cols.append(_inv_transform(col, y_true_bio[:, i], state).unsqueeze(1))

        all_pred_raw.append(torch.cat(pred_cols, dim=1).cpu())
        all_true_raw.append(torch.cat(true_cols, dim=1).cpu())

    y_pred = torch.cat(all_pred_raw, dim=0)   # (N,5) raw
    y_true = torch.cat(all_true_raw, dim=0)   # (N,5) raw

    # per-dim R2 then weighted average
    w = BIOMASS_W.to(dtype=y_true.dtype)
    r2s = []
    for j in range(y_true.shape[1]):
        yt = y_true[:, j]
        yp = y_pred[:, j]
        ss_tot = torch.sum((yt - yt.mean()) ** 2)
        ss_res = torch.sum((yt - yp) ** 2)
        r2 = 1.0 - ss_res / (ss_tot + eps)
        r2s.append(r2)

    r2s = torch.stack(r2s)  # (5,)
    return (w * r2s).sum() / (w.sum() + eps)
