# Comparison using kerchunk / datatree

In [None]:
import coiled
import dask
import pandas as pd
import xarray as xr
from dask.distributed import wait
from datatree import DataTree
from utils import generate_WBGT, load_elev

In [None]:
cluster = coiled.Cluster(
    n_workers=10,
    spot_policy="spot_with_fallback",
    arm=True,
)
client = cluster.get_client()

In [None]:
client

In [None]:
# Read the reference catalog into a Pandas DataFrame
cat_df = pd.read_csv(
    "s3://carbonplan-share/nasa-nex-reference/reference_catalog_nested.csv"
)
# Select only ssp245 && HISTORICAL!
ssp245_historical_catalog = cat_df[cat_df["ID"].str.contains("ssp245|historical")]

# Subset 20 GCMs for time comparison
ssp245_historical_catalog = ssp245_historical_catalog.iloc[0:20]

# Convert the DataFrame into a dictionary
catalog = ssp245_historical_catalog.set_index("ID").T.to_dict("records")[0]

In [None]:
@dask.delayed
def load_ref_ds(gcm_scenario: str, url: str):
    storage_options = {
        "remote_protocol": "s3",
        "target_protocol": "s3",
        "target_options": {"anon": True},
        "lazy": True,
        "skip_instance_cache": True,
    }  # options passed to fsspec
    open_dataset_options = {"chunks": {}}  # opens passed to xarray

    ds = xr.open_dataset(
        url,
        engine="kerchunk",
        storage_options=storage_options,
        open_dataset_options=open_dataset_options,
    )

    if {
        "huss",
        "tasmax",
        "tas",
    }.issubset(set(list(ds))):
        ds = ds[["huss", "tasmax", "tas"]]
        # adding the gcm/scenario combo to attrs for later down the pipeline
        ds.attrs["gcm_scenario"] = gcm_scenario
        return {gcm_scenario: ds}

In [None]:
# convert catalog dict, drop any None's and compute
catalog_tuple_list = list(catalog.items())
jobs = [load_ref_ds(*catalog_tuple) for catalog_tuple in catalog_tuple_list]
catalog_dict_list = client.compute(jobs, sync=True)
catalog_dict = {
    key: value for d in catalog_dict_list if d is not None for key, value in d.items()
}

In [None]:
# Create datatree object from computed dictionary
dt = DataTree.from_dict(catalog_dict)

In [None]:
@dask.delayed
def calc_wbgt(ds):
    ds = ds.isel(time=slice(0, 365))
    output = (
        f"s3://carbonplan-scratch/TEMP_NASA_NEX/wbgt-shade-"
        f"gridded/years/{ds.attrs['gcm_scenario']}.zarr"
    )
    return generate_WBGT(ds=ds, output_fpath=output, elev=elev)


elev = load_elev()
ds_list = [ds.to_dataset() for ds in dt.leaves if ds.dims]

In [None]:
# compute
delayed_list = [calc_wbgt(ds) for ds in ds_list]
wait(client.compute(delayed_list))

In [None]:
client.shutdown()