# Parallel daily CHLA(z) netcdf production using Dask-Gateway

This shows my pipeline for processing netcdfs and posting to a Google Cloud Bucket.

[Video of me walking through this notebook](https://youtu.be/mCfMGyKEJgU)

1. Get links to Rrs daily L3 files
2. Create a function to create predicted CHLA(z) from BRT model
3. Create function to start the Gateway cluster
4. Run!

**A couple gotchas to avoid**

* This saves a temp netcdf to /tmp. There is not much room on that (like 10Gb?). If you forget to delete files as you work, then you will max it out and crash your server. If you see lots of server restarts, then that is probably what happened.
* Getting authentication to work in Dask Gateway workers can be non-intuitive. Try testing your workflow in a VM where you do not have your credentials stored.
* I tried writing cloud bucket to cloud bucket without the temp local write, but that was desperately slow for some reason.
* Dask workers do not have access to your home directory, so you'll need to upload any files that they need.
* The `process_one_granule()` function does not use a dask graph (no dask-backed arrays processing) but if it did it would be essential that we put those tasks inside `with dask.config.set(scheduler="threads"):`  If we do not, dask will try to spread the dask graph all the workers and you will get cryptic serialization errors.
* Be careful not to commit any secrets or passwords to a GitHub repository. Notice how I read those in from local files that are not in the repository.
* Your Dask Workers will be happiest if the image for the workers is the same as your notebook. My image was missing google-cloud-storage, so I was trying to pip install that into my workers. That was an enormous hassle and it was easier to create an Docker image with the packages I needed. [Docker image](https://github.com/nmfs-opensci/container-images/tree/main/images/openscapes). With GitHub, it is easy to create and host images that you can use with Dask Gateway (or coiled).

### Authentication to cloud buckets

Because we are writing to a cloud bucket, we need to deal with authentication and that can cause a lot of headaches when sending tasks to Dask Workers. See the notes at the bottom on how to get the authentication to work. These functions use the google-cloud-storage package to handle authentication.

### What is Dask Gateway

Dask Gateway is a service that lets users create and manage Dask clusters (virtual machines) on shared infrastructure (often Kubernetes) from a notebook or script, without needing direct access to the underlying compute platform. It must be installed and configured by your JupyterHub / platform administrator (resource limits, images, auth, scaling policies); users can only request clusters once that backend setup is in place.

https://coiled.io/ is a similar service and also allows you to run your tasks on whatever cloud provider that you need to. Anyone can use coiled; your admin doesn't need to set up for you.

## Set up the `process_one_granule()` function

It also has a variety of helper files.

- Searches PACE L3M Rrs DAY granules via earthaccess
- For each granule/day:
    * downloads Rrs
    * runs BRT CHLA(z) prediction
    * computes integrated/peak metrics
    * writes a daily NetCDF locally and deletes when done
    * uploads to GCS
- Skips days that already exist in GCS unless FORCE_RERUN=True

In [1]:
"""
Parallel daily CHLA(z) production using Dask-Gateway.

- Searches PACE L3M Rrs DAY granules via earthaccess
- For each granule/day:
    * downloads Rrs
    * runs BRT CHLA(z) prediction
    * computes integrated/peak metrics
    * writes a daily NetCDF locally
    * uploads to GCS
- Skips days that already exist in GCS unless FORCE_RERUN=True
"""

import os
from pathlib import Path
import tempfile

import numpy as np
import pandas as pd
import xarray as xr
import earthaccess
from google.cloud import storage
from dask_gateway import Gateway
from dask.distributed import Client

# --------------------------------------------------------------------------------------
# CONFIG
# --------------------------------------------------------------------------------------

# Path to your saved ML bundle (zip) – adjust as needed
BUNDLE_PATH = "models/brt_chla_profiles_bundle.zip"
BUNDLE_FILENAME = Path(BUNDLE_PATH).name  # "brt_chla_profiles_bundle.zip"

# GCS target
BUCKET_NAME = "nmfs_odp_nwfsc"
DESTINATION_PREFIX = "CB/fish-pace-datasets/chla-z/netcdf"

# Dask-Gateway settings
MIN_WORKERS = 4
MAX_WORKERS = 12
WORKER_CORES = 4
WORKER_MEMORY = "32GiB"

# Spatial chunking for NetCDF output
LAT_CHUNK = 100
LON_CHUNK = 100

# Rerun control: if False, skip days that already exist in GCS
FORCE_RERUN = False

# Optional date filtering for rrs_results (None = no filter)
START_DATE = None  # e.g. "2024-03-01"
END_DATE   = None  # e.g. "2024-04-30"

#START_DATE = "2024-04-01"
#END_DATE   = "2024-04-02"

import netrc
import json

netrc_path = os.path.expanduser("~/.netrc")
auth = netrc.netrc(netrc_path)
login, account, password = auth.authenticators("urs.earthdata.nasa.gov")
ED_USER = login
ED_PASS = password
with open("/home/jovyan/.config/gcloud/application_default_credentials.json") as f:
    GCP_SA_JSON = f.read()


# --------------------------------------------------------------------------------------
# Helper: load ML bundle and build CHLA profile dataset
# --------------------------------------------------------------------------------------

# Ensure ml_utils is available
if not os.path.exists("ml_utils.py"):
    import subprocess
    subprocess.run(
        [
            "wget",
            "-q",
            "https://raw.githubusercontent.com/fish-pace/chla-z-modeling/main/ml_utils.py",
        ],
        check=True,
    )

import ml_utils as mu  

#######################
# - Helper; Not needed as I use a Docker image with the packages I need
#######################
def ensure_google_cloud_storage():
    """Install google-cloud-storage on the worker if it's missing."""
    import importlib
    import subprocess
    import sys

    try:
        importlib.import_module("google.cloud.storage")
    except ImportError:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "google-cloud-storage"]
        )


def build_chla_profile_dataset(CHLA: xr.DataArray) -> xr.Dataset:
    """
    Given CHLA(time, z, lat, lon), compute derived metrics and
    return an xr.Dataset suitable for writing to Zarr/NetCDF.
    """
    # Start from CHLA's own dataset so its coords (including z_start/z_end) win
    ds = CHLA.to_dataset(name="CHLA")

    # ---- Layer thickness (z dimension) ----
    z_start = CHLA.coords.get("z_start", None)
    z_end   = CHLA.coords.get("z_end", None)

    if (z_start is not None) and (z_end is not None):
        z_thick = (z_end - z_start).rename("z_thickness")   # (z)
    else:
        # fallback: uniform layer thickness, e.g. 10 m
        z_thick = xr.full_like(CHLA["z"], 10.0).rename("z_thickness")

    z_center = CHLA["z"]

    # total CHLA in column (used for validity + center-of-mass)
    col_total = CHLA.sum("z")          # (time, lat, lon)
    valid = col_total > 0              # True where there is some CHLA

    # ---- Integrated CHLA (nominal 0–200 m; actual range = z extent) ----
    CHLA_int = (CHLA * z_thick).sum("z")
    CHLA_int = CHLA_int.where(valid)
    CHLA_int.name = "CHLA_int_0_200"

    # ---- Peak value and depth (NaN-safe) ----
    CHLA_filled = CHLA.fillna(-np.inf)
    peak_idx = CHLA_filled.argmax("z")       # (time, lat, lon) integer indices

    CHLA_peak = CHLA.isel(z=peak_idx).where(valid)
    CHLA_peak.name = "CHLA_peak"

    CHLA_peak_depth = z_center.isel(z=peak_idx).where(valid)
    CHLA_peak_depth.name = "CHLA_peak_depth"

    # ---- Depth-weighted mean depth (center of mass) ----
    num = (CHLA * z_center).sum("z")
    den = col_total
    depth_cm = (num / den).where(valid)
    depth_cm.name = "CHLA_depth_center_of_mass"

    # ---- Attach derived fields to the dataset ----
    ds["CHLA_int_0_200"] = CHLA_int
    ds["CHLA_peak"] = CHLA_peak
    ds["CHLA_peak_depth"] = CHLA_peak_depth
    ds["CHLA_depth_center_of_mass"] = depth_cm
    ds["z_thickness"] = z_thick

    # ---- Variable attributes ----
    ds["CHLA"].attrs.setdefault("units", "mg m-3")
    ds["CHLA"].attrs.setdefault("long_name", "Chlorophyll-a concentration")
    ds["CHLA"].attrs.setdefault("standard_name", "mass_concentration_of_chlorophyll_a_in_sea_water")
    ds["CHLA"].attrs.setdefault(
        "description",
        "BRT-derived chlorophyll-a profiles from PACE hyperspectral Rrs",
    )

    ds["CHLA_int_0_200"].attrs.update(
        units="mg m-2",
        long_name="Depth-integrated chlorophyll-a",
        description=(
            "Vertical integral of CHLA over the available depth bins "
            "(nominally 0–200 m; actual range defined by z_start/z_end)."
        ),
    )

    ds["CHLA_peak"].attrs.update(
        units="mg m-3",
        long_name="Peak chlorophyll-a concentration in the water column",
        standard_name="mass_concentration_of_chlorophyll_a_in_sea_water",
        description="Maximum CHLA value over depth at each (time, lat, lon).",
    )

    ds["CHLA_peak_depth"].attrs.update(
        units="m",
        long_name="Depth of peak chlorophyll-a",
        positive="down",
        description=(
            "Depth (bin center) where CHLA is maximal in the water column "
            "at each (time, lat, lon)."
        ),
    )

    ds["CHLA_depth_center_of_mass"].attrs.update(
        units="m",
        long_name="Chlorophyll-a depth center of mass",
        positive="down",
        description=(
            "Depth of the chlorophyll-a center of mass, computed as "
            "sum_z(CHLA * z) / sum_z(CHLA)."
        ),
    )

    ds["z_thickness"].attrs.update(
        units="m",
        long_name="Layer thickness",
        description=(
            "Thickness of each vertical bin used for depth integration. "
            "Derived from z_end - z_start when available; otherwise set to a "
            "uniform nominal thickness."
        ),
    )
    ds["z_thickness"] = ds["z_thickness"].expand_dims(time=ds["time"])

    return ds


# --------------------------------------------------------------------------------------
# Worker-side function: process ONE granule/day
# --------------------------------------------------------------------------------------

from functools import partial

def process_one_granule(
    res,
    lat_chunk=LAT_CHUNK,
    lon_chunk=LON_CHUNK,
    bucket_name=BUCKET_NAME,
    destination_prefix=DESTINATION_PREFIX,
    force_rerun=FORCE_RERUN,
    ed_username=ED_USER,
    ed_password=ED_PASS,
    gcp_sa_json=GCP_SA_JSON,
    bundle_filename=BUNDLE_FILENAME,
):
    import os
    import tempfile
    import earthaccess
    import xarray as xr
    import pandas as pd
    from pathlib import Path
    import ml_utils as mu  # <- now workers can import this

    # --- ensure google-cloud-storage is available on THIS worker ---
    import importlib
    import subprocess
    import sys
    try:
        importlib.import_module("google.cloud.storage")
    except ImportError:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "google-cloud-storage"]
        )

    from google.cloud import storage  # now this should succeed

    # --- locate the bundle file next to ml_utils.py ---
    bundle_path = Path(mu.__file__).with_name(bundle_filename)
    # just to be extra defensive:
    if not bundle_path.exists():
        raise FileNotFoundError(f"Bundle not found at {bundle_path}")

    bundle = mu.load_ml_bundle(str(bundle_path))
    
    # DEBUG Load bundle on the worker from the uploaded zip file
    #bundle = mu.load_ml_bundle(bundle_filename)

    # --- EARTHACCESS AUTH VIA ENV VARS (inside worker) ---
    if ed_username is not None and ed_password is not None:
        os.environ["EARTHDATA_USERNAME"] = ed_username
        os.environ["EARTHDATA_PASSWORD"] = ed_password

    auth = earthaccess.login(strategy="environment", persist=False)

    # --- GCP AUTH VIA JSON TEXT (inside worker) ---
    import uuid

    cred_path = None
    if gcp_sa_json:
        cred_path = os.path.join(tempfile.gettempdir(), f"gcp_sa_worker_{uuid.uuid4().hex}.json")
        with open(cred_path, "w") as f:
            f.write(gcp_sa_json)
        os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path

    # -------------------------------
    #  Normal per-day pipeline below
    # -------------------------------
    day_iso = res["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"]
    day = pd.to_datetime(day_iso)
    day_str = day.strftime("%Y%m%d")

    storage_client = storage.Client(project="noaa-gcs-public-data")
    bucket = storage_client.bucket(bucket_name)
    blob_path = f"{destination_prefix}/chla_z_{day_str}_v2.nc"
    blob = bucket.blob(blob_path)

    if blob.exists() and not force_rerun:
        msg = f"[{day_str}] SKIP (exists at gs://{bucket_name}/{blob_path})"
        print(msg)
        return msg

    files = earthaccess.open([res], pqdm_kwargs={"disable": True})
    rrs_ds = xr.open_dataset(files[0])
    # DEBUG
    # rrs_ds = rrs_ds.sel(lat=slice(40, 20), lon=slice(-70, -60) )

    try:
        if "time" in rrs_ds.dims:
            R = rrs_ds["Rrs"].sel(time=day).squeeze("time")
        else:
            R = rrs_ds["Rrs"]
        R = R.transpose("lat", "lon", "wavelength")

        pred = bundle.predict(
            R,
            brt_models=bundle.model,
            feature_cols=bundle.meta["feature_cols"],
            consts={"solar_hour": 0, "type": 1},
            chunk_size_lat=100,
            time=day.to_datetime64(),
            z_name="z",
            silent=True,
            linear=True,
        )

        ds_day = build_chla_profile_dataset(pred)

        tmp_dir = Path(tempfile.gettempdir())
        local_path = tmp_dir / f"chla_z_{day_str}.nc"

        # Fix chunking
        chunks4d = (1, ds_day.sizes["z"], lat_chunk, lon_chunk)
        chunks3d = (1, lat_chunk, lon_chunk)  
        chunks2d = (1, ds_day.sizes["z"])
        encoding = {
            "CHLA": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks4d},
            "CHLA_int_0_200": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks3d},
            "CHLA_peak": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks3d},
            "CHLA_peak_depth": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks3d},
            "CHLA_depth_center_of_mass": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks3d},
            "z_thickness": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks2d},
        }

        ds_day.to_netcdf(local_path, engine="h5netcdf", encoding=encoding)
        blob.upload_from_filename(str(local_path))
        # this is crucial. Remove file from /tmp when done. Otherwise /tmp will fill and server will crash.
        local_path.unlink(missing_ok=True)

        gcs_url = f"gs://{bucket_name}/{blob_path}"
        msg = f"[{day_str}] WROTE {gcs_url}"
        print(msg)
        return msg

    finally:
        rrs_ds.close()
        # optional: clean up the creds file
        if cred_path is not None:
            try:
                os.remove(cred_path)
            except FileNotFoundError:
                pass


## Create the function to start Dask Gateway cluster

In [None]:
# --------------------------------------------------------------------------------------
# DRIVER: search granules, filter, and dispatch via Dask-Gateway
# --------------------------------------------------------------------------------------

def main():
    # 1. Earthaccess login on client
    auth = earthaccess.login(strategy="netrc", persist=True)
    if not auth.authenticated:
        raise RuntimeError("earthaccess login failed")

    # 2. Search PACE L3M Rrs daily granules
    rrs_results = earthaccess.search_data(
        short_name="PACE_OCI_L3M_RRS",
        granule_name="*.DAY.*.4km.nc",
        temporal=(START_DATE, END_DATE),
    )

    print(f"Found {len(rrs_results)} DAY granules after date filter.")
    if not rrs_results:
        print("Nothing to do.")
        return

    # 4. Dask-Gateway cluster setup
    gateway = Gateway()
    options = gateway.cluster_options()
    setattr(options, "worker_resource_allocation", '4CPU, 30.2Gi')
    
    cluster = gateway.new_cluster(options)
    cluster.adapt(minimum=MIN_WORKERS, maximum=MAX_WORKERS)

    client = cluster.get_client()
    print(cluster)
    print(client)

    # Dashboard link (copy/paste into a browser tab)
    print("Dask dashboard:", client.dashboard_link)

    # Make sure workers have needed files
    client.upload_file("ml_utils.py")
    client.upload_file(BUNDLE_PATH)

    # ensure google-cloud-storage is installed on every worker
    client.run(ensure_google_cloud_storage)

    # 5. Dispatch one task per granule
    futures = client.map(process_one_granule, rrs_results)

    # 6. Stream results as they complete (instead of blocking on gather)
    from dask.distributed import as_completed

    n = len(futures)
    done = 0
    errors = 0

    try:
        for fut in as_completed(futures):
            try:
                msg = fut.result()
                done += 1
                print(f"[{done}/{n}] {msg}")
            except Exception as e:
                errors += 1
                done += 1
                print(f"[{done}/{n}] ERROR: {repr(e)}")
                # If you want to stop on first error, uncomment:
                # raise
    finally:
        print(f"Finished. Success={done - errors}, Errors={errors}")
        client.close()
        cluster.close()

In [None]:
# took 10 hours for 560 files
if __name__ == "__main__": main()

Found 560 DAY granules after date filter.
GatewayCluster<prod.c0c526860a59442a82b1cb067f3c5a4e, status=running>
<Client: 'tls://192.168.36.66:8786' processes=0 threads=0, memory=0 B>
Dask dashboard: /services/dask-gateway/clusters/prod.c0c526860a59442a82b1cb067f3c5a4e/status


In [1]:
# test if all the files are there
import gcsfs
import numpy as np
import pandas as pd

token = "/home/jovyan/.config/gcloud/application_default_credentials.json"
fs = gcsfs.GCSFileSystem(token=token)
# all .nc
base = "nmfs_odp_nwfsc/CB/fish-pace-datasets/chla-z/netcdf"
all_nc = sorted(fs.glob(f"{base}/chla_z_*.nc"))
# keep only v1 (exclude anything with _v2 anywhere in the name)
paths_v1 = ["gcs://" + p for p in all_nc if "_v2" not in p]
paths_v2 = ["gcs://" + p for p in all_nc if "_v2" in p]
print("files:", len(paths_v1), "first:", paths_v1[0])
print("files:", len(paths_v2), "first:", paths_v2[0])

day_strs_v1 = [p.split("chla_z_")[1].split(".nc")[0] for p in paths_v1]
all_times_v1 = np.array(
    pd.to_datetime(sorted(set(day_strs_v1)), format="%Y%m%d"),
    dtype="datetime64[ns]"
)
day_strs_v2 = [p.split("chla_z_")[1].split("_v2")[0] for p in paths_v2]
all_times_v2 = np.array(
    pd.to_datetime(sorted(set(day_strs_v2)), format="%Y%m%d"),
    dtype="datetime64[ns]"
)
missing = [x for x in all_times_v1 if x not in all_times_v2]
missing

files: 560 first: gcs://nmfs_odp_nwfsc/CB/fish-pace-datasets/chla-z/netcdf/chla_z_20240305.nc
files: 559 first: gcs://nmfs_odp_nwfsc/CB/fish-pace-datasets/chla-z/netcdf/chla_z_20240305_v2.nc


[np.datetime64('2025-02-06T00:00:00.000000000')]

In [None]:
# need to rerun a few days
START_DATE = '20250206'
END_DATE   = '20250206'
if __name__ == "__main__": main()

Found 1 DAY granules after date filter.
GatewayCluster<prod.6cf49325472e4ac4a66bd3af73c1a4bc, status=running>
<Client: 'tls://192.168.53.134:8786' processes=0 threads=0, memory=0 B>
Dask dashboard: /services/dask-gateway/clusters/prod.6cf49325472e4ac4a66bd3af73c1a4bc/status
