In [2]:
import jax.numpy as jnp
import tree_math

ix = 5
il = 3
kx = 2

state = jnp.ones((ix,il,kx))

@tree_math.struct
class ConvectionData:
    psa: jnp.ndarray # normalized surface pressure 
    se: jnp.ndarray # dry static energy
    iptop: jnp.ndarray # Top of convection (layer index)

    def __init__(self, nodal_shape, node_levels,psa=None,se=None,iptop=None) -> None:
        if psa is not None:
            self.psa = psa
        else:
            self.psa = jnp.zeros((nodal_shape))
        if se is not None:
            self.se = se
        else:
            self.se = jnp.zeros((nodal_shape + (node_levels,)))
        if iptop is not None:
            self.iptop = iptop
        else:
            self.iptop = jnp.zeros((nodal_shape),dtype=int)

    def copy(self, psa=None, se=None, iptop=None):
        return ConvectionData(self.psa.shape, 
                              self.se.shape[-1], 
                              psa=psa if psa is not None else self.psa,
                              se=se if se is not None else self.se,
                              iptop=iptop if iptop is not None else self.iptop)
        

In [32]:
import jax
import jax.numpy as jnp
from jcm.geometry import fsg

# this needs to go in humidity.py -- explain the scalar v array thing in vmap 
def get_qsat(ta, ps, sig):
    """
    Computes saturation specific humidity.
    
    Args:
        ta: Absolute temperature [K]
        ps: Normalized pressure (p/1000 hPa)
        sig: Sigma level
        
    Returns:
        qsat: Saturation specific humidity (g/kg)
    """
    
    e0 = 6.108e-3
    c1 = 17.269
    c2 = 21.875
    t0 = 273.16
    t1 = 35.86
    t2 = 7.66

    # Computing qsat for each grid point
    # 1. Compute Qsat (g/kg) from T (degK) and normalized pres. P (= p/1000_hPa)
    
    qsat = jnp.where(ta >= t0, e0 * jnp.exp(c1 * (ta - t0) / (ta - t1)), 
                      e0 * jnp.exp(c2 * (ta - t0) / (ta - t2)))
    
    # If sig > 0, P = Ps * sigma, otherwise P = Ps(1) = const.
    qsat = jnp.where(sig <= 0.0, 622.0 * qsat / (ps[0,0] - 0.378 * qsat), 
                      622.0 * qsat / (sig * ps - 0.378 * qsat))
    

temp = jnp.ones((96,48,8))*273
psa = jnp.ones((96,48,8))*0.5
qg = jnp.ones((96,48,8))*2

get_qsat_lambda = lambda ta, ps, fg: get_qsat(ta, ps, fg)
map_qsat = jax.vmap(get_qsat_lambda, in_axes=(2, 2, 0), out_axes=2) # mapping over dim 2 for arguments ta, ps and over dim 0 (the only dim) for fsg, mapping over dim 2 of the output
qsat = map_qsat(temp, psa, fsg) #need to check that this produces ix x il x kx array


In [8]:
from jcm.physics_data import SWRadiationData

xy = (ix,il)

sw_rad = SWRadiationData(xy, 2)
print(sw_rad.dfabs)

[[[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]]


In [5]:
xy = (ix,il)

convection_data = ConvectionData(xy,kx)

print(convection_data.psa.shape)

(5, 3)


In [19]:
convection_data = ConvectionData(state.shape[0:2], state.shape[-1],psa=jnp.ones((ix,il))*2)
print(convection_data.se.shape)
convection_data2 = convection_data.copy(se=jnp.ones((ix,il,kx)))
print(convection_data2.se)
print(convection_data.psa)

(5, 3, 2)
[[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]]
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
