In [2]:
import jax.numpy as jnp
import tree_math

ix = 5
il = 3
kx = 2

state = jnp.ones((ix,il,kx))

@tree_math.struct
class ConvectionData:
    psa: jnp.ndarray # normalized surface pressure 
    se: jnp.ndarray # dry static energy
    iptop: jnp.ndarray # Top of convection (layer index)

    def __init__(self, nodal_shape, node_levels,psa=None,se=None,iptop=None) -> None:
        if psa is not None:
            self.psa = psa
        else:
            self.psa = jnp.zeros((nodal_shape))
        if se is not None:
            self.se = se
        else:
            self.se = jnp.zeros((nodal_shape + (node_levels,)))
        if iptop is not None:
            self.iptop = iptop
        else:
            self.iptop = jnp.zeros((nodal_shape),dtype=int)

    def copy(self, psa=None, se=None, iptop=None):
        return ConvectionData(self.psa.shape, 
                              self.se.shape[-1], 
                              psa=psa if psa is not None else self.psa,
                              se=se if se is not None else self.se,
                              iptop=iptop if iptop is not None else self.iptop)
        

In [8]:
from jcm.physics_data import SWRadiationData

xy = (ix,il)

sw_rad = SWRadiationData(xy, 2)
print(sw_rad.dfabs)

[[[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]]


In [5]:
xy = (ix,il)

convection_data = ConvectionData(xy,kx)

print(convection_data.psa.shape)

(5, 3)


In [19]:
convection_data = ConvectionData(state.shape[0:2], state.shape[-1],psa=jnp.ones((ix,il))*2)
print(convection_data.se.shape)
convection_data2 = convection_data.copy(se=jnp.ones((ix,il,kx)))
print(convection_data2.se)
print(convection_data.psa)

(5, 3, 2)
[[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]]
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
