In [1]:
from __future__ import annotations
from typing import Tuple

import xarray as xr
import numpy as np
import cf_xarray
import  s3fs, gcsfs, fsspec, zarr

In [2]:
### helper functions to normalize coords

def _select_variable(ds: xr.Dataset, var: Union[str, Dict[str, str]]) -> str:
    """
    Pick a variable name from a Dataset.

    If var is a string, it's treated as a variable name.
    If cf_xarray is available and var is a dict, it's used for CF-aware selection
    (e.g., `{"standard_name": "sea_surface_temperature"}`).
    A fallback attempts to match attributes case-insensitively if CF-selection fails.
    """
    if isinstance(var, str):
        if var in ds.data_vars:
            return var
        for k in ds.data_vars:
            if k.lower() == var.lower():
                return k
        raise KeyError(f"Variable '{var}' not found. Available: {list(ds.data_vars)}")

    if _HAS_CF:
        for key in ["standard_name", "long_name", "units"]:
            if key in var:
                matches = ds.cf.select_variables(**{key: var[key]})
                if matches:
                    return list(matches)[0]

    key_order = ["standard_name", "long_name", "units"]
    for candidate in ds.data_vars:
        attrs = {k: str(ds[candidate].attrs.get(k, "")).lower() for k in key_order}
        if any((k in var) and (str(var[k]).lower() == attrs.get(k, "")) for k in key_order):
            return candidate
    raise KeyError(f"Could not locate variable from hints {var}. Variables: {list(ds.data_vars)}")


def _normalize_coord_names(ds: xr.Dataset) -> xr.Dataset:
    """
    Standardize coordinate names to 'latitude', 'longitude', 'time'.
    """
    rename_map = {}
    for alias, standard in {"lat": "latitude", "lon": "longitude"}.items():
        if alias in ds.coords and standard not in ds.coords:
            rename_map[alias] = standard
    return ds.rename(rename_map) if rename_map else ds


def _infer_target_lon_frame(lon_min: float, lon_max: float) -> str:
    """
    Infers whether user-provided longitude bounds are in 0-360 or -180-180 frame.
    """
    return "0-360" if (lon_min >= 0 and lon_max <= 360) else "-180-180"


def _coerce_longitudes(ds: xr.Dataset, target_frame: str, assume_frame: Optional[str] = None) -> xr.Dataset:
    """
    Coerce dataset longitudes to a target frame ('0-360' or '-180-180').
    """
    if "longitude" not in ds.coords:
        return ds

    lon = ds["longitude"].values
    if assume_frame:
        current = assume_frame
    else:
        current = "0-360" if (np.nanmin(lon) >= 0 and np.nanmax(lon) <= 360) else "-180-180"

    if current == target_frame:
        return ds

    if target_frame == "0-360":
        lon_new = np.mod(lon, 360.0)
    else:  # target is -180-180
        lon_new = ((lon + 180) % 360) - 180
    
    ds = ds.assign_coords(longitude=lon_new)
    return ds.sortby("longitude")


def _ensure_lat_monotonic(ds: xr.Dataset) -> xr.Dataset:
    """
    Ensures the latitude coordinate is monotonically increasing.
    """
    if "latitude" in ds.coords and ds["latitude"].ndim == 1 and ds["latitude"].values[0] > ds["latitude"].values[-1]:
        return ds.sortby("latitude")
    return ds


def _slice_longitude(ds: xr.Dataset, lon_min: float, lon_max: float) -> xr.Dataset:
    """
    Slice longitude robustly, handling wrap-around for ranges like 350E to 10E.
    """
    if lon_min <= lon_max:
        return ds.sel(longitude=slice(lon_min, lon_max))
    
    lon = ds["longitude"]
    part1 = ds.sel(longitude=slice(lon_min, float(lon.max())))
    part2 = ds.sel(longitude=slice(float(lon.min()), lon_max))
    return xr.concat([part1, part2], dim="longitude")

In [3]:
def load_aws_dataset(
    s3_path: str,
    variable_of_interest: Union[str, Dict[str, str]],
    region_of_interest: Optional[Dict[str, float]] = None,
    time_of_interest: Optional[Union[slice, Tuple[str, str]]] = None,
    *,
    group: Optional[str] = None,
    consolidated: Optional[bool] = None,
    chunks: Optional[Dict] = None,
    assume_lon: Optional[str] = None,  # "0-360" or "-180-180" if you know...
    return_dataset: bool = False,
    save_to: Optional[Union[str, pathlib.Path]] = None,
) -> Union[xr.DataArray, xr.Dataset]:
    """
    Load and subset a Zarr dataset from a public AWS S3 bucket.

    Parameters
    ----------
    s3_path:
        The full S3 path to the Zarr store (e.g., "s3://era5-pds/zarr/...").
    variable_of_interest:
        - Name of the variable in the dataset (e.g., "sst", "tos", "t2m"), OR
        - A mapping of CF/long-name hints to try, e.g.:
          {"standard_name": "sea_surface_temperature"}
    region_of_interest:
        Dict with geographic bounds: {"lat_min": -90, "lat_max": 90, "lon_min": 0, "lon_max": 360}.
        Longitudes may be 0–360 or −180–180. Function will reconcile.
    time_of_interest:
        Either a Python slice (e.g., slice("1990-01-01","2000-12-31")) or a 2-tuple of ISO strings.
    group:
        Zarr group within the store (e.g., "spatial" for ERA5).
    consolidated:
        Whether the Zarr store is consolidated. If None, attempts sensible defaults.
    chunks:
        Dask chunking dict, e.g., {"time": 2400}.
    assume_lon:
        If set, forces interpretation of dataset longitudes as "0-360" or "-180-180".
    return_dataset:
        If True, return the full Dataset. Otherwise return the selected DataArray.
    save_to:
        Optional path to save the subset as NetCDF.

    Returns
    -------
    xr.DataArray or xr.Dataset
        The subsetted data.
    """
    # normalize input params
    if isinstance(time_of_interest, tuple):
        time_of_interest = slice(time_of_interest[0], time_of_interest[1])

    region = region_of_interest or {}
    lat_min = region.get("lat_min", None)
    lat_max = region.get("lat_max", None)
    lon_min = region.get("lon_min", None)
    lon_max = region.get("lon_max", None)

    # config aws
    storage_options = {"anon": True}
    if consolidated is None:
        consolidated = False if (group is not None and "era5" in s3_path.lower()) else True

    # open dataset
    ds = xr.open_dataset(
        s3_path,
        engine="zarr",
        chunks=chunks,
        consolidated=consolidated,
        backend_kwargs={
            "storage_options": storage_options,
            **({"group": group} if group else {}),
        },
    )

    # select variable (cf-aware if possible)
    var = _select_variable(ds, variable_of_interest)

    # normalize coordinate names
    ds = _normalize_coord_names(ds)

    # fix lon to desired slicing frame, if needed 
    if (lon_min is not None) and (lon_max is not None):
        ds = _coerce_longitudes(ds, target_frame=_infer_target_lon_frame(lon_min, lon_max), assume_frame=assume_lon)

    # wnsure latitude selection works if lat is descending (ERA5 style)
    if (lat_min is not None) and (lat_max is not None):
        ds = _ensure_lat_monotonic(ds)

    # apply coord selections
    sel = ds
    if time_of_interest is not None and ("time" in sel.dims or "time" in sel.coords):
        sel = sel.sel(time=time_of_interest)
    if (lat_min is not None) and (lat_max is not None) and "latitude" in sel.coords:
        sel = sel.sel(latitude=slice(min(lat_min, lat_max), max(lat_min, lat_max)))
    if (lon_min is not None) and (lon_max is not None) and "longitude" in sel.coords:
        sel = _slice_longitude(sel, lon_min, lon_max)

    # return subset da/ds
    out = sel if return_dataset else sel[var]
    if save_to is not None:
        save_path = pathlib.Path(save_to).expanduser().resolve()
        save_path.parent.mkdir(parents=True, exist_ok=True)
        out.to_netcdf(save_path)
    return out