# Process daily files in dask gateway

In [4]:
# --- Core data handling and plotting libraries ---
import earthaccess
import xarray as xr       # for working with labeled multi-dimensional arrays
import numpy as np        # for numerical operations on arrays
import pandas as pd
# --- Custom python functions ---
import os, importlib
# Looks to see if you have the file already and if not, downloads from GitHub
if not os.path.exists("ml_utils.py"):
    !wget -q https://raw.githubusercontent.com/fish-pace/2025-tutorials/main/ml_utils.py

import ml_utils as mu
importlib.reload(mu)

<module 'ml_utils' from '/home/jovyan/chla-z-modeling/ml_utils.py'>

In [5]:
def process_one_granule(
    res,
    lat_chunk=100,
    lon_chunk=100,
    bucket_name="nmfs_odp_nwfsc",
    destination_prefix="CB/fish-pace-datasets/chla-z/netcdf",
):
    """
    Run the full pipeline for a single PACE L3M Rrs DAY granule:
      - download Rrs via earthaccess
      - run BRT CHLA(z) prediction
      - build derived metrics
      - write daily NetCDF to local temp
      - upload NetCDF to GCS

    Returns
    -------
    str
        GCS path of the uploaded NetCDF (for logging/debugging).
    """
    import earthaccess
    import xarray as xr
    import numpy as np
    import pandas as pd
    from pathlib import Path
    from google.cloud import storage
    import tempfile

    # 1. auth for earthaccess on the worker
    auth = earthaccess.login(persist=True)

    # 2. get date for this granule
    day_iso = res["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"]
    day = pd.to_datetime(day_iso)          # Timestamp
    day_str = day.strftime("%Y%m%d")

    # 3. open Rrs dataset for this granule
    files = earthaccess.open([res], auth=auth, pqdm_kwargs={"disable": True})
    rrs_ds = xr.open_dataset(files[0])

    try:
        # Rrs for that day
        if "time" in rrs_ds.dims:
            R = rrs_ds["Rrs"].sel(time=day).squeeze("time")
        else:
            R = rrs_ds["Rrs"]
        R = R.transpose("lat", "lon", "wavelength")

        # 4. CHLA(z) prediction for this day
        # NOTE: bundle and build_chla_profile_dataset must be importable/picklable
        pred = bundle.predict(
            R,
            brt_models=bundle.model,
            feature_cols=bundle.meta["feature_cols"],
            consts={"solar_hour": 0, "type": 1},
            chunk_size_lat=100,
            time=day.to_datetime64(),   # time coord length 1
            z_name="z",
            silent=True,
        )  # (time=1, z, lat, lon), float32

        ds_day = build_chla_profile_dataset(pred)

        # 5. add / update metadata
        ds_day["CHLA"].attrs.update(
            units="mg m-3",
            long_name="Chlorophyll-a concentration",
            description="BRT-derived CHLA profiles from PACE hyperspectral Rrs",
        )
        ds_day["z"].attrs.update(units="m", long_name="depth (bin center)")
        ds_day["lat"].attrs.update(units="degrees_north")
        ds_day["lon"].attrs.update(units="degrees_east")
        ds_day.attrs["source"] = "BRT model trained on BGC-Argo + OOI matchups"
        ds_day.attrs["model_bundle"] = Path("path/to/bundle.zip").name  # adjust

        # 6. write to local temp NetCDF
        tmp_dir = Path(tempfile.gettempdir())
        local_path = tmp_dir / f"chla_z_{day_str}.nc"

        encoding = {
            "CHLA": {
                "dtype": "float32",
                "zlib": True,
                "complevel": 4,
                "chunksizes": (1, ds_day.sizes["z"], lat_chunk, lon_chunk),
            }
        }

        ds_day.to_netcdf(
            local_path,
            engine="h5netcdf",
            encoding=encoding,
        )

        # 7. upload to GCS
        storage_client = storage.Client(project="noaa-gcs-public-data")
        bucket = storage_client.bucket(bucket_name)

        blob_path = f"{destination_prefix}/chla_z_{day_str}.nc"
        blob = bucket.blob(blob_path)
        blob.upload_from_filename(str(local_path))

        # clean up local file
        local_path.unlink(missing_ok=True)

        gcs_url = f"gs://{bucket_name}/{blob_path}"
        print(f"[{day_str}] Uploaded â†’ {gcs_url}")
        return gcs_url

    finally:
        rrs_ds.close()


In [6]:
# Read in model

## Load model
bundle = mu.load_ml_bundle("models/brt_chla_profiles_bundle.zip")
brt_models = bundle.model
meta = bundle.meta
rrs_cols = meta["rrs_cols"]
chl_cols = meta["y_col"]
extra = meta["extra_cols"]
dataset = bundle.data["dataset"]
train_idx = bundle.data["train_idx"]
test_idx = bundle.data["test_idx"]
X_train = bundle.data["X_train"]
X_test = bundle.data["X_test"]
y_train_all = bundle.data["y_train"]
y_test_all = bundle.data["y_test"]


Loaded ML bundle from: models/brt_chla_profiles_bundle.zip
  model_kind : pickle
  model_type : collection (dict), n_submodels=20
  example key: CHLA_0_10
  target     : log10_CHLA_A_B depth bins
  features   : 174 columns
  train/test : 4408 / 1102 rows
  dataset    : 5510 rows stored in bundle

Usage example (Python):
  bundle = load_ml_bundle('path/to/bundle.zip')
  # Predict using helper 'predict_all_depths_for_day'
  # Example: predict all depths for one day from a BRF dataset R
  pred = bundle.predict(
      R_dataset,                  # xr.DataArray/xr.Dataset with lat/lon + predictors
      brt_models=bundle.model,    # dict of models by depth bin
      feature_cols=bundle.meta['feature_cols'],
      consts={'solar_hour': 12.0, 'type': 1},
  )  # -> e.g. CHLA(time?, z, lat, lon)

  # Plot using helper 'make_plot_pred_map'
  fig, ax = bundle.plot(pred_da, pred_label='Prediction')



# Set up Dask Gateway

In [None]:
from dask_gateway import Gateway

gateway = Gateway()
options = gateway.cluster_options()

# I don't know how to decide. I know that one day takes 5Gb RAM
# options.worker_cores = 4
# options.worker_memory = "32GiB"

cluster = gateway.new_cluster(options)

# I don't know how to decide
# cluster.scale(8)  # say 8 workers

# I have 560 days and each day is 30min. I don't want this to take all day
cluster.adapt(minimum=4, maximum=16)

from dask.distributed import Client
client = cluster.get_client()
print(cluster)
print(client)


In [None]:
# we already have rrs_results from earthaccess.search_data
# maybe subset by date or just take all DAY granules
granules = rrs_results[:10]   # or rrs_results[:100] for testing

# one Dask task per granule
futures = client.map(process_one_granule, granules)

# block until all are done, get the GCS URLs (or errors)
results = client.gather(futures)
print("Uploaded daily files:")
for r in results:
    print("  ", r)

client.close()
cluster.close()