In [None]:
from datetime import datetime

from pyrte_rrtmgp.external_data_helpers import download_dyamond2_data

# Download the data
downloaded_files = download_dyamond2_data(
    datetime(2020, 2, 1, 9),
    compute_gas_optics=False,
    data_dir="GEOS-DYAMOND2-data",
)

### Reorganize the data to improve the processing time

In [None]:
import xarray as xr

from pyrte_rrtmgp.constants import HELMERT1

nlev = 181
min_lev_ice = 78

# Load the global dataset
atmosphere = (
    xr.open_mfdataset(
        "GEOS-DYAMOND2-data/*inst_01hr_3d_*.nc4",
        drop_variables=[
            "anchor",
            "cubed_sphere",
            "orientation",
            "contacts",
            "corner_lats",
            "corner_lons",
        ],
    )
    .isel(lev=slice(min_lev_ice, nlev))
    .rename({"lev": "layer"})
    .chunk({"Xdim": 2880, "Ydim": 72, "nf": 1, "layer": -1})
)

# Need to convert LWP/IWP to g/m2 and rel/rei to microns
atmosphere["lwp"] = (atmosphere["DELP"] * atmosphere["QL"]) * 1000 / HELMERT1
atmosphere["iwp"] = (atmosphere["DELP"] * atmosphere["QI"]) * 1000 / HELMERT1
atmosphere["rel"] = atmosphere["RL"] * 1e6
atmosphere["rei"] = atmosphere["RI"] * 1e6

needed_vars = ["lwp", "iwp", "rel", "rei"]

atmosphere[needed_vars].to_netcdf(
    "atmosphere.nc",
    encoding={var: {"zlib": True, "complevel": 5} for var in needed_vars},
)

### Compute the cloud optics

For avoiding memory issues please use dask version 2025.3.0 or higher. A [fix](https://docs.dask.org/en/stable/changelog.html#v2025-3-0) for the apply_ufunc was included in it that solve the memory issues.

In [None]:
import dask.array as da
import numpy as np
import xarray as xr
from dask.diagnostics import ProgressBar

from pyrte_rrtmgp import rrtmgp_cloud_optics
from pyrte_rrtmgp.data_types import CloudOpticsFiles

atmosphere = xr.open_dataset(
    "atmosphere.nc", chunks={"Ydim": 72, "nf": 1, "Xdim": -1, "time": 1}
)

# Load cloud optics (this object is relatively small and will be serialized to workers)
cloud_optics_lw = rrtmgp_cloud_optics.load_cloud_optics(
    cloud_optics_file=CloudOpticsFiles.LW_BND
)

# Define the function to be applied to each chunk


def process_chunk(atm_chunk):
    from pyrte_rrtmgp import rrtmgp_cloud_optics

    tau_chunk_ds = cloud_optics_lw.compute_cloud_optics(
        atm_chunk, problem_type="absorption", add_to_input=False
    )
    # Assuming the relevant variable is the first one if it's a Dataset, or just use the DataArray
    if isinstance(tau_chunk_ds, xr.Dataset):
        # Infer the name of the output variable if possible, or assume a default/first one
        output_var_name = list(tau_chunk_ds.data_vars.keys())[
            0
        ]  # Adjust if needed
        tau_chunk = tau_chunk_ds[output_var_name]
    else:
        tau_chunk = tau_chunk_ds

    # Aggregate over 'bnd' and 'layer' dimensions
    # Ensure the output is float32
    tau_agg_chunk = tau_chunk.sum(dim=["bnd", "layer"], skipna=True).astype(
        np.float32
    )
    # The result should be a DataArray with dimensions (Ydim_chunk, Xdim_chunk)
    tau_agg_chunk["lats"] = atm_chunk["lats"]
    return tau_agg_chunk


dims_to_remove = ["layer"]

# Determine dimensions, coordinates, shape, and chunks for the template
kept_coords = [
    coord for coord in atmosphere.coords.keys() if coord not in dims_to_remove
]
template_coords = {
    coord: atmosphere.coords[coord]
    for coord in kept_coords
    if coord in atmosphere.coords
}
kept_dims = [dim for dim in atmosphere.dims if dim in kept_coords]

# Create a template array like lwp but without the level dimension
dask_data = xr.full_like(
    atmosphere["lwp"].isel(layer=0, drop=True), np.nan, dtype=np.float32
)

# Create template DataArray with dask array
template_agg = xr.DataArray(
    data=dask_data,
    dims=kept_dims,
    coords=template_coords,
    name="aggregated_tau",
)

# --- Apply the function chunk-wise using map_blocks ---
tau_agg = xr.map_blocks(
    process_chunk,
    atmosphere,  # Input Dataset (chunked)
    template=template_agg,  # Provide the template with properly chunked dask arrays
)

with ProgressBar():
    result = tau_agg.compute(scheduler="multiprocessing")

In [None]:
import dask.array as da
import numpy as np
import xarray as xr
from dask.diagnostics import ProgressBar

from pyrte_rrtmgp import rrtmgp_cloud_optics
from pyrte_rrtmgp.data_types import CloudOpticsFiles

atmosphere = xr.open_dataset(
    "atmosphere.nc",
    chunks={"Ydim": 36, "nf": 1, "Xdim": -1, "time": 1, "layer": -1},
)

# Load cloud optics (this object is relatively small and will be serialized to workers)
cloud_optics_lw = rrtmgp_cloud_optics.load_cloud_optics(
    cloud_optics_file=CloudOpticsFiles.LW_BND
)

# Define the function to be applied to each chunk


def process_chunk(atm_chunk):
    from pyrte_rrtmgp import rrtmgp_cloud_optics

    tau_chunk_ds = cloud_optics_lw.compute_cloud_optics(
        atm_chunk, problem_type="absorption", add_to_input=False
    )
    # Assuming the relevant variable is the first one if it's a Dataset, or just use the DataArray
    if isinstance(tau_chunk_ds, xr.Dataset):
        # Infer the name of the output variable if possible, or assume a default/first one
        output_var_name = list(tau_chunk_ds.data_vars.keys())[
            0
        ]  # Adjust if needed
        tau_chunk = tau_chunk_ds[output_var_name]
    else:
        tau_chunk = tau_chunk_ds

    tau_chunk["lons"] = atm_chunk["lons"]
    return tau_chunk


# Determine dimensions, coordinates, shape, and chunks for the template
kept_coords = [coord for coord in atmosphere.coords.keys() if coord]
template_coords = {
    coord: atmosphere.coords[coord]
    for coord in kept_coords
    if coord in atmosphere.coords
}
kept_dims = [dim for dim in atmosphere.dims if dim in kept_coords]

# Create a template array like lwp but without the level dimension
# Create a template array like lwp but with an additional dimension of size 16
base_array = atmosphere["lwp"]
# Create a new dask array with the additional dimension
dask_data = da.full(
    shape=(*base_array.shape, 16),
    fill_value=np.nan,
    chunks=(*base_array.chunks, (16,)),
    dtype=base_array.dtype,
)

# coords = {**base_array.coords, 'bnd': cloud_optics_lw["nband"].values}
# Convert to xarray DataArray
dask_data = xr.DataArray(
    data=dask_data,
    dims=list(base_array.dims) + ["bnd"],
    coords=base_array.coords,
)

dim_order = ("layer", "bnd", "time", "nf", "Ydim", "Xdim")
dask_data = dask_data.transpose(*dim_order)

# Create template DataArray with dask array
template_agg = xr.DataArray(
    data=dask_data,
    dims=dask_data.dims,
    coords=dict(dask_data.coords),
    name="aggregated_tau",
)

# --- Apply the function chunk-wise using map_blocks ---
tau_agg = xr.map_blocks(process_chunk, atmosphere, template=template_agg)

print("Computing and saving result to NetCDF...")
with ProgressBar():
    tau_agg.to_netcdf("computed_tau.nc", compute=False)
    tau_agg.compute(scheduler="multiprocessing")
print("Done.")