# Create daily global netcdfs


## Step 1 a function to make a prediction from BRT

This is in `ml_utils.py` but also bundled into a saved model. This outputs an xarray DataArray with time, z, lat, lon and our CHLA predictions.

In [2]:
import numpy as np
import xarray as xr

def predict_all_depths_for_day(
    R: xr.DataArray,         # (lat, lon, wavelength)
    brt_models: dict,        # e.g. {"CHLA_0_10": model0, "CHLA_10_20": model1, ...}
    feature_cols: list,
    consts=None,
    chunk_size_lat: int = 100,
    time=None,               # e.g. "2024-07-15" or np.datetime64
    z: np.ndarray | None = None,   # optional override for depth centers
    z_name: str = "z",       # vertical dimension name
):
    """
    Run BRT predictions for all depth bins for a single day.

    Parameters
    ----------
    R : xr.DataArray
        Predictor array of Rrs wavelengths from PACE on (lat, lon, wavelength) (no time dimension).
    brt_models : dict
        Mapping depth-label -> fitted model, e.g.
        {"CHLA_0_10": model0, "CHLA_10_20": model1, ...}.
        The last two underscore-separated tokens are assumed to be
        depth start/end in meters, e.g. "CHLA_0_10" -> 0, 10.
    feature_cols : list of str
        Columns expected by the BRT models.
    consts : dict, optional
        Feature -> scalar value for constants (e.g. {"solar_hour": 12.0, "type": 1}).
    chunk_size_lat : int
        Number of latitude indices per chunk.
    time : str or np.datetime64, optional
        Time stamp for this prediction. If provided, a `time` dimension of length 1
        is added to the output.
    z : array-like, optional
        Depth centers (same order as brt_models keys). If not given, centers are
        inferred as (z_start + z_end)/2 from the model name.
    z_name : str, default "z"
        Name of the vertical dimension in the output.

    Returns
    -------
    xr.DataArray
        CHLA prediction with dims:
            (time, z_name, lat, lon)      if `time` provided
            (z_name, lat, lon)           otherwise

        Coordinates:
            z_name         : depth center (m)
            f"{z_name}_start" : depth bin lower bound (m)
            f"{z_name}_end"   : depth bin upper bound (m)
    """
    consts = consts or {}
    R = R.transpose("lat", "lon", "wavelength")

    depth_labels = list(brt_models.keys())
    n_depth = len(depth_labels)

    # --- parse z_start / z_end / z_center from labels like ABC_0_10 ---
    z_start_arr = np.full(n_depth, np.nan, dtype="float32")
    z_end_arr   = np.full(n_depth, np.nan, dtype="float32")
    z_center_arr = np.full(n_depth, np.nan, dtype="float32")

    for i, label in enumerate(depth_labels):
        parts = label.split("_")
        if len(parts) >= 3:
            try:
                z0 = float(parts[-2])
                z1 = float(parts[-1])
                z_start_arr[i] = z0
                z_end_arr[i]   = z1
                z_center_arr[i] = 0.5 * (z0 + z1)
            except ValueError:
                # leave as NaN if parsing fails
                pass

    # if user provided z, override centers
    if z is not None:
        z_center_arr = np.asarray(z, dtype="float32")
        if z_center_arr.shape[0] != n_depth:
            raise ValueError(f"len(z)={len(z_center_arr)} does not match number of models={n_depth}")

    nlat = R.sizes["lat"]
    lat_coord = R["lat"]

    depth_chunks = {label: [] for label in depth_labels}

    # --- chunk over latitude ---
    for start in range(0, nlat, chunk_size_lat):
        stop = min(start + chunk_size_lat, nlat)
        R_chunk = R.isel(lat=slice(start, stop))

        for label, model in brt_models.items():
            pred_chunk = make_prediction_brt(
                R_chunk,
                brt_model=model,
                feature_cols=feature_cols,
                solar_const=0, type_const=1,
            )
            depth_chunks[label].append(pred_chunk)

    # --- stitch each depth over lat, then stack into vertical dimension ---
    per_depth = []
    for idx, (label, chunks) in enumerate(depth_chunks.items()):
        da = xr.concat(chunks, dim="lat").assign_coords(lat=lat_coord)
        per_depth.append(da.expand_dims({z_name: [idx]}))

    pred_all = xr.concat(per_depth, dim=z_name)  # (z, lat, lon)
    pred_all.name = "CHLA"

    # vertical coordinates
    pred_all = pred_all.assign_coords(
        {
            z_name: z_center_arr,
            f"{z_name}_start": (z_name, z_start_arr),
            f"{z_name}_end":   (z_name, z_end_arr),
        }
    )

    # optional time dimension
    if time is not None:
        time_val = np.datetime64(time)
        pred_all = pred_all.expand_dims(time=[time_val])

    # note about depth inference
    pred_all.attrs.setdefault(
        "depth_info",
        f"Depth coordinates inferred from brt_models keys of form 'NAME_z0_z1'. "
        f"z is the bin center, {z_name}_start/{z_name}_end are bin bounds (m)."
    )

    return pred_all

## Create a dataset with our derived variables

In [None]:
def build_chla_profile_dataset(CHLA: xr.DataArray) -> xr.Dataset:
    """
    Given CHLA(time, z, lat, lon), compute derived metrics and
    return an xr.Dataset suitable for writing to Zarr/NetCDF.
    """

    z_thick = CHLA.coords.get("z_end", None) - CHLA.coords.get("z_start", None)
    if z_thick is None or np.all(np.isnan(z_thick)):
        # fallback: assume uniform 10 m bins, or something you know
        z_thick = xr.full_like(CHLA.isel(z=0), 10.0)

    # integrated CHLA 0â€“200m
    CHLA_int = (CHLA * z_thick).sum("z")
    CHLA_int.name = "CHLA_int_0_200"

    # peak value and depth
    peak_idx = CHLA.argmax("z")
    CHLA_peak = CHLA.isel(z=peak_idx)
    CHLA_peak.name = "CHLA_peak"

    z_center = CHLA["z"]
    depth_peak = z_center.isel(z=peak_idx)
    depth_peak.name = "CHLA_peak_depth"

    # depth-weighted mean depth
    num = (CHLA * z_center).sum("z")
    den = CHLA.sum("z")
    depth_cm = num / den
    depth_cm.name = "CHLA_depth_center_of_mass"

    ds = xr.Dataset(
        {
            "CHLA": CHLA,                     # (time, z, lat, lon)
            "CHLA_int_0_200": CHLA_int,       # (time, lat, lon)
            "CHLA_peak": CHLA_peak,           # (time, lat, lon)
            "CHLA_peak_depth": depth_peak,    # (time, lat, lon)
            "CHLA_depth_center_of_mass": depth_cm,  # (time, lat, lon)
        }
    )

    # carry over coords and attrs from CHLA
    ds = ds.assign_coords(CHLA.coords)
    ds.attrs.update(CHLA.attrs)

    return ds


In [None]:
%%time
## How long does a global prediciton take, 14 min

R = rrs_ds["Rrs"]
feature_cols = list(X_train.columns)

CHLA_day = predict_all_depths_for_day(
    R=R,
    brt_models=brt_models,
    feature_cols=feature_cols,
    consts={"type": 1, "solar_hour": 0},  
    chunk_size_lat=100,
)