In [14]:
import jax.numpy as jnp
import tree_math

@tree_math.struct
class LWRadiationData:
    rlds: jnp.ndarray # Downward flux of long-wave radiation at the surface
    dfabs: jnp.ndarray # Flux of long-wave radiation absorbed in each atmospheric layer
    ftop: jnp.ndarray
    slr: jnp.ndarray

    @classmethod
    def zeros(self, shape):
        return LWRadiationData(
            rlds=jnp.zeros(shape[0:2]),
            dfabs=jnp.zeros(shape),
            ftop=jnp.zeros(shape[0:2]),
            slr=jnp.zeros(shape[0:2]))

    @classmethod
    def copy(self, rlds=None, dfabs=None, ftop=None, slr=None):
        return LWRadiationData(
            rlds=rlds if rlds is not None else self.rlds,
            dfabs=dfabs if dfabs is not None else self.dfabs,
            ftop=ftop if ftop is not None else self.ftop,
            slr=slr if slr is not None else self.slr
        )

In [20]:
shape = (2,4,6)
test = LWRadiationData(jnp.ones(shape[0:2]), jnp.ones(shape), jnp.ones(shape[0:2]), jnp.ones(shape[0:2]))
test_copy = test.copy(rlds=jnp.zeros(shape[0:2]))
print(test_copy.slr)

[[1. 1. 1. 1.]
 [1. 1. 1. 1.]]
