In [1]:
import os

import numpy as np
import xarray as xr
from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics
from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data
from pyrte_rrtmgp.rte_solver import RTESolver
from pyrte_rrtmgp.all_skys_funcs import compute_profiles, create_gas_dataset, compute_clouds, compute_cloud_optics, expand_to_2d
from pyrte_rrtmgp.kernels.rte import (
    increment_1scalar_by_1scalar,
    increment_1scalar_by_2stream,
    increment_2stream_by_1scalar,
    increment_2stream_by_2stream,
    inc_1scalar_by_1scalar_bybnd,
    inc_1scalar_by_2stream_bybnd,
    inc_2stream_by_1scalar_bybnd,
    inc_2stream_by_2stream_bybnd,
)

rte_rrtmgp_dir = download_rrtmgp_data()

ncol = 24
nlay = 72
# Compute profiles
p_lay, t_lay, p_lev, t_lev, q, o3 = compute_profiles(300, ncol, nlay)

# Create xarray dataset with profiles
profiles = xr.Dataset(
    data_vars={
        "pres_layer": (["site", "layer"], p_lay),
        "temp_layer": (["site", "layer"], t_lay), 
        "pres_level": (["site", "level"], p_lev),
        "temp_level": (["site", "level"], t_lev),
        "water_vapor": (["site", "layer"], q),
        "ozone": (["site", "layer"], o3)
    }
)


gas_optics_lw = load_gas_optics(gas_optics_file=GasOpticsFiles.LW_G256)
lw_clouds = os.path.join(rte_rrtmgp_dir, "rrtmgp-clouds-lw.nc")
lw_aerosols = os.path.join(rte_rrtmgp_dir, "rrtmgp-aerosols-merra-lw.nc")

# for sw
# sfc_alb_dir = 0.06
# sfc_alb_dif = 0.06
# mu0 = 0.86


# Create dataset with gas concentrations
gas_values = {
    'carbon_dioxide_GM': 348.0e-6,  # scalar
    'methane_GM': 1650.0e-9,  # scalar
    'nitrous_oxide_GM': 306.0e-9,  # scalar
    'nitrogen_GM': 0.7808,  # scalar
    'oxygen_GM': 0.2095,  # scalar
    'carbon_monoxide_GM': 0.0  # scalar
}

gases = create_gas_dataset(gas_values, ncol, nlay)

atmosphere = xr.merge([profiles, gases])
top_at_1 = False
t_sfc = t_lev[:, nlay if top_at_1 else 0]
atmosphere["surface_temperature"] = xr.DataArray(t_sfc, dims=["site"])

gas_optics_lw = load_gas_optics(gas_optics_file=GasOpticsFiles.LW_G256)
clear_sky_optical_props = gas_optics_lw.gas_optics.compute(atmosphere, problem_type="absorption", add_to_input=False)






# gas_optics_lw.gas_optics.compute(atmosphere, problem_type="absorption")


cloud_optics = xr.load_dataset(lw_clouds)
aerosol_optics = xr.load_dataset(lw_aerosols)

top_at_1 = False
t_sfc = t_lev[0, nlay if top_at_1 else 0]
emis_sfc = expand_to_2d(0.98, ncol, nlay+1, name="surface_emissivity")


lwp, iwp, rel, rei = compute_clouds(cloud_optics, ncol, nlay, p_lay, t_lay)

tau = compute_cloud_optics(lwp, iwp, rel, rei, cloud_optics)

# Create dataset with optical properties
clouds_optical_props = xr.Dataset(
    {
        "tau": (["site", "layer", "gpt"], tau),
    }
)

clear_sky_optical_props["surface_emissivity"] = 0.98


def combine_optical_props(op1, op2):
    """
    Combines two sets of optical properties, modifying op1 in place.
    
    Args:
        op1: First set of optical properties, will be modified.
        op2: Second set of optical properties to add.
    """
    ncol = op2.dims["site"]
    nlay = op2.dims["layer"]
    ngpt = op2.dims["gpt"]
    
    # Check if input has only tau (1-stream) or tau, ssa, g (2-stream)
    is_1stream_1 = hasattr(op1, 'tau') and not hasattr(op1, 'ssa')
    is_1stream_2 = hasattr(op2, 'tau') and not hasattr(op2, 'ssa')
    
    # Check if the g-points are equal between the two datasets
    gpoints_equal = op1.sizes['gpt'] == op2.sizes['gpt']
    
    if gpoints_equal:
        if is_1stream_1:
            if is_1stream_2:
                # 1-stream by 1-stream
                increment_1scalar_by_1scalar(
                    ncol, nlay, ngpt,
                    op2.tau.values,
                    op1.tau.values
                )
                op2['tau'] = (('site', 'layer', 'gpt'), op2.tau.values)
            else:
                # 1-stream by 2-stream
                increment_1scalar_by_2stream(
                    ncol, nlay, ngpt,
                    op2.tau.values,
                    op1.tau.values,
                    op1.ssa.values
                )
                op2['tau'] = (('site', 'layer', 'gpt'), op2.tau.values)
        else:  # 2-stream output
            if is_1stream_2:
                # 2-stream by 1-stream
                increment_2stream_by_1scalar(
                    ncol, nlay, ngpt,
                    op2.tau.values,
                    op2.ssa.values,
                    op1.tau.values
                )
                op2['tau'] = (('site', 'layer', 'gpt'), op2.tau.values)
                op2['ssa'] = (('site', 'layer', 'gpt'), op2.ssa.values)
            else:
                # 2-stream by 2-stream
                increment_2stream_by_2stream(
                    ncol, nlay, ngpt,
                    op2.tau.values,
                    op2.ssa.values,
                    op2.g.values,
                    op1.tau.values,
                    op1.ssa.values,
                    op1.g.values
                )
                op2['tau'] = (('site', 'layer', 'gpt'), op2.tau.values)
                op2['ssa'] = (('site', 'layer', 'gpt'), op2.ssa.values)
                op2['g'] = (('site', 'layer', 'gpt'), op2.g.values)
    
    else:
        # By-band increment (when op2's ngpt equals op1's nband)
        if op2.sizes['bnd'] != op1.sizes['gpt']:
            raise ValueError("Incompatible g-point structures for by-band increment")
            
        if is_1stream_1:
            if is_1stream_2:
                # 1-stream by 1-stream by band
                inc_1scalar_by_1scalar_bybnd(
                    ncol, nlay, ngpt,
                    op2.tau.values,
                    op1.tau.values,
                    op2.sizes['bnd'],
                    op2["bnd_limits_gpt"].values
                )
                op2['tau'] = (('site', 'layer', 'gpt'), op2.tau.values)
            else:
                # 1-stream by 2-stream by band
                inc_1scalar_by_2stream_bybnd(
                    ncol, nlay, ngpt,
                    op2.tau.values,
                    op1.tau.values,
                    op1.ssa.values,
                    op2.dims['bnd'],
                    op2["bnd_limits_gpt"].values
                )
                op2['tau'] = (('site', 'layer', 'gpt'), op2.tau.values)
        else:
            if is_1stream_2:
                # 2-stream by 1-stream by band
                inc_2stream_by_1scalar_bybnd(
                    ncol, nlay, ngpt,
                    op2.tau.values,
                    op2.ssa.values,
                    op1.tau.values,
                    op2.dims['bnd'],
                    op2["bnd_limits_gpt"].values
                )
                op2['tau'] = (('site', 'layer', 'gpt'), op2.tau.values)
                op2['ssa'] = (('site', 'layer', 'gpt'), op2.ssa.values)
            else:
                # 2-stream by 2-stream by band
                inc_2stream_by_2stream_bybnd(
                    ncol, nlay, ngpt,
                    op2.tau.values,
                    op2.ssa.values,
                    op2.g.values,
                    op1.tau.values,
                    op1.ssa.values,
                    op1.g.values,
                    op2.dims['bnd'],
                    op2["bnd_limits_gpt"].values
                )
                op2['tau'] = (('site', 'layer', 'gpt'), op2.tau.values)
                op2['ssa'] = (('site', 'layer', 'gpt'), op2.ssa.values)
                op2['g'] = (('site', 'layer', 'gpt'), op2.g.values)


# Increment the optical properties
combine_optical_props(clouds_optical_props, clear_sky_optical_props)

solver = RTESolver()
fluxes = solver.solve(clear_sky_optical_props, add_to_input=False)

fluxes



  ncol = op2.dims["site"]
  nlay = op2.dims["layer"]
  ngpt = op2.dims["gpt"]
