In [2]:
import xarray as xr
import pandas as pd
import logging

In [4]:
# Load the NetCDF datacube with chunking
ds = xr.open_dataset("/Users/vladimir/catalonia-wildfire-prediction/data/IberFire.nc", chunks="auto")

In [5]:
print(ds)

<xarray.Dataset> Size: 731GB
Dimensions:                                        (y: 920, x: 1188, time: 6241)
Coordinates:
  * x                                              (x) float64 10kB 2.675e+06...
  * y                                              (y) float64 7kB 2.492e+06 ...
  * time                                           (time) datetime64[ns] 50kB ...
Data variables: (12/261)
    x_index                                        (y, x) uint16 2MB dask.array<chunksize=(920, 1188), meta=np.ndarray>
    y_index                                        (y, x) uint16 2MB dask.array<chunksize=(920, 1188), meta=np.ndarray>
    is_spain                                       (y, x) uint16 2MB dask.array<chunksize=(920, 1188), meta=np.ndarray>
    is_fire                                        (time, y, x) uint8 7GB dask.array<chunksize=(1562, 230, 298), meta=np.ndarray>
    is_near_fire                                   (time, y, x) uint8 7GB dask.array<chunksize=(1562, 230, 298), meta=

In [None]:
"""
Preprocess IberFire.nc into sharded ConvLSTM-ready samples without loading the cube in memory.

Outputs
- data/convlstm/{train,val,test}/shard_{k}.npz with:
  - X: [N, T, C, H, W] float32
  - y: [N, H, W] uint8 (next-day fire mask)
- data/convlstm/{split}/manifest.parquet
- data/convlstm/stats.json (per-variable mean/std computed on train split)
- data/convlstm/config.json (pipeline config for training notebook)
"""

import argparse
import json
from pathlib import Path
from typing import Dict, List, Tuple, Iterable

import dask
import numpy as np
import pandas as pd
import xarray as xr

# ----------------------
# Config
# ----------------------

DEFAULT_CONFIG = {
    # Input
    "input_path": "data/IberFire.nc",  # or data/iberfire_catalonia.nc
    # Features (must exist in dataset; add/remove as needed)
    "feature_vars": [
        "LST",       # Land Surface Temp
        "SWI_001",   # Soil Water Index (example layers)
        "SWI_010",
        "SWI_020",
        "FWI",       # Fire Weather Index
    ],
    "label_var": "is_fire",
    # Time windowing
    "seq_len": 7,          # T (days)
    "horizon": 1,          # predict next day t+1
    "stride": 1,           # slide by 1 day
    # Spatial tiling
    "tile_h": 128,
    "tile_w": 128,
    # Chunking for xarray/dask
    "chunks": {"time": 256, "y": 256, "x": 256},
    # Splits (inclusive ranges)
    "splits": {
        "train": ["2008-01-01", "2018-12-31"],
        "val":   ["2019-01-01", "2019-12-31"],
        "test":  ["2020-01-01", "2021-12-31"],
    },
    # Sharding
    "samples_per_shard": 32,
    "out_dir": "data/convlstm",
    # NaN handling
    "nan_fill": 0.0,
}

# ----------------------
# Helpers
# ----------------------

def open_ds(path: str, chunks: Dict[str, int]) -> xr.Dataset:
    ds = xr.open_dataset(path, chunks=chunks or "auto")
    # Ensure needed dims present
    for dim in ("time", "y", "x"):
        if dim not in ds.dims:
            raise ValueError(f"Dataset missing required dim: {dim}")
    return ds

def validate_vars(ds: xr.Dataset, feature_vars: List[str], label_var: str) -> List[str]:
    present = []
    missing = []
    for v in feature_vars:
        if v in ds.data_vars and set(ds[v].dims) >= {"time", "y", "x"}:
            present.append(v)
        else:
            missing.append(v)
    if missing:
        print(f"[warn] Missing feature vars (skipped): {missing}")
    if label_var not in ds.data_vars:
        raise ValueError(f"Label var '{label_var}' not found in dataset.")
    return present

def time_index_for_range(time_index: xr.DataArray, start: str, end: str) -> np.ndarray:
    t = pd.to_datetime(time_index.values)
    mask = (t >= np.datetime64(start)) & (t <= np.datetime64(end))
    return np.nonzero(mask)[0]

def compute_norm_stats(ds: xr.Dataset, vars_: List[str], time_idx: np.ndarray) -> Dict[str, Dict[str, float]]:
    stats = {}
    # Restrict dataset to train time indices lazily
    dstrain = ds.isel(time=time_idx)
    for v in vars_:
        arr = dstrain[v]
        # Compute global mean/std across (time, y, x); skip NaNs
        mean = arr.mean(dim=("time", "y", "x"), skipna=True).compute().item()
        std = arr.std(dim=("time", "y", "x"), skipna=True).compute().item()
        # Avoid div by zero
        std = float(std) if std and std > 1e-6 else 1.0
        stats[v] = {"mean": float(mean), "std": float(std)}
    return stats

def iter_time_windows(all_t_idx: np.ndarray, seq_len: int, horizon: int, stride: int) -> Iterable[Tuple[int, int, int]]:
    """
    Yields (t_start, t_end_exclusive, t_label) indices.
    """
    last_label = all_t_idx[-1]
    for t0 in all_t_idx[::stride]:
        t_end = t0 + seq_len  # exclusive
        t_label = t0 + seq_len - 1 + horizon
        if t_label <= last_label:
            yield (t0, t_end, t_label)

def iter_tiles(y_size: int, x_size: int, tile_h: int, tile_w: int) -> Iterable[Tuple[int, int]]:
    for y0 in range(0, y_size, tile_h):
        for x0 in range(0, x_size, tile_w):
            yield y0, x0

def extract_window(
    ds: xr.Dataset,
    feature_vars: List[str],
    label_var: str,
    t0: int,
    t_end: int,
    t_label: int,
    y0: int,
    x0: int,
    tile_h: int,
    tile_w: int,
    stats: Dict[str, Dict[str, float]],
    nan_fill: float,
) -> Tuple[np.ndarray, np.ndarray]:
    ys = slice(y0, min(y0 + tile_h, ds.dims["y"]))
    xs = slice(x0, min(x0 + tile_w, ds.dims["x"]))
    ts = slice(t0, t_end)

    # Features -> [T, C, H, W]
    feat_arrays = []
    for v in feature_vars:
        v_arr = ds[v].isel(time=ts, y=ys, x=xs).astype("float32")
        mean, std = stats[v]["mean"], stats[v]["std"]
        v_arr = (v_arr - mean) / std
        v_arr = v_arr.fillna(nan_fill)
        # to numpy (compute) lazily per window
        v_np = np.asarray(v_arr)  # dask will compute here
        # -> [T, H, W]
        feat_arrays.append(v_np)
    # Stack channels
    X = np.stack(feat_arrays, axis=1)  # [T, C, H, W]

    # Label (next-day fire mask) -> [H, W] uint8
    y_arr = ds[label_var].isel(time=t_label, y=ys, x=xs).astype("uint8").fillna(0)
    y = np.asarray(y_arr)

    return X, y

def write_shard(out_dir: Path, split: str, shard_idx: int, Xs: List[np.ndarray], ys: List[np.ndarray]) -> Path:
    out_dir_split = out_dir / split
    out_dir_split.mkdir(parents=True, exist_ok=True)
    shard_path = out_dir_split / f"shard_{shard_idx:06d}.npz"
    X = np.stack(Xs, axis=0)  # [N, T, C, H, W]
    y = np.stack(ys, axis=0)  # [N, H, W]
    np.savez_compressed(shard_path, X=X, y=y)
    return shard_path

# ----------------------
# Main
# ----------------------

def run(cfg: Dict):
    out_dir = Path(cfg["out_dir"])
    out_dir.mkdir(parents=True, exist_ok=True)

    print("[info] Opening dataset...")
    ds = open_ds(cfg["input_path"], cfg["chunks"])
    feat_vars = validate_vars(ds, cfg["feature_vars"], cfg["label_var"])
    print(f"[info] Using features: {feat_vars}; label: {cfg['label_var']}")
    # Persist chunking for faster repeated reads
    ds = ds[feat_vars + [cfg["label_var"]]].chunk(cfg["chunks"])

    # Split indices
    t_values = ds["time"].values
    split_idx = {
        k: time_index_for_range(ds["time"], v[0], v[1]) for k, v in cfg["splits"].items()
    }

    # Compute normalization stats on train split
    print("[info] Computing normalization stats (train split)...")
    stats = compute_norm_stats(ds, feat_vars, split_idx["train"])
    with open(out_dir / "stats.json", "w") as f:
        json.dump(stats, f, indent=2)
    with open(out_dir / "config.json", "w") as f:
        json.dump(cfg, f, indent=2)

    for split, idx in split_idx.items():
        if len(idx) == 0:
            print(f"[warn] No time indices for split {split}, skipping.")
            continue
        print(f"[info] Generating {split} shards...")
        shard_idx = 0
        X_bucket, y_bucket = [], []
        manifest_rows = []

        y_size, x_size = ds.dims["y"], ds.dims["x"]
        windows = list(iter_time_windows(idx, cfg["seq_len"], cfg["horizon"], cfg["stride"]))
        total_windows = len(windows) * ((y_size + cfg["tile_h"] - 1) // cfg["tile_h"]) * ((x_size + cfg["tile_w"] - 1) // cfg["tile_w"])
        done = 0

        for (t0, t_end, t_label) in windows:
            for (y0, x0) in iter_tiles(y_size, x_size, cfg["tile_h"], cfg["tile_w"]):
                try:
                    X, y = extract_window(
                        ds, feat_vars, cfg["label_var"],
                        t0, t_end, t_label,
                        y0, x0, cfg["tile_h"], cfg["tile_w"],
                        stats, cfg["nan_fill"]
                    )
                except Exception as e:
                    print(f"[warn] failed window t0={t0}, y0={y0}, x0={x0}: {e}")
                    continue

                X_bucket.append(X)
                y_bucket.append(y)
                if len(X_bucket) >= cfg["samples_per_shard"]:
                    shard_path = write_shard(out_dir, split, shard_idx, X_bucket, y_bucket)
                    manifest_rows.append({
                        "shard": shard_path.name,
                        "num_samples": len(X_bucket),
                        "t_start": str(pd.to_datetime(t_values[t0])),
                        "t_label": str(pd.to_datetime(t_values[t_label])),
                    })
                    shard_idx += 1
                    X_bucket, y_bucket = [], []

                done += 1
                if done % 50 == 0:
                    print(f"[info] {split}: {done}/{total_windows} samples queued...")

        # Flush remainder
        if X_bucket:
            shard_path = write_shard(out_dir, split, shard_idx, X_bucket, y_bucket)
            manifest_rows.append({
                "shard": shard_path.name,
                "num_samples": len(X_bucket),
                "t_start": str(pd.to_datetime(t_values[windows[-1][0]])),
                "t_label": str(pd.to_datetime(t_values[windows[-1][2]])),
            })

        # Save manifest
        if manifest_rows:
            man_df = pd.DataFrame(manifest_rows)
            man_df.to_parquet(out_dir / split / "manifest.parquet", index=False)
            print(f"[info] Wrote {len(manifest_rows)} shards to {out_dir / split}")

    print("[done] Preprocessing complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", type=str, default=DEFAULT_CONFIG["input_path"])
    parser.add_argument("--out", type=str, default=DEFAULT_CONFIG["out_dir"])
    parser.add_argument("--seq-len", type=int, default=DEFAULT_CONFIG["seq_len"])
    parser.add_argument("--horizon", type=int, default=DEFAULT_CONFIG["horizon"])
    parser.add_argument("--stride", type=int, default=DEFAULT_CONFIG["stride"])
    parser.add_argument("--tile-h", type=int, default=DEFAULT_CONFIG["tile_h"])
    parser.add_argument("--tile-w", type=int, default=DEFAULT_CONFIG["tile_w"])
    parser.add_argument("--samples-per-shard", type=int, default=DEFAULT_CONFIG["samples_per_shard"])
    args = parser.parse_args()

    cfg = DEFAULT_CONFIG.copy()
    cfg["input_path"] = args.input
    cfg["out_dir"] = args.out
    cfg["seq_len"] = args.seq_len
    cfg["horizon"] = args.horizon
    cfg["stride"] = args.stride
    cfg["tile_h"] = args.tile_h
    cfg["tile_w"] = args.tile_w
    cfg["samples_per_shard"] = args.samples_per_shard

    run(cfg)