In [None]:
import os

import numpy as np
import xarray as xr
from pyrte_rrtmgp.rrtmgp_gas_optics import GasOpticsFiles, load_gas_optics
from pyrte_rrtmgp.rrtmgp_data import download_rrtmgp_data
from typing import Union

rte_rrtmgp_dir = download_rrtmgp_data()



def compute_profiles(SST, ncol, nlay):
    """
    Construct profiles of pressure, temperature, humidity, and ozone
    following the RCEMIP protocol for a surface temperature of 300K.
    Based on Python implementation by Chiel van Heerwardeen.
    """
    # Constants
    z_trop = 15000.0
    z_top = 70.0e3
    g1 = 3.6478
    g2 = 0.83209 
    g3 = 11.3515
    o3_min = 1e-13
    g = 9.79764
    Rd = 287.04
    p0 = 101480.0  # Surface pressure
    z_q1 = 4.0e3
    z_q2 = 7.5e3
    q_t = 1.0e-8
    gamma = 6.7e-3
    q_0 = 0.01864  # for 300 K SST

    # Initialize arrays
    p_lay = np.zeros((ncol, nlay))
    t_lay = np.zeros((ncol, nlay))
    q_lay = np.zeros((ncol, nlay))
    o3 = np.zeros((ncol, nlay))
    p_lev = np.zeros((ncol, nlay+1))
    t_lev = np.zeros((ncol, nlay+1))

    # Initial calculations
    Tv0 = (1.0 + 0.608*q_0) * SST

    # Split resolution above and below RCE tropopause (15 km or about 125 hPa)
    z_lev = np.zeros(nlay+1)
    z_lev[0] = 0.0
    z_lev[1:nlay//2+1] = 2.0 * z_trop/nlay * np.arange(1, nlay//2+1)
    z_lev[nlay//2+1:] = z_trop + 2.0 * (z_top - z_trop)/nlay * np.arange(1, nlay//2+1)
    z_lay = 0.5 * (z_lev[:-1] + z_lev[1:])

    # Layer calculations
    for ilay in range(nlay):
        for icol in range(ncol):
            z = z_lay[ilay]
            if z > z_trop:
                q = q_t
                T = SST - gamma*z_trop/(1.0 + 0.608*q_0)
                Tv = (1.0 + 0.608*q) * T
                p = p0 * (Tv/Tv0)**(g/(Rd*gamma)) * np.exp(-((g*(z-z_trop))/(Rd*Tv)))
            else:
                q = q_0 * np.exp(-z/z_q1) * np.exp(-(z/z_q2)**2)
                T = SST - gamma*z / (1.0 + 0.608*q)
                Tv = (1.0 + 0.608*q) * T
                p = p0 * (Tv/Tv0)**(g/(Rd*gamma))

            p_lay[icol,ilay] = p
            t_lay[icol,ilay] = T
            q_lay[icol,ilay] = q
            p_hpa = p_lay[icol,ilay] / 100.0
            o3[icol,ilay] = max(o3_min, g1 * p_hpa**g2 * np.exp(-p_hpa/g3) * 1.0e-6)

    # Level calculations
    for ilay in range(nlay+1):
        for icol in range(ncol):
            z = z_lev[ilay]
            if z > z_trop:
                q = q_t
                T = SST - gamma*z_trop/(1.0 + 0.608*q_0)
                Tv = (1.0 + 0.608*q) * T
                p = p0 * (Tv/Tv0)**(g/(Rd*gamma)) * np.exp(-((g*(z-z_trop))/(Rd*Tv)))
            else:
                q = q_0 * np.exp(-z/z_q1) * np.exp(-(z/z_q2)**2)
                T = SST - gamma*z / (1.0 + 0.608*q)
                Tv = (1.0 + 0.608*q) * T
                p = p0 * (Tv/Tv0)**(g/(Rd*gamma))

            p_lev[icol,ilay] = p
            t_lev[icol,ilay] = T

    return p_lay, t_lay, p_lev, t_lev, q_lay, o3

ncol = 300
nlay = 100
p_lay, t_lay, p_lev, t_lev, q, o3 = compute_profiles(300, ncol, nlay)



class GasConcentrations:
    def __init__(self):
        self.concs = {}
        self.ncol = 0
        self.nlay = 0
        
    def set_vmr_scalar(self, gas: str, w: float) -> str:
        """Set scalar volume mixing ratio for a gas.
        
        Args:
            gas: Name of the gas
            w: Volume mixing ratio (scalar value between 0 and 1)
            
        Returns:
            Error message string, empty if successful
        """
        if w < 0.0 or w > 1.0:
            return 'GasConcentrations.set_vmr(): concentrations should be >= 0, <= 1'
            
        if gas not in self.concs:
            return f'GasConcentrations.set_vmr(): trying to set {gas} but name not provided at initialization'
            
        self.concs[gas] = np.full((1,1), w)
        return ''

    def set_vmr_1d(self, gas: str, w: np.ndarray) -> str:
        """Set 1D volume mixing ratio profile for a gas.
        
        Args:
            gas: Name of the gas
            w: Volume mixing ratio profile (1D array between 0 and 1)
            
        Returns:
            Error message string, empty if successful
        """
        if np.any((w < 0.0) | (w > 1.0)):
            return 'GasConcentrations.set_vmr: concentrations should be >= 0, <= 1'
            
        if self.nlay > 0:
            if len(w) != self.nlay:
                return 'GasConcentrations.set_vmr: different dimension (nlay)'
        else:
            self.nlay = len(w)
            
        if gas not in self.concs:
            return f'GasConcentrations.set_vmr(): trying to set {gas} but name not provided at initialization'
            
        self.concs[gas] = w.reshape(1, self.nlay)
        return ''

    def set_vmr_2d(self, gas: str, w: np.ndarray) -> str:
        """Set 2D volume mixing ratio field for a gas.
        
        Args:
            gas: Name of the gas 
            w: Volume mixing ratio field (2D array between 0 and 1)
            
        Returns:
            Error message string, empty if successful
        """
        if np.any((w < 0.0) | (w > 1.0)):
            return 'GasConcentrations.set_vmr: concentrations should be >= 0, <= 1'
            
        if self.ncol > 0 and w.shape[0] != self.ncol:
            return 'GasConcentrations.set_vmr: different dimension (ncol)'
        else:
            self.ncol = w.shape[0]
            
        if self.nlay > 0 and w.shape[1] != self.nlay:
            return 'GasConcentrations.set_vmr: different dimension (nlay)'
        else:
            self.nlay = w.shape[1]
            
        if gas not in self.concs:
            return f'GasConcentrations.set_vmr(): trying to set {gas} but name not provided at initialization'
            
        self.concs[gas] = w
        return ''

    def set_vmr(self, gas: str, w: Union[float, np.ndarray]) -> str:
        """Set volume mixing ratio for a gas.
        
        Dispatches to appropriate method based on input type.
        
        Args:
            gas: Name of the gas
            w: Volume mixing ratio (scalar, 1D array, or 2D array)
            
        Returns:
            Error message string, empty if successful
        """
        if isinstance(w, (int, float)):
            return self.set_vmr_scalar(gas, float(w))
        elif isinstance(w, np.ndarray):
            if w.ndim == 1:
                return self.set_vmr_1d(gas, w)
            elif w.ndim == 2:
                return self.set_vmr_2d(gas, w)
        raise ValueError("Input must be scalar or 1D/2D array")


gas_optics_lw = load_gas_optics(gas_optics_file=GasOpticsFiles.LW_G256)
lw_clouds = os.path.join(rte_rrtmgp_dir, "rrtmgp-clouds-lw.nc")
lw_aerosols = os.path.join(rte_rrtmgp_dir, "rrtmgp-aerosols-merra-lw.nc")

sfc_alb_dir = 0.06
sfc_alb_dif = 0.06
mu0 = 0.86

top_at_1 = True
t_sfc = t_lev[0, nlay if top_at_1 else 0]
emis_sfc = 0.98

# gas_concs%set_vmr("h2o", q )
# gas_concs%set_vmr("o3",  o3)
# gas_concs%set_vmr("co2", 348.e-6_wp)
# gas_concs%set_vmr("ch4", 1650.e-9_wp)
# gas_concs%set_vmr("n2o", 306.e-9_wp)
# gas_concs%set_vmr("n2",  0.7808_wp)
# gas_concs%set_vmr("o2",  0.2095_wp)
# gas_concs%set_vmr("co",  0._wp)

