# Process daily files in dask gateway

In [1]:
# --- Core data handling and plotting libraries ---
import earthaccess
import xarray as xr       # for working with labeled multi-dimensional arrays
import numpy as np        # for numerical operations on arrays
import pandas as pd
# --- Custom python functions ---
import os, importlib
# Looks to see if you have the file already and if not, downloads from GitHub
if not os.path.exists("ml_utils.py"):
    !wget -q https://raw.githubusercontent.com/fish-pace/2025-tutorials/main/ml_utils.py

import ml_utils as mu
importlib.reload(mu)

<module 'ml_utils' from '/home/jovyan/chla-z-modeling/ml_utils.py'>

In [2]:
import os
import json

# In reality: read from a secret, NOT hardcoded
with open("/home/jovyan/.config/gcloud/application_default_credentials.json") as f:
    GCP_SA_JSON = f.read()

In [5]:
from functools import partial

def process_one_granule(
    res,
    lat_chunk=LAT_CHUNK,
    lon_chunk=LON_CHUNK,
    bucket_name=BUCKET_NAME,
    destination_prefix=DESTINATION_PREFIX,
    force_rerun=FORCE_RERUN,
    ed_username=None,
    ed_password=None,
    gcp_sa_json=None,
):
    import os
    import tempfile
    import earthaccess
    import xarray as xr
    import pandas as pd
    from google.cloud import storage
    from pathlib import Path

    # --- EARTHACCESS AUTH VIA ENV VARS (inside worker) ---
    if ed_username is not None and ed_password is not None:
        os.environ["EARTHDATA_USERNAME"] = ed_username
        os.environ["EARTHDATA_PASSWORD"] = ed_password

    auth = earthaccess.login(strategy="environment", persist=False)

    # --- GCP AUTH VIA JSON TEXT (inside worker) ---
    cred_path = None
    if gcp_sa_json is not None:
        tmp_dir = tempfile.gettempdir()
        # unique-ish filename per worker/process
        cred_path = os.path.join(tmp_dir, "gcp_sa_worker.json")
        with open(cred_path, "w") as f:
            f.write(gcp_sa_json)
        os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path

    # -------------------------------
    #  Normal per-day pipeline below
    # -------------------------------
    day_iso = res["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"]
    day = pd.to_datetime(day_iso)
    day_str = day.strftime("%Y%m%d")

    storage_client = storage.Client(project="noaa-gcs-public-data")
    bucket = storage_client.bucket(bucket_name)
    blob_path = f"{destination_prefix}/chla_z_{day_str}.nc"
    blob = bucket.blob(blob_path)

    if blob.exists() and not force_rerun:
        msg = f"[{day_str}] SKIP (exists at gs://{bucket_name}/{blob_path})"
        print(msg)
        return msg

    files = earthaccess.open([res], auth=auth, pqdm_kwargs={"disable": True})
    rrs_ds = xr.open_dataset(files[0])

    try:
        if "time" in rrs_ds.dims:
            R = rrs_ds["Rrs"].sel(time=day).squeeze("time")
        else:
            R = rrs_ds["Rrs"]
        R = R.transpose("lat", "lon", "wavelength")

        pred = bundle.predict(
            R,
            brt_models=bundle.model,
            feature_cols=bundle.meta["feature_cols"],
            consts={"solar_hour": 0, "type": 1},
            chunk_size_lat=100,
            time=day.to_datetime64(),
            z_name="z",
            silent=True,
        )

        ds_day = build_chla_profile_dataset(pred)

        tmp_dir = Path(tempfile.gettempdir())
        local_path = tmp_dir / f"chla_z_{day_str}.nc"

        encoding = {
            "CHLA": {
                "dtype": "float32",
                "zlib": True,
                "complevel": 4,
                "chunksizes": (1, ds_day.sizes["z"], lat_chunk, lon_chunk),
            }
        }

        ds_day.to_netcdf(local_path, engine="h5netcdf", encoding=encoding)
        blob.upload_from_filename(str(local_path))
        local_path.unlink(missing_ok=True)

        gcs_url = f"gs://{bucket_name}/{blob_path}"
        msg = f"[{day_str}] WROTE {gcs_url}"
        print(msg)
        return msg

    finally:
        rrs_ds.close()
        # optional: clean up the creds file
        if cred_path is not None:
            try:
                os.remove(cred_path)
            except FileNotFoundError:
                pass



In [4]:
# Read in model

## Load model
bundle = mu.load_ml_bundle("models/brt_chla_profiles_bundle.zip")
brt_models = bundle.model
meta = bundle.meta
rrs_cols = meta["rrs_cols"]
chl_cols = meta["y_col"]
extra = meta["extra_cols"]
dataset = bundle.data["dataset"]
train_idx = bundle.data["train_idx"]
test_idx = bundle.data["test_idx"]
X_train = bundle.data["X_train"]
X_test = bundle.data["X_test"]
y_train_all = bundle.data["y_train"]
y_test_all = bundle.data["y_test"]


Loaded ML bundle from: models/brt_chla_profiles_bundle.zip
  model_kind : pickle
  model_type : collection (dict), n_submodels=20
  example key: CHLA_0_10
  target     : log10_CHLA_A_B depth bins
  features   : 174 columns
  train/test : 4408 / 1102 rows
  dataset    : 5510 rows stored in bundle

Usage example (Python):
  bundle = load_ml_bundle('path/to/bundle.zip')
  # Predict using helper 'predict_all_depths_for_day'
  # Example: predict all depths for one day from a BRF dataset R
  pred = bundle.predict(
      R_dataset,                  # xr.DataArray/xr.Dataset with lat/lon + predictors
      brt_models=bundle.model,    # dict of models by depth bin
      feature_cols=bundle.meta['feature_cols'],
      consts={'solar_hour': 12.0, 'type': 1},
  )  # -> e.g. CHLA(time?, z, lat, lon)

  # Plot using helper 'make_plot_pred_map'
  fig, ax = bundle.plot(pred_da, pred_label='Prediction')



# Set up Dask Gateway

In [None]:
from dask_gateway import Gateway

gateway = Gateway()
options = gateway.cluster_options()

# I don't know how to decide. I know that one day takes 5Gb RAM
# options.worker_cores = 4
# options.worker_memory = "32GiB"

cluster = gateway.new_cluster(options)

# I don't know how to decide
# cluster.scale(8)  # say 8 workers

# I have 560 days and each day is 30min. I don't want this to take all day
cluster.adapt(minimum=4, maximum=16)

from dask.distributed import Client
client = cluster.get_client()
print(cluster)
print(client)


In [None]:
# we already have rrs_results from earthaccess.search_data
# maybe subset by date or just take all DAY granules
granules = rrs_results[:10]   # or rrs_results[:100] for testing

# one Dask task per granule
futures = client.map(process_one_granule, granules)

# block until all are done, get the GCS URLs (or errors)
results = client.gather(futures)
print("Uploaded daily files:")
for r in results:
    print("  ", r)

client.close()
cluster.close()

In [None]:
from functools import partial

import os
import netrc
import json

netrc_path = os.path.expanduser("~/.netrc")
auth = netrc.netrc(netrc_path)
login, account, password = auth.authenticators("urs.earthdata.nasa.gov")
ED_USER = login
ED_PASS = password
with open("/home/jovyan/.config/gcloud/application_default_credentials.json") as f:
    GCP_SA_JSON = f.read()
    
fn = partial(
    process_one_granule,
    ed_username=ED_USER,
    ed_password=ED_PASS,
    gcp_sa_json=GCP_SA_JSON,
)

futures = client.map(fn, rrs_results_subset)
results = client.gather(futures)

In [11]:
import os
import netrc

netrc_path = os.path.expanduser("~/.netrc")
auth = netrc.netrc(netrc_path)
login, account, password = auth.authenticators("urs.earthdata.nasa.gov")


In [13]:
password

'C6m3U3iYTTQo'