In [30]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp 

import geometry
import humidity
import large_scale_condensation

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
hsg, dhs, fsg, dhsr, fsgr, sia_half, coa_half, sia, coa, radang, cosg, cosgr, \
cosgr2, coriol = geometry.initialize_geometry()

In [32]:
temp = jnp.array([[[273] * 96] * 48])
pressure = jnp.array([[[0.5] * 96] * 48])
sigma_levels = 4
qg = jnp.array([[[2] * 96] * 48])

qsat = humidity.get_qsat(temp, pressure, sigma_levels)

rh, qsat_new = humidity.spec_hum_to_rel_hum(temp, pressure, sigma_levels, qg)

qa, qsat_new = humidity.rel_hum_to_spec_hum(temp, pressure, sigma_levels, rh)

In [33]:
from physical_constants import p0, cp, alhc, alhs, grav
from geometry import fsg, dhs

# Constants for large-scale condensation
trlsc = 4.0   # Relaxation time (in hours) for specific humidity
rhlsc = 0.9   # Maximum relative humidity threshold (at sigma=1)
drhlsc = 0.1  # Vertical range of relative humidity threshold
rhblsc = 0.95 # Relative humidity threshold for boundary layer

def get_large_scale_condensation_tendencies(psa, qa, qsat, itop, fsg, dhs, p0, cp, alhc, grav):
    ix, il, kx = qa.shape

    # Initialize outputs
    dtlsc = jnp.zeros_like(qa)
    dqlsc = jnp.zeros_like(qa)
    precls = jnp.zeros((ix, il))

    # Constants for computation
    qsmax = 10.0
    rtlsc = 1.0 / (trlsc * 3600.0)
    tfact = alhc / cp
    prg = p0 / grav

    psa2 = psa ** 2.0

    # Compute sig2, rhref, and dqmax arrays
    sig2 = fsg**2.0
    rhref = rhlsc + drhlsc * (sig2 - 1.0)
    rhref = rhref.at[-1].set(jnp.maximum(rhref[-1], rhblsc))
    dqmax = qsmax * sig2 * rtlsc

    # Compute dqa array
    dqa = rhref[jnp.newaxis, jnp.newaxis, :] * qsat - qa

    # Calculate dqlsc and dtlsc where dqa < 0
    negative_dqa_mask = dqa < 0
    dqlsc = jnp.where(negative_dqa_mask, dqa * rtlsc, dqlsc)
    dtlsc = jnp.where(negative_dqa_mask, tfact * jnp.minimum(-dqlsc, dqmax[jnp.newaxis, jnp.newaxis, :] * psa2[:, :, jnp.newaxis]), dtlsc)

    # Update itop
    def update_itop(itop, indices, values):
        for idx, val in zip(zip(*indices), values):
            itop = itop.at[idx[:2]].set(jnp.minimum(itop[idx[:2]], val))
        return itop

    itop_update_indices = jnp.where(negative_dqa_mask)
    itop = update_itop(itop, itop_update_indices, itop_update_indices[2])

    # Large-scale precipitation
    pfact = dhs * prg
    precls -= jnp.sum(pfact[jnp.newaxis, jnp.newaxis, :] * dqlsc, axis=2)
    precls *= psa

    return itop, precls, dtlsc, dqlsc

# Example inputs
ix, il, kx = 1, 1, 8
psa = jnp.ones((ix, il))
qa = jnp.ones((ix, il, kx))
qsat = jnp.ones((ix, il, kx))
itop = jnp.full((ix, il), kx - 1)
#fsg = jnp.linspace(0.1, 1.0, kx)
#dhs = jnp.ones(kx)
#p0 = 1000.0
#cp = 1004.0
#alhc = 2.5e6
#grav = 9.81

# Call the function
itop, precls, dtlsc, dqlsc = get_large_scale_condensation_tendencies(psa, qa, qsat, itop, fsg, dhs, p0, cp, alhc, grav)

# Print the results
print("Precipitation due to large-scale condensation (precls):")
print(precls)
print("\nTemperature tendency due to large-scale condensation (dtlsc):")
print(dtlsc)
print("\nSpecific humidity tendency due to large-scale condensation (dqlsc):")
print(dqlsc)
print("\nUpdated cloud top (itop):")
print(itop)

Precipitation due to large-scale condensation (precls):
[[0.11387439]]

Temperature tendency due to large-scale condensation (dtlsc):
[[[1.0811790e-06 1.5612222e-05 3.3905773e-05 3.2597978e-05 3.0098290e-05
   2.6480671e-05 2.2536526e-05 8.6494329e-06]]]

Specific humidity tendency due to large-scale condensation (dqlsc):
[[[-1.38845508e-05 -1.38262167e-05 -1.36111139e-05 -1.30861135e-05
   -1.20826398e-05 -1.06303851e-05 -9.04704939e-06 -3.47222317e-06]]]

Updated cloud top (itop):
[[0]]


In [34]:
itop, precls, dtlsc, dqlsc = large_scale_condensation.get_large_scale_condensation_tendencies(psa, qa, qsat, itop)

In [35]:
import jax.numpy as jnp
from jax import jit, random
from physical_constants import alhc, wvi

# Constants for large-scale condensation
trlsc = 4.0   # Relaxation time (in hours) for specific humidity
rhlsc = 0.9   # Maximum relative humidity threshold (at sigma=1)
drhlsc = 0.1  # Vertical range of relative humidity threshold
rhblsc = 0.95 # Relative humidity threshold for boundary layer

@jit
def diagnose_convection(psa, se, qa, qsat, psmin, rhbl, wvi, alhc):
    ix, il, kx = se.shape
    itop = jnp.full((ix, il), kx + 1, dtype=int)  # Initialize itop with nlp
    qdif = jnp.zeros((ix, il), dtype=float)
    
    # Saturation moist static energy
    mss = se + alhc * qsat
    
    rlhc = 1.0 / alhc

    # Mask for psa > psmin
    mask_psa = psa > psmin

    mse0 = se[:, :, kx-1] + alhc * qa[:, :, kx-1]
    mse1 = se[:, :, kx-2] + alhc * qa[:, :, kx-2]
    mse1 = jnp.minimum(mse0, mse1)

    mss0 = jnp.maximum(mse0, mss[:, :, kx-1])

    # Compute mss2 array for all k layers (3 to kx-3)
    k_indices = jnp.arange(3, kx-3, dtype=int)
    mss2 = mss[:, :, k_indices] + wvi[k_indices, 1] * (mss[:, :, k_indices + 1] - mss[:, :, k_indices])
    
    # Check 1: conditional instability
    mask_conditional_instability = mss0[:, :, None] > mss2
    ktop1 = jnp.full((ix, il), kx, dtype=int)
    ktop1 = k_indices[jnp.argmax(mask_conditional_instability, axis=2)]

    # Check 2: gradient of actual moist static energy
    mask_mse1_greater_mss2 = mse1[:, :, None] > mss2
    ktop2 = jnp.full((ix, il), kx, dtype=int)
    ktop2 = k_indices[jnp.argmax(mask_mse1_greater_mss2, axis=2)]
    msthr = jnp.zeros((ix, il), dtype=float)
    msthr = mss2[jnp.arange(ix)[:, None], jnp.arange(il), jnp.argmax(mask_mse1_greater_mss2, axis=2)]

    # Check 3: RH > RH_c at both k=kx and k=kx-1
    qthr0 = rhbl * qsat[:, :, kx-1]
    qthr1 = rhbl * qsat[:, :, kx-2]
    lqthr = (qa[:, :, kx-1] > qthr0) & (qa[:, :, kx-2] > qthr1)

    # Applying masks to itop and qdif
    mask_ktop1_less_kx = ktop1 < kx
    mask_ktop2_less_kx = ktop2 < kx

    combined_mask1 = mask_ktop1_less_kx & mask_ktop2_less_kx
    itop = jnp.where(combined_mask1, ktop1, itop)
    qdif = jnp.where(combined_mask1, jnp.maximum(qa[:, :, kx-1] - qthr0, (mse0 - msthr) * rlhc), qdif)

    combined_mask2 = mask_ktop1_less_kx & lqthr & ~combined_mask1
    itop = jnp.where(combined_mask2, ktop1, itop)
    qdif = jnp.where(combined_mask2, qa[:, :, kx-1] - qthr0, qdif)

    return itop, qdif

# Example usage:
# Define constants and inputs (these would need to be provided or calculated as appropriate)
#alhc = 1.0  # Example value, replace with actual
rhbl = 1.0  # Example value, replace with actual
psmin = 1.0  # Example value, replace with actual

# Example dimensions
ix, il, kx = 10, 10, 80

# Random key for reproducibility
key = random.PRNGKey(0)

# Example value arrays, replace with actual values
psa = random.uniform(key, (ix, il))
se = random.uniform(key, (ix, il, kx))
qa = random.uniform(key, (ix, il, kx))
qsat = random.uniform(key, (ix, il, kx))
wvi = random.uniform(key, (kx, 3))

itop, qdif = diagnose_convection(psa, se, qa, qsat, psmin, rhbl, wvi, alhc)
print("itop:\n", itop)
print("qdif:\n", qdif)


itop:
 [[ 5  4  3  3  3  6  3  3  3  3]
 [63  4  6  3  7  3  3  3  3  3]
 [41 10  4  4 20  3  3  3  3  6]
 [ 5 62  3  3  3  4  3  3  3 33]
 [ 3  3  3  3  6 10  3  7  3  6]
 [ 3  4  4 22 33  4  3  3 29  3]
 [ 3  3  3  4 10  3 13 27  4  3]
 [20  5  3 21 10  3  3  3 11  3]
 [ 4 22  6 11  7  3  3  3  3 38]
 [ 3  3  6  3 10  3 72  6  3  3]]
qdif:
 [[0.2379781  0.31681213 0.20483765 0.37375352 0.35226506 0.
  0.68903154 0.61212    0.154889   0.4957041 ]
 [0.01243855 0.49612427 0.08989421 0.7335642  0.11288419 0.
  0.66539276 0.67531276 0.41168654 0.15678015]
 [0.01546369 0.00209293 0.40730274 0.02478252 0.11370323 0.
  0.20390245 0.04800238 0.40182525 0.0283331 ]
 [0.09192085 0.00690194 0.16491126 0.6718196  0.3445324  0.04116547
  0.5424763  0.46149874 0.25512344 0.09200244]
 [0.7380173  0.5086936  0.41437596 0.18492189 0.2850928  0.24283719
  0.4429484  0.05564942 0.02827912 0.05362209]
 [0.00589981 0.05816976 0.12481099 0.01371972 0.05169467 0.1002105
  0.07322574 0.44911122 0.03536947 0.

In [41]:
import jax.numpy as jnp

from physical_constants import alhc, wvi
from geometry import dhs, fsg

def get_convection_tendencies(psa, se, qa, qsat, itop, cbmf, precnv, dfse, dfqa):
    ix, il, kx = se.shape

    # Entrainment profile (up to sigma = 0.5)
    entr = jnp.maximum(0.0, fsg[1:kx-1] - 0.5)**2.0
    sentr = jnp.sum(entr)
    entr *= entmax / sentr

    # Diagnose convection
    itop, qdif = diagnose_convection(psa, se, qa, qsat, psmin, rhbl, wvi, alhc)

    mask = itop < kx
    qmax = jnp.maximum(1.01 * qa[:, :, -1], qsat[:, :, -1])
    sb = se[:, :, -2] + wvi[-2, 1] * (se[:, :, -1] - se[:, :, -2])
    qb = jnp.minimum(qa[:, :, -2] + wvi[-2, 1] * (qa[:, :, -1] - qa[:, :, -2]), qa[:, :, -1])

    fqmax = 5.0
    fm0 = p0 * dhs[-1] / (grav * trcnv * 3600.0)
    rdps = 2.0 / (1.0 - psmin)

    fpsa = psa * jnp.minimum(1.0, (psa - psmin) * rdps)
    fmass = fm0 * fpsa * jnp.minimum(fqmax, qdif / (qmax - qb))
    cbmf = jnp.where(mask, fmass, cbmf)

    fus = fmass * se[:, :, -1]
    fuq = fmass * qmax
    fds = fmass * sb
    fdq = fmass * qb

    dfse = dfse.at[:, :, -1].set(fds - fus)
    dfqa = dfqa.at[:, :, -1].set(fdq - fuq)

    # Create an array of k values to use for broadcasting
    k_vals = jnp.arange(kx-2, 0, -1)

    # Initialize fmass, fus, and fuq arrays for broadcasting
    fmass_broadcast = jnp.tile(fmass[:, :, jnp.newaxis], (1, 1, len(k_vals)))
    fus_broadcast = jnp.tile(fus[:, :, jnp.newaxis], (1, 1, len(k_vals)))
    fuq_broadcast = jnp.tile(fuq[:, :, jnp.newaxis], (1, 1, len(k_vals)))

    # Calculate sb and qb for each layer in the loop using broadcasting
    sb_vals = se[:, :, k_vals-1] + wvi[k_vals-1, 1] * (se[:, :, k_vals] - se[:, :, k_vals-1])
    qb_vals = qa[:, :, k_vals-1] + wvi[k_vals-1, 1] * (qa[:, :, k_vals] - qa[:, :, k_vals-1])

    enmass = entr[k_vals-1] * psa[:, :, jnp.newaxis] * cbmf[:, :, jnp.newaxis]

    fmass_broadcast += enmass
    fus_broadcast += enmass * se[:, :, k_vals]
    fuq_broadcast += enmass * qa[:, :, k_vals]

    fds_vals = fmass_broadcast * sb_vals
    fdq_vals = fmass_broadcast * qb_vals

    dfse = dfse.at[:, :, k_vals].set(fus_broadcast - fds_vals)
    dfqa = dfqa.at[:, :, k_vals].set(fuq_broadcast - fdq_vals)

    delq_vals = rhil * qsat[:, :, k_vals] - qa[:, :, k_vals]
    fsq_vals = jnp.where(delq_vals > 0, smf * cbmf[:, :, jnp.newaxis] * delq_vals, 0.0)

    dfqa = dfqa.at[:, :, k_vals].add(fsq_vals)
    dfqa = dfqa.at[:, :, -1].add(-jnp.sum(fsq_vals, axis=-1))

    # Top layer (condensation and detrainment)
    k = itop
    qsatb = qsat[:, :, k] + wvi[k, 1] *(qsat[:, :, k+1]-qsat[:, :, k])
    precnv = jnp.where(mask, jnp.maximum(fuq - fmass * qsatb, 0.0), precnv)

    dfse = fus - fds + alhc * precnv
    dfqa = fuq - fdq - precnv

    return dfse, dfqa, cbmf, precnv

# Example data
ix, il, kx = 4, 4, 10
psa = random.uniform(key, (ix, il))
se = random.uniform(key, (ix, il, kx))
qa = random.uniform(key, (ix, il, kx))
qsat = random.uniform(key, (ix, il, kx))
wvi = random.uniform(key, (kx, 3))
itop = jnp.ones((ix, il), dtype=int) * (kx - 2)
cbmf = jnp.ones((ix, il))
precnv = jnp.ones((ix, il))
dfse = jnp.ones((ix, il, kx))
dfqa = jnp.ones((ix, il, kx))
p0 = 1000.0
alhc = 2.5e6
wvi = jnp.ones((kx, 2))
grav = 9.81
fsg = jnp.ones(kx)
dhs = jnp.ones(kx + 1)
trcnv = 1.0
psmin = 0.1
entmax = 0.1
rhil = 0.8
smf = 0.1

# Run the function with example data
dfse, dfqa, cbmf, precnv = get_convection_tendencies(psa, se, qa, qsat, itop, cbmf, precnv, dfse, dfqa)

print("dfse:", dfse)
print("dfqa:", dfqa)
print("cbmf:", cbmf)
print("precnv:", precnv)

dfse: [[[[0.00000000e+00 2.03904953e+05 1.16675406e+05 1.92730094e+05]
   [1.59761592e+04 0.00000000e+00 0.00000000e+00 0.00000000e+00]
   [0.00000000e+00 8.90240039e+03 0.00000000e+00 1.72552754e+04]
   [2.24424536e+03 2.71414398e+02 0.00000000e+00 0.00000000e+00]]

  [[0.00000000e+00 6.76733828e+04 3.02986250e+04 7.09461016e+04]
   [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
   [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
   [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]

  [[0.00000000e+00 2.98579625e+05 1.76703312e+05 2.77364375e+05]
   [7.74353359e+04 0.00000000e+00 0.00000000e+00 0.00000000e+00]
   [0.00000000e+00 2.47236641e+04 0.00000000e+00 0.00000000e+00]
   [2.40779219e+04 1.13734204e+03 0.00000000e+00 0.00000000e+00]]

  [[0.00000000e+00 3.74265195e+04 1.11207734e+04 4.39069609e+04]
   [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
   [0.00000000e+00 0.00000000e+00 0.00000000e+00 3.53373047e+04]
   [0.0000000

In [37]:
import convection

key = random.PRNGKey(0)
ix, il, kx = 4, 4, 10
psa = random.uniform(key, (ix, il))
se = random.uniform(key, (ix, il, kx))
qa = random.uniform(key, (ix, il, kx))
qsat = random.uniform(key, (ix, il, kx))
itop = jnp.ones((ix, il), dtype=int) * (kx - 2)

dfse = jnp.zeros_like(dfse)
dfqa = jnp.zeros_like(dfqa)
cbmf = jnp.zeros_like(cbmf)
precnv = jnp.zeros_like(precnv)
        
print(dfse, dfqa, cbmf, precnv)

dfse, dfqa, cbmf, precnv = convection.get_convection_tendencies(psa, se, qa, qsat, itop, cbmf, precnv, dfse, dfqa)


[[[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]


 [[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]


 [[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]]


 [[[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]
   [0. 0. 0. 0.]]

  [[0. 0. 0. 0

ValueError: Incompatible shapes for broadcasting: (4, 4, 8) and requested shape (4, 4, 8, 4)