In [1]:
from __future__ import annotations

from typing import Dict, Iterable, List, Mapping

import numpy as np
import pandas as pd
import xarray as xr


# Create ONE global registry, make it ndarray-like, and set it as application registry
ureg = pint.UnitRegistry()
# This line fixes: ValueError: invalid registry. Please enable 'force_ndarray_like' or 'force_ndarray'.
ureg.force_ndarray_like = True  # or: ureg.force_ndarray = True (stricter)
pint.set_application_registry(ureg)


def _union_sorted_times(all_times: List[pd.DatetimeIndex]) -> pd.DatetimeIndex:
    if not all_times:
        return pd.DatetimeIndex([])
    out = all_times[0]
    for t in all_times[1:]:
        out = out.union(t)
    return out.sort_values()


def build_dataset(
    entries: Iterable[Mapping],
    convert_units: bool = True,
) -> xr.Dataset:
    """
    Build a pint-aware xarray.Dataset from a list of station dictionaries.

    Each entry is expected to have keys:
      - 'name': station name
      - 'type': variable name to use (e.g., 'QO')
      - 'timeseries': list of (timestamp, value) tuples
      - 'unit': e.g. 'm^3/s'
      - 'computational_unit': string describing the type (e.g., 'subbasin') [stored in attrs]
      - 'computational_unit_id': integer id for the unit (used as coordinate values)
      - 'freq': e.g., '1D'

    Dataset layout:
      - dims: (dim_name, time)
      - variables: one per unique 'type' (e.g., 'QO') with units attached via pint
      - coords:
          - time: union of all timestamps across entries
          - dim_name: all unique computational_unit_id values
          - name(dim_name,): station names aligned to each id (first seen if duplicates)
          - freq(dim_name,): freq string per id (first seen if duplicates)
      - attrs:
          - computational_unit_kind: the label found in entries (e.g., 'subbasin')

    If convert_units is True, values for the same 'type' with different units are converted
    to the first encountered unit for that type using pint. If False, mismatched units
    across the same 'type' will raise a ValueError.
    """
    entries = list(entries)

    # Collect per-entry parsed series and metadata
    per_entry_series: List[pd.Series] = []
    per_entry_time_index: List[pd.DatetimeIndex] = []
    per_entry_meta: List[Dict] = []

    for e in entries:
        ts = e.get("timeseries", [])
        if not ts:
            idx = pd.DatetimeIndex([])
            vals = np.array([], dtype=float)
        elif isinstance(ts, pd.Series):
            pass # do nothing
        else:
            t, v = zip(*ts)
            idx = pd.to_datetime(list(t))
            vals = np.array(list(v), dtype=float)

        per_entry_series.append(pd.Series(vals, index=idx))
        per_entry_time_index.append(idx)

        per_entry_meta.append(
            dict(
                name=e.get("name"),
                typ=e.get("type"),
                unit=e.get("unit"),
                cu_kind=e.get("computational_unit"),
                cu_id=e.get("computational_unit_id"),
                freq=e.get("freq"),
            )
        )

    # Global coordinates
    global_time = _union_sorted_times(per_entry_time_index)

    # Unique computational_unit ids in first-seen order
    seen_ids: List[int] = []
    id_to_first_index: Dict[int, int] = {}
    for m in per_entry_meta:
        cu_id = m["cu_id"]
        if cu_id not in id_to_first_index:
            id_to_first_index[cu_id] = len(seen_ids)
            seen_ids.append(cu_id)
    cu_ids = np.array(seen_ids)
    n_cu = len(cu_ids)
    n_time = len(global_time)

    # Coordinate arrays per computational unit
    names_by_id: Dict[int, str] = {}
    freq_by_id: Dict[int, str] = {}
    cu_kind: str | None = None

    # Prepare containers per type
    arrays_by_type: Dict[str, np.ndarray] = {}
    unit_by_type: Dict[str, str] = {}

    def _ensure_matrix_for_type(typ: str):
        if typ not in arrays_by_type:
            arrays_by_type[typ] = np.full((n_cu, n_time), np.nan, dtype=float)

    # First encountered cu_kind (e.g., 'subbasin')
    for m in per_entry_meta:
        if m["cu_kind"] is not None:
            cu_kind = cu_kind or m["cu_kind"]
    # assign `dim_name` to cu_kind
    dim_name = cu_kind

    # Fill matrices
    for s, m in zip(per_entry_series, per_entry_meta):
        typ = m["typ"]
        unit = m["unit"]
        cu_id = m["cu_id"]
        name = m["name"]
        freq = m["freq"]

        if cu_id not in names_by_id and name is not None:
            names_by_id[cu_id] = name
        if cu_id not in freq_by_id and freq is not None:
            freq_by_id[cu_id] = freq

        _ensure_matrix_for_type(typ)

        # Set reference unit for this type
        if typ not in unit_by_type:
            unit_by_type[typ] = unit
        else:
            if unit != unit_by_type[typ]:
                if not convert_units:
                    raise ValueError(
                        f"Found inconsistent units for type '{typ}': "
                        f"{unit} vs {unit_by_type[typ]}"
                    )

        # Align times
        s_aligned = s.reindex(global_time)

        # Convert units if needed
        if unit != unit_by_type[typ]:
            q = s_aligned.to_numpy() * ureg(unit)
            s_aligned_vals = q.to(unit_by_type[typ]).magnitude
        else:
            s_aligned_vals = s_aligned.to_numpy()

        row = id_to_first_index[cu_id]
        arrays_by_type[typ][row, :] = s_aligned_vals

    # Build coords
    name_arr = np.array([names_by_id.get(cu, None) for cu in cu_ids], dtype=str)
    freq_arr = np.array([freq_by_id.get(cu, None) for cu in cu_ids], dtype=str)

    coords = {
        dim_name: cu_ids,
        "time": global_time,
        "name": (dim_name, name_arr),
        "freq": (dim_name, freq_arr),
    }

    # Build data variables with units in attrs
    data_vars = {}
    for typ, arr in arrays_by_type.items():
        data_vars[typ] = ((dim_name, "time"), arr, {"units": unit_by_type[typ]})

    ds = xr.Dataset(data_vars=data_vars, coords=coords)

    if cu_kind is not None:
        ds.attrs["computational_unit_kind"] = cu_kind

    # return ds

    # Quantify using the SAME registry we configured above
    quantify_map = {typ: unit for typ, unit in unit_by_type.items()}
    dsq = ds.pint.quantify(quantify_map, unit_registry=ureg)

    return dsq


if __name__ == "__main__":
    # Minimal example
    data = [
        {
            "name": "station_1",
            "type": "QO",
            "timeseries": [
                ("2020-01-01", 10.5),
                ("2020-01-02", 12.3),
            ],
            # Prefer pint-friendly 'm^3/s' spelling
            "unit": "m^3/s",
            "computational_unit": "subbasin",
            "computational_unit_id": 14,
            "freq": "1D",
        },
        {
            "name": "station_2",
            "type": "QO",
            "timeseries": [
                ("2020-01-01", 7.1),
                ("2020-01-03", 9.4),
            ],
            "unit": "m^3/s",
            "computational_unit": "subbasin",
            "computational_unit_id": 22,
            "freq": "1D",
        },
    ]

    ds = build_dataset(data, convert_units=True)
    print(ds)

<xarray.Dataset> Size: 176B
Dimensions:   (subbasin: 2, time: 3)
Coordinates:
  * subbasin  (subbasin) int64 16B 14 22
  * time      (time) datetime64[ns] 24B 2020-01-01 2020-01-02 2020-01-03
    name      (subbasin) <U9 72B 'station_1' 'station_2'
    freq      (subbasin) <U2 16B '1D' '1D'
Data variables:
    QO        (subbasin, time) float64 48B [m³/s] 10.5 12.3 nan 7.1 nan 9.4
Attributes:
    computational_unit_kind:  subbasin


In [2]:
ds

0,1
Magnitude,[[10.5 12.3 nan] [7.1 nan 9.4]]
Units,meter3/second
