In [42]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp 

import geometry
import humidity
import large_scale_condensation

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [43]:
hsg, dhs, fsg, dhsr, fsgr, sia_half, coa_half, sia, coa, radang, cosg, cosgr, \
cosgr2, coriol = geometry.initialize_geometry()

In [44]:
temp = jnp.array([[[273] * 96] * 48])
pressure = jnp.array([[[0.5] * 96] * 48])
sigma_levels = 4
qg = jnp.array([[[2] * 96] * 48])

qsat = humidity.get_qsat(temp, pressure, sigma_levels)

rh, qsat_new = humidity.spec_hum_to_rel_hum(temp, pressure, sigma_levels, qg)

qa, qsat_new = humidity.rel_hum_to_spec_hum(temp, pressure, sigma_levels, rh)

In [56]:
from physical_constants import p0, cp, alhc, alhs, grav
from geometry import fsg, dhs

# Constants for large-scale condensation
trlsc = 4.0   # Relaxation time (in hours) for specific humidity
rhlsc = 0.9   # Maximum relative humidity threshold (at sigma=1)
drhlsc = 0.1  # Vertical range of relative humidity threshold
rhblsc = 0.95 # Relative humidity threshold for boundary layer

def get_large_scale_condensation_tendencies(psa, qa, qsat, itop, fsg, dhs, p0, cp, alhc, grav):
    ix, il, kx = qa.shape

    # Initialize outputs
    dtlsc = jnp.zeros_like(qa)
    dqlsc = jnp.zeros_like(qa)
    precls = jnp.zeros((ix, il))

    # Constants for computation
    qsmax = 10.0
    rtlsc = 1.0 / (trlsc * 3600.0)
    tfact = alhc / cp
    prg = p0 / grav

    psa2 = psa ** 2.0

    # Compute sig2, rhref, and dqmax arrays
    sig2 = fsg**2.0
    rhref = rhlsc + drhlsc * (sig2 - 1.0)
    rhref = rhref.at[-1].set(jnp.maximum(rhref[-1], rhblsc))
    dqmax = qsmax * sig2 * rtlsc

    # Compute dqa array
    dqa = rhref[jnp.newaxis, jnp.newaxis, :] * qsat - qa

    # Calculate dqlsc and dtlsc where dqa < 0
    negative_dqa_mask = dqa < 0
    dqlsc = jnp.where(negative_dqa_mask, dqa * rtlsc, dqlsc)
    dtlsc = jnp.where(negative_dqa_mask, tfact * jnp.minimum(-dqlsc, dqmax[jnp.newaxis, jnp.newaxis, :] * psa2[:, :, jnp.newaxis]), dtlsc)

    # Update itop
    def update_itop(itop, indices, values):
        for idx, val in zip(zip(*indices), values):
            itop = itop.at[idx[:2]].set(jnp.minimum(itop[idx[:2]], val))
        return itop

    itop_update_indices = jnp.where(negative_dqa_mask)
    itop = update_itop(itop, itop_update_indices, itop_update_indices[2])

    # Large-scale precipitation
    pfact = dhs * prg
    precls -= jnp.sum(pfact[jnp.newaxis, jnp.newaxis, :] * dqlsc, axis=2)
    precls *= psa

    return itop, precls, dtlsc, dqlsc

# Example inputs
ix, il, kx = 1, 1, 8
psa = jnp.ones((ix, il))
qa = jnp.ones((ix, il, kx))
qsat = jnp.ones((ix, il, kx))
itop = jnp.full((ix, il), kx - 1)
#fsg = jnp.linspace(0.1, 1.0, kx)
#dhs = jnp.ones(kx)
#p0 = 1000.0
#cp = 1004.0
#alhc = 2.5e6
#grav = 9.81

# Call the function
itop, precls, dtlsc, dqlsc = get_large_scale_condensation_tendencies(psa, qa, qsat, itop, fsg, dhs, p0, cp, alhc, grav)

# Print the results
print("Precipitation due to large-scale condensation (precls):")
print(precls)
print("\nTemperature tendency due to large-scale condensation (dtlsc):")
print(dtlsc)
print("\nSpecific humidity tendency due to large-scale condensation (dqlsc):")
print(dqlsc)
print("\nUpdated cloud top (itop):")
print(itop)

Precipitation due to large-scale condensation (precls):
[[0.11387439]]

Temperature tendency due to large-scale condensation (dtlsc):
[[[1.0811790e-06 1.5612222e-05 3.3905773e-05 3.2597978e-05 3.0098290e-05
   2.6480671e-05 2.2536526e-05 8.6494329e-06]]]

Specific humidity tendency due to large-scale condensation (dqlsc):
[[[-1.38845508e-05 -1.38262167e-05 -1.36111139e-05 -1.30861135e-05
   -1.20826398e-05 -1.06303851e-05 -9.04704939e-06 -3.47222317e-06]]]

Updated cloud top (itop):
[[0]]


In [58]:
itop, precls, dtlsc, dqlsc = large_scale_condensation.get_large_scale_condensation_tendencies(psa, qa, qsat, itop)