In [None]:
import os

import multiprocessing
import numpy as np
import xarray as xr
import dask
from dask.distributed import Client

try:
    client = Client.current()
except ValueError:
    n_workers = multiprocessing.cpu_count()  # Get number of CPU cores available
    client = Client(n_workers=n_workers)

from pyrte_rrtmgp import rrtmgp_gas_optics
from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics
from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data
from pyrte_rrtmgp.rte_solver import RTESolver

ERROR_TOLERANCE = 1e-7

rte_rrtmgp_dir = download_rrtmgp_data()
rfmip_dir = os.path.join(rte_rrtmgp_dir, "examples", "rfmip-clear-sky")
input_dir = os.path.join(rfmip_dir, "inputs")
ref_dir = os.path.join(rfmip_dir, "reference")

gas_optics_lw = load_gas_optics(gas_optics_file=GasOpticsFiles.LW_G256)

atmosphere_file = "multiple_input4MIPs_radiation_RFMIP_UColorado-RFMIP-1-2_none.nc"
atmosphere_path = os.path.join(input_dir, atmosphere_file)
atmosphere = xr.open_dataset(atmosphere_path, chunks={"expt": 3})

gas_optics_lw.gas_optics.compute(atmosphere, problem_type="absorption")

solver = RTESolver()
fluxes = solver.solve(atmosphere, add_to_input=False)

rlu_reference = f"{ref_dir}/rlu_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"
rld_reference = f"{ref_dir}/rld_Efx_RTE-RRTMGP-181204_rad-irf_r1i1p1f1_gn.nc"
rlu = xr.load_dataset(rlu_reference, decode_cf=False)
rld = xr.load_dataset(rld_reference, decode_cf=False)

assert np.isclose(
    fluxes["lw_flux_up"], rlu["rlu"], atol=ERROR_TOLERANCE
).all()
assert np.isclose(
    fluxes["lw_flux_down"], rld["rld"], atol=ERROR_TOLERANCE
).all()