# Testing AutoDiff for functions in shortwave radiation

Here, I test how JAX VJP works with functions in the shortwave radiation fluxes file

In [1]:
import jax
import jax.numpy as jnp
from jax import grad
from jax import vjp
from jcm.physics import PhysicsTendency, PhysicsState
from jcm.physics_data import LWRadiationData, SWRadiationData, CondensationData, ConvectionData, HumidityData, SurfaceFluxData, DateData, PhysicsData
from jcm.vertical_diffusion import get_vertical_diffusion_tend

In [3]:
# Test to take the gradient of get_vertical_diffusion_tend

# Defining input parameters and parameter dependencies 
ix = 4         # Number of longitudes
iy = 3         # Number of latitudes in hemisphere
il = 2 * iy     # Number of latitudes in full sphere
kx = 1          # Number of vertical levels
tyear = 0.25  # Example time of the year (spring equinox)
xy = (ix, il)
xyz = (ix, il, kx)
date_data = DateData(tyear=tyear)  # Define date data object
physics_data = PhysicsData(xy,kx,date=date_data)  # Create PhysicsData object (parameter)
state = PhysicsState(jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xy)) # Create PhysicsState object (parameter)

In [4]:

# First call get_vertical_diffusion_tend to see if the function is working properly
get_vertical_diffusion_tend(state, physics_data)

(PhysicsTendency(u_wind=Array([[[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]],
 
        [[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]],
 
        [[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]],
 
        [[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]]], dtype=float32), v_wind=Array([[[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]],
 
        [[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]],
 
        [[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]],
 
        [[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]]], dtype=float32), temperature=Array([[[-0.],
         [-0.],
         [-0.],
         [-0.],
         [-0.],
         [-0.]],
 
        [[-0.],
         [-0.],
         [-0.],
         [-0.],
         [-0.],
       

In [5]:

# Next, call vjp to start the process of getting the gradients of get_vertical_diffusion_tend
primals, f_vjp = vjp(get_vertical_diffusion_tend, state, physics_data)

# Should produce a "Shapes must be 1D sequences of concrete values of integer type" error
# which is coming from the line: self.ftop = ftop if ftop is not None else jnp.zeros((nodal_shape)) 
# and we're not sure why it's coming from that specific line

TypeError: Shapes must be 1D sequences of concrete values of integer type, got Traced<ConcreteArray([[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[4,6])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[4,6]), None)
    recipe = LambdaBinding().

# More notes: 
- In the get_vertical_diffusion_tend function, the only calls to the PhysicsData parameter are to get certain values from the object
- We think that when JAX takes the gradient, it is transforming the nodal_shape parameter in PhysicsData to a Traced Array
- The error is coming from the PhysicsData constructor (not the copy)