In [65]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp 

import geometry
import humidity

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [66]:
hsg, dhs, fsg, dhsr, fsgr, sia_half, coa_half, sia, coa, radang, cosg, cosgr, \
cosgr2, coriol = geometry.initialize_geometry()

In [67]:
temp = jnp.array([[[273] * 96] * 48])
pressure = jnp.array([[[0.5] * 96] * 48])
sigma_levels = 4
qg = jnp.array([[[2] * 96] * 48])

qsat = humidity.get_qsat(temp, pressure, sigma_levels)

rh, qsat_new = humidity.spec_hum_to_rel_hum(temp, pressure, sigma_levels, qg)

qa, qsat_new = humidity.rel_hum_to_spec_hum(temp, pressure, sigma_levels, rh)

In [4]:
import jax.numpy as jnp
from physical_constants import alhc, cp, p0, grav
import jax
from geometry import fsg, dhs


def get_large_scale_condensation_tendencies(psa, qa, qsat, itop):
    # Constants
    qsmax = 10.0
    rtlsc = 1.0 / (trlsc * 3600.0)
    tfact = alhc / cp
    prg = p0 / grav

    # Initialization
    dtlsc = jnp.zeros_like(qsat)
    dqlsc = jnp.zeros_like(qsat)
    precls = jnp.zeros_like(psa)

    psa2 = psa**2.0

    # Tendencies of temperature and moisture
    k_values = jnp.arange(2, kx + 1)
    sig2 = fsg[k_values - 1]**2.0
    rhref = rhlsc + drhlsc * (sig2 - 1.0)
    rhref = jnp.where(k_values == kx, jnp.maximum(rhref, rhblsc), rhref)
    dqmax = qsmax * sig2 * rtlsc

    dqa = rhref[..., jnp.newaxis] * qsat - qa
    itop = jnp.minimum(kx, itop)
    dqlsc = jnp.where(dqa < 0.0, dqa * rtlsc, 0.0)
    # dtlsc = jnp.where(dqa < 0.0, tfact * jnp.minimum(-dqlsc, dqmax * psa2), 0.0)

    return dqlsc

# Example data
kx = 5
ix = 3
il = 4

trlsc  = 4.0  
rhlsc  = 0.9  
drhlsc = 0.1  
rhblsc = 0.95

psa = jax.random.uniform(jax.random.PRNGKey(0), shape=(ix, il))
qa = jax.random.uniform(jax.random.PRNGKey(1), shape=(ix, il, kx))
qsat = jax.random.uniform(jax.random.PRNGKey(2), shape=(ix, il, kx))
itop = jnp.full((ix, il), kx)  # Assuming initial values are all kx

# Call the function
precls, dtlsc, dqlsc = get_large_scale_condensation_tendencies(psa, qa, qsat, itop)

# Print the result or perform further checks
print("precls:", precls)
print("dtlsc:", dtlsc)
print("dqlsc:", dqlsc)