# File to complete data ingestion with PyTorch's `dataset` and `DataLoader` directly from NetCDF file.

In [7]:
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import List, Optional, Tuple

In [5]:
DATASET_PATH = Path("data/IberFire.nc")


In [8]:
class SpatioTemporalDataset(Dataset):
    """
    A PyTorch Dataset for spatio-temporal data stored in a NetCDF file.

    Each sample:
        - inputs: sequence of past time steps (T, C, H, W)
        - target: target variable at the *next* time step (H, W)

    Spatially, the domain is split into non-overlapping patches of size (patch_size x patch_size).
    """

    def __init__(
        self,
        data_path: Path,
        sequence_length: int = 30,
        stride: int = 1,
        patch_size: int = 64,
        target_variable: str = "is_near_fire",
        predictor_variables: Optional[List[str]] = None,
    ):
        """
        Parameters
        ----------
        data_path : Path
            Path to the NetCDF file.
        sequence_length : int
            Number of past time steps T in each input sequence.
        stride : int
            Step in time between the starts of consecutive sequences.
        patch_size : int
            Height/width of each square patch (in grid cells).
        target_variable : str
            Name of the variable to predict at the next time step.
        predictor_variables : list of str, optional
            Variables to use as input channels.
            If None, uses all data variables except the target.
        """
        self.data_path = Path(data_path)
        self.sequence_length = sequence_length
        self.stride = stride
        self.patch_size = patch_size
        self.target_variable = target_variable

        # IMPORTANT: open without chunks → simpler, each __getitem__ does a small read
        # You can add engine="h5netcdf" if needed
        self.ds = xr.open_dataset(self.data_path, decode_times=True)

        # Basic dimension info
        self.time_len = int(self.ds.dims["time"])
        self.y_len = int(self.ds.dims["y"])
        self.x_len = int(self.ds.dims["x"])

        # Choose predictor variables
        if predictor_variables is None:
            self.predictor_variables = [
                v for v in self.ds.data_vars if v != self.target_variable
            ]
        else:
            self.predictor_variables = predictor_variables

        # Precompute how many patches fit in y/x
        self.n_patches_y = self.y_len // self.patch_size  # floor
        self.n_patches_x = self.x_len // self.patch_size  # floor

        # How many valid time windows?
        # We use times [t0, ..., t0+T-1] as inputs, and time t_target = t0+T as target.
        # Need t_target < time_len  →  t0 <= time_len - T - 1
        max_t0 = self.time_len - self.sequence_length - 1
        if max_t0 < 0:
            raise ValueError("sequence_length is too long for the available time dimension.")

        self.n_time_windows = (max_t0 // self.stride) + 1

        # Total samples = time_windows * patches_y * patches_x
        self._len = self.n_time_windows * self.n_patches_y * self.n_patches_x

    def __len__(self) -> int:
        return self._len

    def _index_to_coords(self, idx: int) -> Tuple[int, int, int, slice, slice]:
        """
        Map a flat dataset index to:
          - t0, t1: start/end of input sequence in time
          - t_target: time index for target (next step after sequence)
          - y_slice, x_slice: patch slice in space
        """
        if idx < 0 or idx >= self._len:
            raise IndexError(idx)

        patches_per_time = self.n_patches_y * self.n_patches_x

        # which time window?
        time_idx = idx // patches_per_time
        patch_idx = idx % patches_per_time

        # time indices
        t0 = time_idx * self.stride
        t1 = t0 + self.sequence_length
        t_target = t1  # next time step after the sequence

        # which patch in y/x?
        py = patch_idx // self.n_patches_x
        px = patch_idx % self.n_patches_x

        y0 = py * self.patch_size
        y1 = y0 + self.patch_size
        x0 = px * self.patch_size
        x1 = x0 + self.patch_size

        y_slice = slice(y0, y1)
        x_slice = slice(x0, x1)

        return t0, t1, t_target, y_slice, x_slice

    def __getitem__(self, idx: int):
        t0, t1, t_target, y_slice, x_slice = self._index_to_coords(idx)

        # ---- Inputs: (T, C, H, W) ----
        # Take selected predictor variables, slice in time and space
        # ds[predictor_variables] is a Dataset; to_array() gives DataArray with
        # dims: (variable, time, y, x)
        window_ds = self.ds[self.predictor_variables].isel(
            time=slice(t0, t1),
            y=y_slice,
            x=x_slice,
        )
        window_da = window_ds.to_array("channel")  # dims: (channel, time, y, x)
        window_da = window_da.transpose("time", "channel", "y", "x")
        x_np = window_da.values.astype("float32")  # (T, C, H, W)

        # ---- Target: (H, W) ----
        target_da = self.ds[self.target_variable].isel(
            time=t_target,
            y=y_slice,
            x=x_slice,
        )
        y_np = target_da.values.astype("float32")  # (H, W)

        # Convert to torch tensors
        x = torch.from_numpy(x_np)  # (T, C, H, W)
        y = torch.from_numpy(y_np)  # (H, W)

        sample = {
            "inputs": x,
            "target": y,
            "t0": t0,
        }
        return sample

In [10]:
from torch.utils.data import DataLoader

dataset = SpatioTemporalDataset(
    data_path=DATASET_PATH,
    sequence_length=16,
    stride=1,
    patch_size=64,
    target_variable="is_near_fire",
    predictor_variables=["temp_2m", "wind_speed", "humidity"],  # example
)

loader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=False,    # keep temporal order
    num_workers=0,    # start with 0; you can tune later
    pin_memory=True,
)

batch = next(iter(loader))
print(batch["inputs"].shape)   # (B, T, C, H, W)
print(batch["target"].shape)   # (B, H, W)
print(batch["t0"])             # starting time indices for each sequence

  self.time_len = int(self.ds.dims["time"])
  self.y_len = int(self.ds.dims["y"])
  self.x_len = int(self.ds.dims["x"])


KeyError: 'temp_2m'