In [None]:
from datetime import datetime

import netCDF4  # noqa

from pyrte_rrtmgp.external_data_helpers import download_dyamond2_data

# Download the data
downloaded_files = download_dyamond2_data(
    datetime(2020, 2, 1, 9),
    compute_gas_optics=False,
    data_dir="GEOS-DYAMOND2-data",
)

In [None]:
import numpy as np
import xarray as xr
from dask.distributed import Client

In [None]:
client = Client(n_workers=7, threads_per_worker=7, memory_limit="64GB")
client

In [None]:
data = xr.open_mfdataset(
    "GEOS-DYAMOND2-data/*inst_01hr_3d_*.nc4",
    drop_variables=[
        "anchor",
        "cubed_sphere",
        "orientation",
        "contacts",
        "corner_lats",
        "corner_lons",
    ],
)
data

In [None]:
# Load the global dataset
atmosphere = data.isel(lev=slice(78, 181)).chunk(
    {"Xdim": 2880, "Ydim": 18, "nf": 1, "lev": -1}
)
atmosphere

In [None]:
from functools import partial

from pyrte_rrtmgp import rrtmgp_cloud_optics
from pyrte_rrtmgp.data_types import CloudOpticsFiles

_cloud_optics_lw = rrtmgp_cloud_optics.load_cloud_optics(
    cloud_optics_file=CloudOpticsFiles.LW_BND
)


def _process_chunk(ds, cloud_optics_lw):
    from pyrte_rrtmgp import rrtmgp_cloud_optics  # noqa
    from pyrte_rrtmgp.constants import HELMERT1
    from pyrte_rrtmgp.data_types import CloudOpticsFiles  # noqa

    ds["lwp"] = (ds["DELP"] * ds["QL"]) * 1000 / HELMERT1
    ds["iwp"] = (ds["DELP"] * ds["QI"]) * 1000 / HELMERT1
    ds["rel"] = ds["RL"] * 1e6
    ds["rei"] = ds["RI"] * 1e6
    ds = ds.rename({"lev": "layer"})

    ds = ds[["lwp", "iwp", "rel", "rei"]]

    # I think best optimization here would be removing the copy()
    # on line 107 of __call__
    # and figuring out how to avoid the 5 repeated `unstack` calls...
    tau_chunk_ds = cloud_optics_lw.compute_cloud_optics(
        ds, problem_type="absorption", add_to_input=False
    )

    # Aggregate over 'bnd' and 'layer' dimensions
    tau_agg_chunk = tau_chunk_ds.sum(dim=["bnd", "layer"], skipna=True)

    # Since we need to return something, let's return the smallest dataset
    # possible and then write multifile output from the process_chunk func
    # to avoid data transfers

    return tau_agg_chunk


process_chunk = partial(_process_chunk, cloud_optics_lw=_cloud_optics_lw)

In [None]:
dask_data = xr.full_like(atmosphere["DELP"].isel(lev=0, drop=True), np.nan)

template_da = xr.DataArray(
    data=dask_data,
    dims=dask_data.dims,
    coords=dask_data.coords,
)

template_agg = xr.Dataset(
    data_vars={
        "tau": template_da.copy(),
        # "ssa": template_da.copy(),
        # "g": template_da.copy()
    }
)

result = xr.map_blocks(
    func=process_chunk, obj=atmosphere, template=template_agg
)
result

In [None]:
result.to_netcdf("tau.nc", compute=True)

In [None]:
read_results = xr.open_dataset("tau.nc")
read_results.tau