In [8]:
import pandas as pd
import numpy as np
from PIL import Image

import importlib, csiro_biomass_shared_utils as csiro_su
importlib.reload(csiro_su)

from pathlib import Path
from typing import Optional, Dict, Any, Sequence

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode

from __future__ import annotations
from dataclasses import dataclass

pd.set_option('display.expand_frame_repr', False)  # don't split into blocks
pd.set_option('display.width', 200)               # or None for auto / very wide
pd.set_option('display.max_columns', None)        # show all columns

In [9]:
from pathlib import Path
import pandas as pd

def load_split_dfs(in_dir="dataset/"):
    in_dir = Path(in_dir)

    train_df = pd.read_parquet(in_dir / "train.parquet")
    val_df   = pd.read_parquet(in_dir / "val.parquet")
    test_df  = pd.read_parquet(in_dir / "test.parquet")

    return train_df, val_df, test_df

# usage
train_df, val_df, test_df = load_split_dfs("dataset/")
print(len(train_df), len(val_df), len(test_df))


cols = ["Pre_GSHH_NDVI","Height_Ave_cm","Dry_Clover_g","Dry_Dead_g","Dry_Green_g","Dry_Total_g","GDM_g"]

csiro_su._assert_has_columns(train_df, cols, "train_df")
csiro_su._assert_has_columns(val_df, cols, "val_df")
csiro_su._assert_has_columns(test_df, cols, "test_df")

# print(train_df)
# print("All good")


255 50 51


In [10]:
# # tfms_train = build_image_transforms("train", target_h=256, target_w=512)
# # csiro_su.show_original_vs_transformed(train_df, "/kaggle/input/csiro-biomass/train", tfms_train, n=3, font_size=9)

# def collate_with_original(batch):
#     images = torch.stack([b["image"] for b in batch], dim=0)
#     ids = [b["sample_id"] for b in batch]

#     out = {"image": images, "sample_id": ids}

#     if "orig" in batch[0]:
#         out["orig"] = [b["orig"] for b in batch]  # list of [3,H,W]

#     return out

# ds = Image2BiomassData(train_df, "/kaggle/input/csiro-biomass/train",
#                        mode="train", return_original=True)

# dl = DataLoader(ds, batch_size=8, shuffle=True, num_workers=2,
#                 collate_fn=collate_with_original)

# csiro_su.show_dl_batch_original_vs_transformed(dl)

## Model to predict from the image:

* **Biomass (regression):** `Dry_Clover_g, Dry_Dead_g, Dry_Green_g, Dry_Total_g, GDM_g`
* **Aux targets to “boost” backbone (multi-task):**

  * `Pre_GSHH_NDVI` (regression)
  * `Height_Ave_cm` (regression)
  * `Species` (**classification**, not regression)

And **everything else** can be returned as **metadata** (for debugging / weighting / stratification), but not trained as targets.

auxiliary tasks (NDVI/Height/Species) can help the backbone learn better visual features, which often improves the main biomass R².

#### About "other columns as metadata to boost R²"

Yes — keep them in `meta` for:

* tail weighting (`TailFlag`)
* regime-aware sampling (`Cbin/Tbin/...`)
* analysis/debug plots
* sanity checks

…but **don’t train to predict them** (it’s usually just noise/overhead).

#### Dataset that returns (image → multi-target) + metadata

* **Image**

  * `out["image"] = img_t` is the **transformed** image tensor (resize/augment/normalize).
  * `out["orig"] = orig` is included **only if** `self.return_original=True` (and it’s the *original* image tensor you saved for visualization/debug).

* **Targets / labels**

  * `out["y_biomass"]` = the 5 biomass targets: `["Dry_Clover_g","Dry_Dead_g","Dry_Green_g","Dry_Total_g","GDM_g"]` (possibly log/standardized depending on how you created `y_biomass`)
  * `out["y_ndvi"]`, `out["y_height"]`, `out["y_species"]` = whatever you computed (raw or transformed/encoded).

* **IDs / metadata**

  * `out["sample_id"]` = identifier
  * `out["meta"] = row.to_dict()` is included **only if** `self.return_meta=True`. It contains the full row (all columns), so yes it can carry “auxiliary” info — **but it’s not automatically used to boost R²** unless your model forward/loss actually consumes those meta fields.

So: **the dataset returns everything you need**, but “meta boosts R²” happens only if you explicitly feed some/all of it into the model (or use it for sampling/weights).

Clean way to do this - 
* **Dataset = “raw fetcher”** (reads image + row/meta, no transforms)
* **One “batch preprocessor” function = does *all* transforms** (image + targets + keeps `meta`, `orig`, `sample_id`)
* Plug that preprocessor into the **DataLoader via `collate_fn`**.



In [11]:
from pathlib import Path
from typing import Any, Dict
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import cv2

# ---------- small helpers ----------
def _to_float(x):
    try:
        return float(x)
    except Exception:
        return float("nan")

import numpy as np
import pandas as pd

def compute_tail_q_from_train_df(train_df: pd.DataFrame, target_cfg: dict, stats: dict, qs=(0.90, 0.99), eps: float = 1e-8):
    tail_q = {}

    for col, cfg in target_cfg.items():
        s = pd.to_numeric(train_df[col], errors="coerce")  # Series with NaNs possible

        # 1) log1p (if configured)
        if cfg.get("log1p", False):
            s = np.log1p(np.clip(s.values.astype(np.float32), 0.0, None))
        else:
            s = s.values.astype(np.float32)

        # 2) normalize (if configured)
        if cfg.get("normalize", True):
            mu, sigma = stats[col]
            sigma = float(sigma) if float(sigma) > eps else eps
            s = (s - float(mu)) / sigma

        # 3) quantiles (ignore NaNs)
        qvals = np.nanquantile(s, qs)

        tail_q[col] = {f"q{int(q*100)}": float(v) for q, v in zip(qs, qvals)}

    return tail_q


def col_mean_std_from_series(s: pd.Series, ddof: int = 0, eps: float = 1e-8):
    s = pd.to_numeric(s, errors="coerce").dropna()
    if len(s) == 0:
        raise ValueError("No valid numeric values to compute mean/std.")
    mu = float(s.mean())
    sigma = float(s.std(ddof=ddof))
    if (not np.isfinite(sigma)) or sigma < eps:
        sigma = eps
    return mu, sigma

class Image2BiomassData(Dataset):
    def __init__(self, dataframe: pd.DataFrame, images_dir: str,
                 mode: str = "train", export_state: bool = False,
                 import_state: bool = False, state: dict | None = None):

        self.df                 = dataframe.reset_index(drop=True).copy()
        self.images_dir         = images_dir
        self.mode               = mode
        self.stats              = {}
        
        # fixed columns (no extra init args)
        # sample_id Species  Pre_GSHH_NDVI  Height_Ave_cm  Dry_Clover_g  Dry_Dead_g  Dry_Green_g  Dry_Total_g GDM_g
        # SpeciesBucket Tbin Gbin Dbin Cbin Rbin  GroupID  TailFlag  StratifyKey StratifyKey_safe
        self.id_col         = "sample_id"
        self.species_col    = "Species"
        self.numeric_cols   = ["Pre_GSHH_NDVI", "Height_Ave_cm", "Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"]
        self.meta_cols      = ["SpeciesBucket", "Tbin", "Gbin", "Dbin", "Cbin", "Rbin", "GroupID", "TailFlag", "StratifyKey", "StratifyKey_safe"]
     
        # Decide per-column transform (example)
        self.target_cfg = {
            "Pre_GSHH_NDVI": {"log1p": False, "normalize": True},
            "Height_Ave_cm": {"log1p": True,  "normalize": True},
            "Dry_Clover_g":  {"log1p": True,  "normalize": True},
            "Dry_Dead_g":    {"log1p": True,  "normalize": True},
            "Dry_Green_g":   {"log1p": True,  "normalize": True},
            "Dry_Total_g":   {"log1p": True,  "normalize": True},
            "GDM_g":         {"log1p": True,  "normalize": True},
        }
        
        # ---- IMPORT STATE (no refit) ----
        if import_state:
            if state is None:
                raise ValueError("import_state=True requires a non-None `state`.")
            self.stats      = state["stats"]
            self.species2id = state["species2id"]
            self.target_cfg = state.get("target_cfg", self.target_cfg)
            self.tail_q      = state.get("tail_q", None)
        else:
            # ---- FIT ON THIS DATAFRAME (train) ----
            # species mapping (fit from provided dataframe)
            uniq = sorted(self.df[self.species_col].astype(str).unique().tolist())
            self.species2id = {"__UNK__": 0, **{s: i + 1 for i, s in enumerate(uniq)}}
            
            # 1) fit stats (needed for normalization)
            for col, cfg in self.target_cfg.items():
                s = pd.to_numeric(self.df[col], errors="coerce")
                if cfg.get("log1p", False):
                    # log1p requires x >= -1; for biomass it should be >= 0
                    s = np.log1p(np.clip(s, 0, None))
                mu, sigma = col_mean_std_from_series(s)
                self.stats[col] = (mu, sigma)
            
            # 2) compute tail quantiles on transformed+standardized scale
            self.tail_q = compute_tail_q_from_train_df(self.df, self.target_cfg, self.stats,
                                                       qs=[0.50, 0.75, 0.90, 0.95, 0.99])
            print(self.tail_q)
        
        # ---- optionally store exportable state ----
        self._state = self.get_state() if export_state else None
        
        IMAGENET_MEAN = [0.485, 0.456, 0.406]
        IMAGENET_STD  = [0.229, 0.224, 0.225]

        # optional torch versions (handy for unnormalize)
        IMAGENET_MEAN_T = torch.tensor(IMAGENET_MEAN).view(3,1,1)
        IMAGENET_STD_T  = torch.tensor(IMAGENET_STD).view(3,1,1)

        # build image tfms ONCE (do not recreate per sample)
        TARGET_H, TARGET_W = 256, 512
        if self.mode == "train":
            self.img_tfms = T.Compose([
                T.ToTensor(),
                T.Resize((TARGET_H, TARGET_W), interpolation=InterpolationMode.BICUBIC),

                T.RandomHorizontalFlip(p=0.5),
                T.RandomVerticalFlip(p=0.5),

                T.RandomApply([T.RandomRotation(
                    degrees=5,
                    interpolation=InterpolationMode.BILINEAR,
                    fill=0
                )], p=0.25),

                T.RandomApply([T.ColorJitter(
                    brightness=0.2, contrast=0.2, saturation=0.15, hue=0.03
                )], p=0.25),

                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ])
        else:
            self.img_tfms = T.Compose([
                T.ToTensor(),
                T.Resize((TARGET_H, TARGET_W), interpolation=InterpolationMode.BICUBIC),
                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ])

    def get_state(self) -> dict:
        return {
            "stats": self.stats,
            "species2id": self.species2id,
            "target_cfg": self.target_cfg,
            "tail_q": self.tail_q
        }
            
    def _transform_and_standardize(self, col: str, x: float) -> float:
        cfg = self.target_cfg[col]

        # 1) transform
        if cfg.get("log1p", False):
            x = np.log1p(max(x, 0.0))

        # 2) optional standardization
        if cfg.get("normalize", True):
            mu, sigma = self.stats[col]
            x = (x - mu) / sigma

        return float(x)

    def __len__(self):
        return len(self.df)

    def _img_path(self, sid: str) -> Path:
        return f"{self.images_dir}/{sid}.jpg"

    def _transform_image(self, img_path: Path) -> torch.Tensor:
        img_bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        if img_bgr is None:
            raise FileNotFoundError(f"Image not found or unreadable: {img_path}")
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        return self.img_tfms(img_rgb)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        ## -- Get the right index ---
        row = self.df.iloc[idx]
        sid = str(row[self.id_col])
        
        img_t = self._transform_image(f"{self.images_dir}/{sid}.jpg")
        
        out = {"img_t": img_t}
        
        ## -- sample_id : nothing changes --
        out.update({ "sample_id": sid})
       
        # species -> id (use precomputed mapping)
        sp = str(row[self.species_col])
        sp_id = self.species2id.get(sp, self.species2id["__UNK__"])
        out["species_id"] = torch.tensor(sp_id, dtype=torch.long)
        
        ## All the remaining columns:
        # Pre_GSHH_NDVI  Height_Ave_cm  Dry_Clover_g  Dry_Dead_g  Dry_Green_g  Dry_Total_g GDM_g
        for c in self.target_cfg.keys():
            out[c] = torch.tensor(self._transform_and_standardize(c, _to_float(row[c])), dtype=torch.float32)
                
        # copy remaining/meta columns AS-IS (no transforms)
        for k in self.meta_cols:
            if k in row.index:
                out[k] = row[k]
            
        return out


In [12]:
data_root_dir = "/kaggle/input/csiro-biomass"

train_ds = Image2BiomassData(train_df, f"{data_root_dir}/train", mode="train", export_state=True)
state = train_ds.get_state()
# print(state)

val_ds = Image2BiomassData(val_df, f"{data_root_dir}/train", mode="val", import_state=True, state=state)
state = val_ds.get_state()
# print(state)

test_ds = Image2BiomassData(test_df, f"{data_root_dir}/train", mode="val", import_state=True, state=state)
state = val_ds.get_state()
# print(state)

{'Pre_GSHH_NDVI': {'q50': 0.21079139411449432, 'q75': 0.7403659224510193, 'q90': 1.2037436962127686, 'q95': 1.3361376523971558, 'q99': 1.4632354521751414}, 'Height_Ave_cm': {'q50': -0.2719688415527344, 'q75': 0.3856850862503052, 'q90': 1.4404025077819824, 'q95': 2.3810505628585807, 'q99': 2.7710587978363037}, 'Dry_Clover_g': {'q50': -0.2647244334220886, 'q75': 0.7284159064292908, 'q90': 1.5074969053268432, 'q95': 1.8269424676895119, 'q99': 2.515016226768495}, 'Dry_Dead_g': {'q50': 0.05714616924524307, 'q75': 0.7739342451095581, 'q90': 1.264405083656311, 'q95': 1.4368308544158934, 'q99': 1.7236804413795472}, 'Dry_Green_g': {'q50': 0.21901564300060272, 'q75': 0.6587365567684174, 'q90': 1.1036961793899536, 'q95': 1.3031319260597227, 'q99': 1.667626738548279}, 'Dry_Total_g': {'q50': 0.10779093205928802, 'q75': 0.6682067811489105, 'q90': 1.2109348773956299, 'q95': 1.5225960612297054, 'q99': 1.9445340704917924}, 'GDM_g': {'q50': 0.1054128035902977, 'q75': 0.6984666883945465, 'q90': 1.1868004

In [13]:
from torch.utils.data import DataLoader

      
TARGET_COLS = ["Pre_GSHH_NDVI", "Height_Ave_cm", "Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"]

def collate_biomass(batch):
    out = {}
    out["img_t"] = torch.stack([b["img_t"] for b in batch], dim=0)              # [B,3,256,512]
    out["sample_id"] = [b["sample_id"] for b in batch]                          # list[str]
    out["species_id"] = torch.stack([b["species_id"] for b in batch], dim=0)    # [B]
    
    # targets -> [B, 8]
    out["y"] = torch.stack(
        [torch.stack([b[c] for c in TARGET_COLS]) for b in batch], dim=0
    ).float()

    # keep meta as-is (strings/ints) -> list
    for k in batch[0].keys():
        if k in ["img_t", "species_id", "sample_id"] + TARGET_COLS:
            continue
        out[k] = [b[k] for b in batch]

    return out

In [14]:
BATCH_SIZE = 32

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True, collate_fn=collate_biomass)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_biomass)
test_dl  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_biomass)

# b = next(iter(train_dl))
# print("img_t:", b["img_t"].shape, b["img_t"].dtype)
# print("y:", b["y"].shape, b["y"].dtype)
# print("species_id:", b["species_id"].shape, b["species_id"].dtype)
# print("sample_id:", len(b["sample_id"]), b["sample_id"][0])
# print("Training & Validation dataloaders created !")
# print(b)
