In [1]:
!pip install google-cloud-storage



In [1]:
from google.cloud import storage

In [2]:
def ensure_google_cloud_storage():
    """Install google-cloud-storage on the worker if it's missing."""
    import importlib
    import subprocess
    import sys

    try:
        importlib.import_module("google.cloud.storage")
    except ImportError:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "google-cloud-storage"]
        )
ensure_google_cloud_storage()

In [3]:
"""
Parallel daily CHLA(z) production using Dask-Gateway.

- Searches PACE L3M Rrs DAY granules via earthaccess
- For each granule/day:
    * downloads Rrs
    * runs BRT CHLA(z) prediction
    * computes integrated/peak metrics
    * writes a daily NetCDF locally
    * uploads to GCS
- Skips days that already exist in GCS unless FORCE_RERUN=True
"""

import os
from pathlib import Path
import tempfile

import numpy as np
import pandas as pd
import xarray as xr
import earthaccess
from google.cloud import storage
from dask_gateway import Gateway
from dask.distributed import Client

# --------------------------------------------------------------------------------------
# CONFIG
# --------------------------------------------------------------------------------------

# Path to your saved ML bundle (zip) – adjust as needed
BUNDLE_PATH = "models/brt_chla_profiles_bundle.zip"
BUNDLE_FILENAME = Path(BUNDLE_PATH).name  # "brt_chla_profiles_bundle.zip"

# GCS target
BUCKET_NAME = "nmfs_odp_nwfsc"
DESTINATION_PREFIX = "CB/fish-pace-datasets/chla-z/netcdf"

# Dask-Gateway settings
MIN_WORKERS = 4
MAX_WORKERS = 12
WORKER_CORES = 4
WORKER_MEMORY = "32GiB"

# Spatial chunking for NetCDF output
LAT_CHUNK = 100
LON_CHUNK = 100

# Rerun control: if False, skip days that already exist in GCS
FORCE_RERUN = False

# Optional date filtering for rrs_results (None = no filter)
START_DATE = None  # e.g. "2024-03-01"
END_DATE   = None  # e.g. "2024-04-30"

#START_DATE = "2024-04-01"
#END_DATE   = "2024-04-02"

import netrc
import json

netrc_path = os.path.expanduser("~/.netrc")
auth = netrc.netrc(netrc_path)
login, account, password = auth.authenticators("urs.earthdata.nasa.gov")
ED_USER = login
ED_PASS = password
with open("/home/jovyan/.config/gcloud/application_default_credentials.json") as f:
    GCP_SA_JSON = f.read()


# --------------------------------------------------------------------------------------
# Helper: load ML bundle and build CHLA profile dataset
# --------------------------------------------------------------------------------------

# Ensure ml_utils is available
if not os.path.exists("ml_utils.py"):
    import subprocess
    subprocess.run(
        [
            "wget",
            "-q",
            "https://raw.githubusercontent.com/fish-pace/chla-z-modeling/main/ml_utils.py",
        ],
        check=True,
    )

import ml_utils as mu  # noqa: E402

# Load the bundle once on the client side; workers will receive it via pickling
# DELETE
# bundle = mu.load_ml_bundle(BUNDLE_PATH)

#######################
# - Helper
#######################
def ensure_google_cloud_storage():
    """Install google-cloud-storage on the worker if it's missing."""
    import importlib
    import subprocess
    import sys

    try:
        importlib.import_module("google.cloud.storage")
    except ImportError:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "google-cloud-storage"]
        )


def build_chla_profile_dataset(CHLA: xr.DataArray) -> xr.Dataset:
    """
    Given CHLA(time, z, lat, lon), compute derived metrics and
    return an xr.Dataset suitable for writing to Zarr/NetCDF.
    """
    # Start from CHLA's own dataset so its coords (including z_start/z_end) win
    ds = CHLA.to_dataset(name="CHLA")

    # ---- Layer thickness (z dimension) ----
    z_start = CHLA.coords.get("z_start", None)
    z_end   = CHLA.coords.get("z_end", None)

    if (z_start is not None) and (z_end is not None):
        z_thick = (z_end - z_start).rename("z_thickness")   # (z)
    else:
        # fallback: uniform layer thickness, e.g. 10 m
        z_thick = xr.full_like(CHLA["z"], 10.0).rename("z_thickness")

    z_center = CHLA["z"]

    # total CHLA in column (used for validity + center-of-mass)
    col_total = CHLA.sum("z")          # (time, lat, lon)
    valid = col_total > 0              # True where there is some CHLA

    # ---- Integrated CHLA (nominal 0–200 m; actual range = z extent) ----
    CHLA_int = (CHLA * z_thick).sum("z")
    CHLA_int = CHLA_int.where(valid)
    CHLA_int.name = "CHLA_int_0_200"

    # ---- Peak value and depth (NaN-safe) ----
    CHLA_filled = CHLA.fillna(-np.inf)
    peak_idx = CHLA_filled.argmax("z")       # (time, lat, lon) integer indices

    CHLA_peak = CHLA.isel(z=peak_idx).where(valid)
    CHLA_peak.name = "CHLA_peak"

    CHLA_peak_depth = z_center.isel(z=peak_idx).where(valid)
    CHLA_peak_depth.name = "CHLA_peak_depth"

    # ---- Depth-weighted mean depth (center of mass) ----
    num = (CHLA * z_center).sum("z")
    den = col_total
    depth_cm = (num / den).where(valid)
    depth_cm.name = "CHLA_depth_center_of_mass"

    # ---- Attach derived fields to the dataset ----
    ds["CHLA_int_0_200"] = CHLA_int
    ds["CHLA_peak"] = CHLA_peak
    ds["CHLA_peak_depth"] = CHLA_peak_depth
    ds["CHLA_depth_center_of_mass"] = depth_cm
    ds["z_thickness"] = z_thick

    # ---- Variable attributes ----
    ds["CHLA"].attrs.setdefault("units", "mg m-3")
    ds["CHLA"].attrs.setdefault("long_name", "Chlorophyll-a concentration")
    ds["CHLA"].attrs.setdefault("standard_name", "mass_concentration_of_chlorophyll_a_in_sea_water")
    ds["CHLA"].attrs.setdefault(
        "description",
        "BRT-derived chlorophyll-a profiles from PACE hyperspectral Rrs",
    )

    ds["CHLA_int_0_200"].attrs.update(
        units="mg m-2",
        long_name="Depth-integrated chlorophyll-a",
        description=(
            "Vertical integral of CHLA over the available depth bins "
            "(nominally 0–200 m; actual range defined by z_start/z_end)."
        ),
    )

    ds["CHLA_peak"].attrs.update(
        units="mg m-3",
        long_name="Peak chlorophyll-a concentration in the water column",
        standard_name="mass_concentration_of_chlorophyll_a_in_sea_water",
        description="Maximum CHLA value over depth at each (time, lat, lon).",
    )

    ds["CHLA_peak_depth"].attrs.update(
        units="m",
        long_name="Depth of peak chlorophyll-a",
        positive="down",
        description=(
            "Depth (bin center) where CHLA is maximal in the water column "
            "at each (time, lat, lon)."
        ),
    )

    ds["CHLA_depth_center_of_mass"].attrs.update(
        units="m",
        long_name="Chlorophyll-a depth center of mass",
        positive="down",
        description=(
            "Depth of the chlorophyll-a center of mass, computed as "
            "sum_z(CHLA * z) / sum_z(CHLA)."
        ),
    )

    ds["z_thickness"].attrs.update(
        units="m",
        long_name="Layer thickness",
        description=(
            "Thickness of each vertical bin used for depth integration. "
            "Derived from z_end - z_start when available; otherwise set to a "
            "uniform nominal thickness."
        ),
    )
    ds["z_thickness"] = ds["z_thickness"].expand_dims(time=ds["time"])

    return ds


# --------------------------------------------------------------------------------------
# Worker-side function: process ONE granule/day
# --------------------------------------------------------------------------------------

from functools import partial

def process_one_granule(
    res,
    lat_chunk=LAT_CHUNK,
    lon_chunk=LON_CHUNK,
    bucket_name=BUCKET_NAME,
    destination_prefix=DESTINATION_PREFIX,
    force_rerun=FORCE_RERUN,
    ed_username=ED_USER,
    ed_password=ED_PASS,
    gcp_sa_json=GCP_SA_JSON,
    bundle_filename=BUNDLE_FILENAME,
):
    import os
    import tempfile
    import earthaccess
    import xarray as xr
    import pandas as pd
    from pathlib import Path
    import ml_utils as mu  # <- now workers can import this

    # --- ensure google-cloud-storage is available on THIS worker ---
    import importlib
    import subprocess
    import sys
    try:
        importlib.import_module("google.cloud.storage")
    except ImportError:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "google-cloud-storage"]
        )

    from google.cloud import storage  # now this should succeed

    # --- locate the bundle file next to ml_utils.py ---
    bundle_path = Path(mu.__file__).with_name(bundle_filename)
    # just to be extra defensive:
    if not bundle_path.exists():
        raise FileNotFoundError(f"Bundle not found at {bundle_path}")

    bundle = mu.load_ml_bundle(str(bundle_path))
    
    # DEBUG Load bundle on the worker from the uploaded zip file
    #bundle = mu.load_ml_bundle(bundle_filename)

    # --- EARTHACCESS AUTH VIA ENV VARS (inside worker) ---
    if ed_username is not None and ed_password is not None:
        os.environ["EARTHDATA_USERNAME"] = ed_username
        os.environ["EARTHDATA_PASSWORD"] = ed_password

    auth = earthaccess.login(strategy="environment", persist=False)

    # --- GCP AUTH VIA JSON TEXT (inside worker) ---
    import uuid

    cred_path = None
    if gcp_sa_json:
        cred_path = os.path.join(tempfile.gettempdir(), f"gcp_sa_worker_{uuid.uuid4().hex}.json")
        with open(cred_path, "w") as f:
            f.write(gcp_sa_json)
        os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path

    # -------------------------------
    #  Normal per-day pipeline below
    # -------------------------------
    day_iso = res["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"]
    day = pd.to_datetime(day_iso)
    day_str = day.strftime("%Y%m%d")

    storage_client = storage.Client(project="noaa-gcs-public-data")
    bucket = storage_client.bucket(bucket_name)
    blob_path = f"{destination_prefix}/chla_z_{day_str}_v2.nc"
    blob = bucket.blob(blob_path)

    if blob.exists() and not force_rerun:
        msg = f"[{day_str}] SKIP (exists at gs://{bucket_name}/{blob_path})"
        print(msg)
        return msg

    files = earthaccess.open([res], pqdm_kwargs={"disable": True})
    rrs_ds = xr.open_dataset(files[0])
    # DEBUG
    # rrs_ds = rrs_ds.sel(lat=slice(40, 20), lon=slice(-70, -60) )

    try:
        if "time" in rrs_ds.dims:
            R = rrs_ds["Rrs"].sel(time=day).squeeze("time")
        else:
            R = rrs_ds["Rrs"]
        R = R.transpose("lat", "lon", "wavelength")

        pred = bundle.predict(
            R,
            brt_models=bundle.model,
            feature_cols=bundle.meta["feature_cols"],
            consts={"solar_hour": 0, "type": 1},
            chunk_size_lat=100,
            time=day.to_datetime64(),
            z_name="z",
            silent=True,
            linear=True,
        )

        # DEBUG
        # pred = pred.compute()

        ds_day = build_chla_profile_dataset(pred)

        tmp_dir = Path(tempfile.gettempdir())
        local_path = tmp_dir / f"chla_z_{day_str}.nc"

        # Fix chunking
        chunks4d = (1, ds_day.sizes["z"], lat_chunk, lon_chunk)
        chunks3d = (1, lat_chunk, lon_chunk)  
        chunks2d = (1, ds_day.sizes["z"])
        encoding = {
            "CHLA": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks4d},
            "CHLA_int_0_200": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks3d},
            "CHLA_peak": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks3d},
            "CHLA_peak_depth": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks3d},
            "CHLA_depth_center_of_mass": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks3d},
            "z_thickness": {"dtype": "float32", "zlib": True, "complevel": 4, "chunksizes": chunks2d},
        }

        ds_day.to_netcdf(local_path, engine="h5netcdf", encoding=encoding)
        blob.upload_from_filename(str(local_path))
        local_path.unlink(missing_ok=True)

        gcs_url = f"gs://{bucket_name}/{blob_path}"
        msg = f"[{day_str}] WROTE {gcs_url}"
        print(msg)
        return msg

    finally:
        rrs_ds.close()
        # optional: clean up the creds file
        if cred_path is not None:
            try:
                os.remove(cred_path)
            except FileNotFoundError:
                pass

# --------------------------------------------------------------------------------------
# DRIVER: search granules, filter, and dispatch via Dask-Gateway
# --------------------------------------------------------------------------------------

def main():
    # 1. Earthaccess login on client
    auth = earthaccess.login(strategy="netrc", persist=True)
    if not auth.authenticated:
        raise RuntimeError("earthaccess login failed")

    # 2. Search PACE L3M Rrs daily granules
    rrs_results = earthaccess.search_data(
        short_name="PACE_OCI_L3M_RRS",
        granule_name="*.DAY.*.4km.nc",
        temporal=(START_DATE, END_DATE),
    )

    print(f"Found {len(rrs_results)} DAY granules after date filter.")
    if not rrs_results:
        print("Nothing to do.")
        return

    # 4. Dask-Gateway cluster setup
    gateway = Gateway()
    options = gateway.cluster_options()
    setattr(options, "worker_resource_allocation", '4CPU, 30.2Gi')
    
    cluster = gateway.new_cluster(options)
    cluster.adapt(minimum=MIN_WORKERS, maximum=MAX_WORKERS)

    client = cluster.get_client()
    print(cluster)
    print(client)

    # Dashboard link (copy/paste into a browser tab)
    print("Dask dashboard:", client.dashboard_link)

    # Make sure workers have needed files
    client.upload_file("ml_utils.py")
    client.upload_file(BUNDLE_PATH)

    # ensure google-cloud-storage is installed on every worker
    client.run(ensure_google_cloud_storage)

    # 5. Dispatch one task per granule
    futures = client.map(process_one_granule, rrs_results)

    # 6. Stream results as they complete (instead of blocking on gather)
    from dask.distributed import as_completed

    n = len(futures)
    done = 0
    errors = 0

    try:
        for fut in as_completed(futures):
            try:
                msg = fut.result()
                done += 1
                print(f"[{done}/{n}] {msg}")
            except Exception as e:
                errors += 1
                done += 1
                print(f"[{done}/{n}] ERROR: {repr(e)}")
                # If you want to stop on first error, uncomment:
                # raise
    finally:
        print(f"Finished. Success={done - errors}, Errors={errors}")
        client.close()
        cluster.close()


In [None]:
# took 10 hours for 560 files
if __name__ == "__main__": main()

Found 560 DAY granules after date filter.
GatewayCluster<prod.c0c526860a59442a82b1cb067f3c5a4e, status=running>
<Client: 'tls://192.168.36.66:8786' processes=0 threads=0, memory=0 B>
Dask dashboard: /services/dask-gateway/clusters/prod.c0c526860a59442a82b1cb067f3c5a4e/status


In [1]:
import gcsfs
import numpy as np
import pandas as pd

token = "/home/jovyan/.config/gcloud/application_default_credentials.json"
fs = gcsfs.GCSFileSystem(token=token)
# all .nc
base = "nmfs_odp_nwfsc/CB/fish-pace-datasets/chla-z/netcdf"
all_nc = sorted(fs.glob(f"{base}/chla_z_*.nc"))
# keep only v1 (exclude anything with _v2 anywhere in the name)
paths_v1 = ["gcs://" + p for p in all_nc if "_v2" not in p]
paths_v2 = ["gcs://" + p for p in all_nc if "_v2" in p]
print("files:", len(paths_v1), "first:", paths_v1[0])
print("files:", len(paths_v2), "first:", paths_v2[0])

day_strs_v1 = [p.split("chla_z_")[1].split(".nc")[0] for p in paths_v1]
all_times_v1 = np.array(
    pd.to_datetime(sorted(set(day_strs_v1)), format="%Y%m%d"),
    dtype="datetime64[ns]"
)
day_strs_v2 = [p.split("chla_z_")[1].split("_v2")[0] for p in paths_v2]
all_times_v2 = np.array(
    pd.to_datetime(sorted(set(day_strs_v2)), format="%Y%m%d"),
    dtype="datetime64[ns]"
)
missing = [x for x in all_times_v1 if x not in all_times_v2]
missing

files: 560 first: gcs://nmfs_odp_nwfsc/CB/fish-pace-datasets/chla-z/netcdf/chla_z_20240305.nc
files: 559 first: gcs://nmfs_odp_nwfsc/CB/fish-pace-datasets/chla-z/netcdf/chla_z_20240305_v2.nc


[np.datetime64('2025-02-06T00:00:00.000000000')]

In [None]:
# need to rerun a few days
START_DATE = '20250206'
END_DATE   = '20250206'
if __name__ == "__main__": main()

Found 1 DAY granules after date filter.
GatewayCluster<prod.6cf49325472e4ac4a66bd3af73c1a4bc, status=running>
<Client: 'tls://192.168.53.134:8786' processes=0 threads=0, memory=0 B>
Dask dashboard: /services/dask-gateway/clusters/prod.6cf49325472e4ac4a66bd3af73c1a4bc/status


# Process Zarr

Should have baked this into the first pipeline. Alas.

In [13]:
%%writefile chla_zarr_worker.py
import os, uuid, tempfile
from pathlib import Path

import numpy as np
import xarray as xr
import gcsfs
import zarr

def normalize_time(ds: xr.Dataset) -> xr.Dataset:
    if "time" in ds.coords:
        ds = ds.assign_coords(time=ds["time"].astype("datetime64[ns]"))
    return ds

def make_time_only_for_region(ds: xr.Dataset) -> xr.Dataset:
    time_vars = [v for v, da in ds.data_vars.items() if "time" in da.dims]
    if not time_vars:
        raise ValueError("No time-dependent variables found for region write.")
    ds_time = ds[time_vars]
    drop_coord_vars = [c for c in list(ds_time.coords) if c != "time"]
    if drop_coord_vars:
        ds_time = ds_time.drop_vars(drop_coord_vars)
    ds_time = ds_time.assign_coords(time=ds_time["time"].astype("datetime64[ns]"))
    return ds_time

def chunk_spec(ds: xr.Dataset, lat_chunk: int, lon_chunk: int) -> dict:
    spec = {}
    if "time" in ds.dims: spec["time"] = 1
    if "z" in ds.dims:    spec["z"] = ds.sizes["z"]
    if "lat" in ds.dims:  spec["lat"] = lat_chunk
    if "lon" in ds.dims:  spec["lon"] = lon_chunk
    return spec

def assert_dimension_names_present(mapper, ds: xr.Dataset) -> None:
    root = zarr.open_group(mapper, mode="r")
    def has_dn(a):
        return bool(getattr(a.metadata, "dimension_names", None))
    if "time" not in root or not has_dn(root["time"]):
        raise KeyError("Store missing v3 dimension_names for 'time' (recreate store at fresh path).")
    for v in ds.data_vars:
        if v not in root or not has_dn(root[v]):
            raise KeyError(f"Store missing v3 dimension_names for {v!r} (recreate store at fresh path).")

def write_one_day_region(gcs_url: str, time_index: int, zarr_path: str, gcp_sa_json: str,
                         lat_chunk: int, lon_chunk: int) -> str:
    # auth on worker
    cred_path = os.path.join(tempfile.gettempdir(), f"gcp_sa_worker_{uuid.uuid4().hex}.json")
    with open(cred_path, "w") as f:
        f.write(gcp_sa_json)
    os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path

    try:
        fs = gcsfs.GCSFileSystem(token="google_default")
        mapper = fs.get_mapper(zarr_path)

        tmp_dir = Path("/tmp/chla_nc_to_zarr_workers")
        tmp_dir.mkdir(parents=True, exist_ok=True)
        local_nc = tmp_dir / Path(gcs_url).name

        fs.get(gcs_url, str(local_nc))

        ds = xr.open_dataset(local_nc, engine="h5netcdf")
        try:
            ds = normalize_time(ds)
            ds_time = make_time_only_for_region(ds).chunk(chunk_spec(ds, lat_chunk, lon_chunk))

            assert_dimension_names_present(mapper, ds_time)

            region = {"time": slice(time_index, time_index + 1)}
            ds_time.to_zarr(mapper, mode="r+", region=region, consolidated=False, zarr_version=3)
        finally:
            ds.close()
            local_nc.unlink(missing_ok=True)

        return f"OK {Path(gcs_url).name} -> time_index={time_index}"

    finally:
        try:
            os.remove(cred_path)
        except FileNotFoundError:
            pass


Writing chla_zarr_worker.py


In [3]:
from __future__ import annotations

import os
import re
import uuid
import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr
import gcsfs
import zarr

from dask_gateway import Gateway
from dask.distributed import as_completed

import chla_zarr_worker as wz  # <-- your uploaded module


# =============================================================================
# CONFIG (edit as needed)
# =============================================================================
TOKEN_PATH = "/home/jovyan/.config/gcloud/application_default_credentials.json"

BUCKET = "nmfs_odp_nwfsc"
NETCDF_PREFIX = "CB/fish-pace-datasets/chla-z/netcdf"
NETCDF_PATTERN = f"{BUCKET}/{NETCDF_PREFIX}/chla_z_*.nc"

# IMPORTANT: start with a fresh v3 path
ZARR_PATH = f"gcs://{BUCKET}/CB/fish-pace-datasets/chla-z/zarr_v12"

LAT_CHUNK = 128
LON_CHUNK = 128

MIN_WORKERS = 4
MAX_WORKERS = 12
WORKER_RESOURCE = "4CPU, 30.2Gi"

# =============================================================================
# LISTING + TIME AXIS
# =============================================================================
_date_re = re.compile(r"chla_z_(\d{8})\.nc$")


def date_from_url(gcs_url: str) -> pd.Timestamp:
    m = _date_re.search(gcs_url)
    if not m:
        raise ValueError(f"Could not parse date from: {gcs_url}")
    return pd.to_datetime(m.group(1), format="%Y%m%d")


def list_netcdf_urls(fs: gcsfs.GCSFileSystem) -> list[str]:
    paths = sorted(fs.glob(NETCDF_PATTERN))
    urls = ["gcs://" + p for p in paths]
    return sorted(urls, key=date_from_url)


def build_time_index(urls: list[str]) -> pd.DatetimeIndex:
    times = pd.to_datetime([date_from_url(u) for u in urls]).astype("datetime64[ns]")
    return pd.DatetimeIndex(times, name="time")


# =============================================================================
# AUTH HELPERS
# =============================================================================
def read_sa_json(token_path: str = TOKEN_PATH) -> str:
    with open(token_path, "r") as f:
        return f.read()


def make_gcsfs_with_sa_json(gcp_sa_json: str) -> tuple[gcsfs.GCSFileSystem, str]:
    cred_path = os.path.join(tempfile.gettempdir(), f"gcp_sa_{uuid.uuid4().hex}.json")
    with open(cred_path, "w") as f:
        f.write(gcp_sa_json)
    os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
    fs = gcsfs.GCSFileSystem(token="google_default")
    return fs, cred_path


# =============================================================================
# ZARR v3 TEMPLATE (dimension_names required)
# =============================================================================
def ensure_store_is_empty(fs: gcsfs.GCSFileSystem, zarr_path: str) -> None:
    if not zarr_path.startswith("gcs://"):
        raise ValueError("Expected zarr_path like gcs://bucket/prefix")
    no_scheme = zarr_path[len("gcs://") :]
    try:
        existing = fs.ls(no_scheme)
    except FileNotFoundError:
        existing = []
    if existing:
        raise RuntimeError(
            f"Target ZARR_PATH is not empty:\n  {zarr_path}\n"
            "Pick a NEW ZARR_PATH or delete the existing prefix before running."
        )


def normalize_time(ds: xr.Dataset) -> xr.Dataset:
    if "time" in ds.coords:
        ds = ds.assign_coords(time=ds["time"].astype("datetime64[ns]"))
    return ds


def create_zarr_v3_template_direct(
    gcp_sa_json: str,
    sample_url: str,
    times: pd.DatetimeIndex,
    zarr_path: str,
    lat_chunk: int,
    lon_chunk: int,
) -> None:
    fs, cred_path = make_gcsfs_with_sa_json(gcp_sa_json)
    try:
        ensure_store_is_empty(fs, zarr_path)

        tmp_dir = Path("/tmp/chla_zarr_template_v3")
        tmp_dir.mkdir(parents=True, exist_ok=True)

        local_nc = tmp_dir / Path(sample_url).name
        fs.get(sample_url, str(local_nc))

        ds0 = xr.open_dataset(local_nc, engine="h5netcdf")
        try:
            ds0 = normalize_time(ds0)
            nt = len(times)

            mapper = fs.get_mapper(zarr_path)
            root = zarr.group(store=mapper, overwrite=True, zarr_format=3)

            # time coordinate (int64 ns)
            time_int = times.values.astype("datetime64[ns]").astype("int64")
            root.create_array(
                name="time",
                shape=(nt,),
                chunks=(min(nt, 1024),),
                dtype="int64",
                dimension_names=("time",),
                overwrite=True,
            )
            root["time"][:] = time_int

            # coords: use data=..., and DO NOT pass dtype/shape with data
            for cname in ["lat", "lon", "z", "z_start", "z_end"]:
                if cname in ds0.coords:
                    vals = np.asarray(ds0[cname].values)
                    dims = tuple(ds0[cname].dims)
                    root.create_array(
                        name=cname,
                        data=vals,
                        chunks=vals.shape,
                        dimension_names=dims,
                        overwrite=True,
                    )

            # data vars: create empty arrays for region writes
            for v in ds0.data_vars:
                da0 = ds0[v]
                dims = tuple(da0.dims)

                shape: list[int] = []
                chunks: list[int] = []
                for d in dims:
                    if d == "time":
                        shape.append(nt); chunks.append(1)
                    elif d == "z":
                        shape.append(ds0.sizes["z"]); chunks.append(ds0.sizes["z"])
                    elif d == "lat":
                        shape.append(ds0.sizes["lat"]); chunks.append(lat_chunk)
                    elif d == "lon":
                        shape.append(ds0.sizes["lon"]); chunks.append(lon_chunk)
                    else:
                        shape.append(ds0.sizes[d]); chunks.append(ds0.sizes[d])

                fill_value = float("nan") if np.issubdtype(da0.dtype, np.floating) else 0

                root.create_array(
                    name=v,
                    shape=tuple(shape),
                    chunks=tuple(chunks),
                    dtype=da0.dtype,
                    fill_value=fill_value,
                    dimension_names=dims,
                    overwrite=True,
                )

        finally:
            ds0.close()
            local_nc.unlink(missing_ok=True)

    finally:
        try:
            os.remove(cred_path)
        except FileNotFoundError:
            pass


# =============================================================================
# MAIN
# =============================================================================
def zarr_main():
    gcp_sa_json = read_sa_json(TOKEN_PATH)

    # list urls on client
    fs, cred_path = make_gcsfs_with_sa_json(gcp_sa_json)
    try:
        urls = list_netcdf_urls(fs)
        if not urls:
            raise RuntimeError("No NetCDF files found with pattern: " + NETCDF_PATTERN)
        print("nfiles:", len(urls), "first:", urls[0])
        times = build_time_index(urls)
    finally:
        try:
            os.remove(cred_path)
        except FileNotFoundError:
            pass

    print(f"Creating Zarr v3 template at: {ZARR_PATH}")
    create_zarr_v3_template_direct(
        gcp_sa_json=gcp_sa_json,
        sample_url=urls[0],
        times=times,
        zarr_path=ZARR_PATH,
        lat_chunk=LAT_CHUNK,
        lon_chunk=LON_CHUNK,
    )
    print("Template created:", ZARR_PATH)

    # Dask-Gateway cluster
    gateway = Gateway()
    options = gateway.cluster_options()
    setattr(options, "worker_resource_allocation", WORKER_RESOURCE)

    cluster = gateway.new_cluster(options)
    cluster.adapt(minimum=MIN_WORKERS, maximum=MAX_WORKERS)
    client = cluster.get_client()

    print(cluster)
    print(client)
    print("Dask dashboard:", client.dashboard_link)

    # ensure workers have your module
    client.upload_file("chla_zarr_worker.py")
    client.run(lambda: __import__("chla_zarr_worker").__name__)

    # (optional) broadcast JSON once instead of embedding in every task payload
    gcp_sa_json_fut = client.scatter(gcp_sa_json, broadcast=True)

    futures = []
    fut_to_info = {}
    for idx, url in enumerate(urls):
        fut = client.submit(
            wz.write_one_day_region,
            url,
            idx,
            ZARR_PATH,
            gcp_sa_json_fut,  # <- scattered once
            LAT_CHUNK,
            LON_CHUNK,
            pure=False,
        )
        futures.append(fut)
        fut_to_info[fut.key] = (idx, url)

    n = len(futures)
    done = 0
    errors = 0

    try:
        for fut in as_completed(futures):
            done += 1
            idx, url = fut_to_info.get(fut.key, (None, None))
            try:
                msg = fut.result()
                print(f"[{done}/{n}] {msg}")
            except Exception as e:
                errors += 1
                print(f"[{done}/{n}] ERROR on idx={idx} url={url}")
                print(f"Exception: {type(e).__name__}: {e}")

        print(f"Finished region writes. Success={n - errors}, Errors={errors}")

    finally:
        client.close()
        cluster.close()

    


Exception ignored in: <function Gateway.__del__ at 0x7f5af1c756c0>
Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.11/site-packages/dask_gateway/client.py", line 380, in __del__
    self.close()
  File "/srv/conda/envs/notebook/lib/python3.11/site-packages/dask_gateway/client.py", line 353, in close
    elif self.loop.asyncio_loop.is_running():
         ^^^^^^^^^
  File "/srv/conda/envs/notebook/lib/python3.11/site-packages/dask_gateway/client.py", line 330, in loop
    return self._loop_runner.loop
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/srv/conda/envs/notebook/lib/python3.11/site-packages/distributed/utils.py", line 648, in loop
    raise RuntimeError(
RuntimeError: Accessing the loop property while the loop is not running is not supported


In [None]:
zarr_main()