In [2]:
!pip install google-cloud-storage

Collecting google-cloud-storage
  Downloading google_cloud_storage-3.7.0-py3-none-any.whl.metadata (14 kB)
Collecting google-auth<3.0.0,>=2.26.1 (from google-cloud-storage)
  Downloading google_auth-2.43.0-py2.py3-none-any.whl.metadata (6.6 kB)
Collecting google-api-core<3.0.0,>=2.27.0 (from google-cloud-storage)
  Downloading google_api_core-2.28.1-py3-none-any.whl.metadata (3.3 kB)
Collecting google-cloud-core<3.0.0,>=2.4.2 (from google-cloud-storage)
  Downloading google_cloud_core-2.5.0-py3-none-any.whl.metadata (3.1 kB)
Collecting google-resumable-media<3.0.0,>=2.7.2 (from google-cloud-storage)
  Downloading google_resumable_media-2.8.0-py3-none-any.whl.metadata (2.6 kB)
Collecting google-crc32c<2.0.0,>=1.1.3 (from google-cloud-storage)
  Downloading google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Collecting googleapis-common-protos<2.0.0,>=1.56.2 (from google-api-core<3.0.0,>=2.27.0->google-cloud-storage)
  Downloading googleapis_c

In [1]:
"""
Parallel daily CHLA(z) production using Dask-Gateway.

- Searches PACE L3M Rrs DAY granules via earthaccess
- For each granule/day:
    * downloads Rrs
    * runs BRT CHLA(z) prediction
    * computes integrated/peak metrics
    * writes a daily NetCDF locally
    * uploads to GCS
- Skips days that already exist in GCS unless FORCE_RERUN=True
"""

import os
from pathlib import Path
import tempfile

import numpy as np
import pandas as pd
import xarray as xr
import earthaccess
from google.cloud import storage
from dask_gateway import Gateway
from dask.distributed import Client

# --------------------------------------------------------------------------------------
# CONFIG
# --------------------------------------------------------------------------------------

# Path to your saved ML bundle (zip) – adjust as needed
BUNDLE_PATH = "models/chla_brt_bundle.zip"

# GCS target
BUCKET_NAME = "nmfs_odp_nwfsc"
DESTINATION_PREFIX = "CB/fish-pace-datasets/chla-z/netcdf"

# Dask-Gateway settings
MIN_WORKERS = 4
MAX_WORKERS = 12
WORKER_CORES = 4
WORKER_MEMORY = "32GiB"

# Spatial chunking for NetCDF output
LAT_CHUNK = 100
LON_CHUNK = 100

# Rerun control: if False, skip days that already exist in GCS
FORCE_RERUN = False

# Optional date filtering for rrs_results (None = no filter)
START_DATE = None  # e.g. "2024-03-01"
END_DATE   = None  # e.g. "2024-04-30"

# --------------------------------------------------------------------------------------
# Helper: load ML bundle and build CHLA profile dataset
# --------------------------------------------------------------------------------------

# Ensure ml_utils is available
if not os.path.exists("ml_utils.py"):
    import subprocess
    subprocess.run(
        [
            "wget",
            "-q",
            "https://raw.githubusercontent.com/fish-pace/2025-tutorials/main/ml_utils.py",
        ],
        check=True,
    )

import ml_utils as mu  # noqa: E402

# Load the bundle once on the client side; workers will receive it via pickling
bundle = mu.load_ml_bundle(BUNDLE_PATH)


def build_chla_profile_dataset(CHLA: xr.DataArray) -> xr.Dataset:
    """
    Given CHLA(time, z, lat, lon), compute derived metrics and
    return an xr.Dataset suitable for writing to Zarr/NetCDF.
    """
    # Start from CHLA's own dataset so its coords (including z_start/z_end) win
    ds = CHLA.to_dataset(name="CHLA")

    # ---- Layer thickness (z dimension) ----
    z_start = CHLA.coords.get("z_start", None)
    z_end   = CHLA.coords.get("z_end", None)

    if (z_start is not None) and (z_end is not None):
        z_thick = (z_end - z_start).rename("z_thickness")   # (z)
    else:
        # fallback: uniform layer thickness, e.g. 10 m
        z_thick = xr.full_like(CHLA["z"], 10.0).rename("z_thickness")

    z_center = CHLA["z"]

    # total CHLA in column (used for validity + center-of-mass)
    col_total = CHLA.sum("z")          # (time, lat, lon)
    valid = col_total > 0              # True where there is some CHLA

    # ---- Integrated CHLA (nominal 0–200 m; actual range = z extent) ----
    CHLA_int = (CHLA * z_thick).sum("z")
    CHLA_int = CHLA_int.where(valid)
    CHLA_int.name = "CHLA_int_0_200"

    # ---- Peak value and depth (NaN-safe) ----
    CHLA_filled = CHLA.fillna(-np.inf)
    peak_idx = CHLA_filled.argmax("z")       # (time, lat, lon) integer indices

    CHLA_peak = CHLA.isel(z=peak_idx).where(valid)
    CHLA_peak.name = "CHLA_peak"

    CHLA_peak_depth = z_center.isel(z=peak_idx).where(valid)
    CHLA_peak_depth.name = "CHLA_peak_depth"

    # ---- Depth-weighted mean depth (center of mass) ----
    num = (CHLA * z_center).sum("z")
    den = col_total
    depth_cm = (num / den).where(valid)
    depth_cm.name = "CHLA_depth_center_of_mass"

    # ---- Attach derived fields to the dataset ----
    ds["CHLA_int_0_200"] = CHLA_int
    ds["CHLA_peak"] = CHLA_peak
    ds["CHLA_peak_depth"] = CHLA_peak_depth
    ds["CHLA_depth_center_of_mass"] = depth_cm
    ds["z_thickness"] = z_thick

    # ---- Variable attributes ----
    ds["CHLA"].attrs.setdefault("units", "mg m-3")
    ds["CHLA"].attrs.setdefault("long_name", "Chlorophyll-a concentration")
    ds["CHLA"].attrs.setdefault(
        "description",
        "BRT-derived chlorophyll-a profiles from PACE hyperspectral Rrs",
    )

    ds["CHLA_int_0_200"].attrs.update(
        units="mg m-2",
        long_name="Depth-integrated chlorophyll-a",
        description=(
            "Vertical integral of CHLA over the available depth bins "
            "(nominally 0–200 m; actual range defined by z_start/z_end)."
        ),
    )

    ds["CHLA_peak"].attrs.update(
        units="mg m-3",
        long_name="Peak chlorophyll-a concentration in the water column",
        description="Maximum CHLA value over depth at each (time, lat, lon).",
    )

    ds["CHLA_peak_depth"].attrs.update(
        units="m",
        long_name="Depth of peak chlorophyll-a",
        positive="down",
        description=(
            "Depth (bin center) where CHLA is maximal in the water column "
            "at each (time, lat, lon)."
        ),
    )

    ds["CHLA_depth_center_of_mass"].attrs.update(
        units="m",
        long_name="Chlorophyll-a depth center of mass",
        positive="down",
        description=(
            "Depth of the chlorophyll-a center of mass, computed as "
            "sum_z(CHLA * z) / sum_z(CHLA)."
        ),
    )

    ds["z_thickness"].attrs.update(
        units="m",
        long_name="Layer thickness",
        description=(
            "Thickness of each vertical bin used for depth integration. "
            "Derived from z_end - z_start when available; otherwise set to a "
            "uniform nominal thickness."
        ),
    )

    return ds


# --------------------------------------------------------------------------------------
# Worker-side function: process ONE granule/day
# --------------------------------------------------------------------------------------

def process_one_granule(
    res,
    lat_chunk=LAT_CHUNK,
    lon_chunk=LON_CHUNK,
    bucket_name=BUCKET_NAME,
    destination_prefix=DESTINATION_PREFIX,
    force_rerun=FORCE_RERUN,
):
    """
    Run the full pipeline for a single PACE L3M Rrs DAY granule:
      - check if daily NetCDF already exists in GCS (skip if so and not force_rerun)
      - download Rrs via earthaccess
      - run BRT CHLA(z) prediction
      - compute derived metrics
      - write daily NetCDF to local temp
      - upload NetCDF to GCS

    Returns
    -------
    str
        Message with status and GCS path (or SKIP info).
    """
    import earthaccess
    import xarray as xr
    import pandas as pd
    from google.cloud import storage
    from pathlib import Path
    import tempfile

    # day as ISO string from UMM
    day_iso = res["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"]
    day = pd.to_datetime(day_iso)
    day_str = day.strftime("%Y%m%d")

    # Check if this day's file already exists in GCS
    storage_client = storage.Client(project="noaa-gcs-public-data")
    bucket = storage_client.bucket(bucket_name)
    blob_path = f"{destination_prefix}/chla_z_{day_str}.nc"
    blob = bucket.blob(blob_path)

    if blob.exists() and not force_rerun:
        msg = f"[{day_str}] SKIP (already exists at gs://{bucket_name}/{blob_path})"
        print(msg)
        return msg

    # Earthaccess auth on worker; assumes ~/.netrc is visible to worker
    auth = earthaccess.login(persist=True)

    # Open Rrs dataset for this granule
    files = earthaccess.open([res], auth=auth, pqdm_kwargs={"disable": True})
    rrs_ds = xr.open_dataset(files[0])

    try:
        # Rrs for that day
        if "time" in rrs_ds.dims:
            R = rrs_ds["Rrs"].sel(time=day).squeeze("time")
        else:
            R = rrs_ds["Rrs"]
        R = R.transpose("lat", "lon", "wavelength")

        # CHLA(z) prediction for this day (uses bundle.predict)
        pred = bundle.predict(
            R,
            brt_models=bundle.model,
            feature_cols=bundle.meta["feature_cols"],
            consts={"solar_hour": 0, "type": 1},
            chunk_size_lat=100,
            time=day.to_datetime64(),
            z_name="z",
            silent=True,
        )  # (time=1, z, lat, lon), float32

        ds_day = build_chla_profile_dataset(pred)

        # Add/override metadata
        ds_day["CHLA"].attrs.update(
            units="mg m-3",
            long_name="Chlorophyll-a concentration",
            description="BRT-derived CHLA profiles from PACE hyperspectral Rrs",
        )
        ds_day["z"].attrs.update(units="m", long_name="depth (bin center)")
        ds_day["lat"].attrs.update(units="degrees_north")
        ds_day["lon"].attrs.update(units="degrees_east")
        ds_day.attrs["source"] = "BRT model trained on BGC-Argo + OOI matchups"
        ds_day.attrs["model_bundle"] = Path(BUNDLE_PATH).name

        # Write to local temporary NetCDF
        tmp_dir = Path(tempfile.gettempdir())
        local_path = tmp_dir / f"chla_z_{day_str}.nc"

        encoding = {
            "CHLA": {
                "dtype": "float32",
                "zlib": True,
                "complevel": 4,
                "chunksizes": (1, ds_day.sizes["z"], lat_chunk, lon_chunk),
            }
        }

        ds_day.to_netcdf(
            local_path,
            engine="h5netcdf",
            encoding=encoding,
        )

        # Upload to GCS
        blob.upload_from_filename(str(local_path))
        local_path.unlink(missing_ok=True)

        gcs_url = f"gs://{bucket_name}/{blob_path}"
        msg = f"[{day_str}] WROTE {gcs_url}"
        print(msg)
        return msg

    finally:
        rrs_ds.close()


# --------------------------------------------------------------------------------------
# DRIVER: search granules, filter, and dispatch via Dask-Gateway
# --------------------------------------------------------------------------------------

def main():
    # 1. Earthaccess login on client
    auth = earthaccess.login(persist=True)
    if not auth.authenticated:
        raise RuntimeError("earthaccess login failed")

    # 2. Search PACE L3M Rrs daily granules
    rrs_results = earthaccess.search_data(
        short_name="PACE_OCI_L3M_RRS",
        granule_name="*.DAY.*.4km.nc",
    )

    # 3. Optional date filtering
    def granule_day(res):
        iso = res["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"]
        return pd.to_datetime(iso)

    if START_DATE is not None:
        start = pd.to_datetime(START_DATE)
        rrs_results = [r for r in rrs_results if granule_day(r) >= start]

    if END_DATE is not None:
        end = pd.to_datetime(END_DATE)
        rrs_results = [r for r in rrs_results if granule_day(r) <= end]

    print(f"Found {len(rrs_results)} DAY granules after date filter.")

    if not rrs_results:
        print("Nothing to do.")
        return

    # 4. Dask-Gateway cluster setup
    gateway = Gateway()
    options = gateway.cluster_options()

    # These attributes may or may not exist depending on your deployment;
    # if they don't, comment these two lines out and set resources via the UI.
    if hasattr(options, "worker_cores"):
        options.worker_cores = WORKER_CORES
    if hasattr(options, "worker_memory"):
        options.worker_memory = WORKER_MEMORY

    cluster = gateway.new_cluster(options)
    cluster.adapt(minimum=MIN_WORKERS, maximum=MAX_WORKERS)

    client = cluster.get_client()
    print(cluster)
    print(client)

    # 5. Dispatch one task per granule
    futures = client.map(
        process_one_granule,
        rrs_results,
    )

    # 6. Wait for completion and collect messages
    results = client.gather(futures)
    print("Pipeline complete. Task summaries:")
    for r in results:
        print("  ", r)

    client.close()
    cluster.close()


if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'google'