# Testing AutoDiff for functions in shortwave radiation

Here, I test how JAX VJP works with functions in the shortwave radiation fluxes file

In [1]:
import jax
import jax.numpy as jnp
from jax import grad
from jax import vjp
from jcm.model import initialize_modules

In [2]:
# Defining input parameters and parameter dependencies 
ix = 4         # Number of longitudes
iy = 3         # Number of latitudes in hemisphere
il = 2 * iy     # Number of latitudes in full sphere
kx = 5          # Number of vertical levels
tyear = 0.25  # Example time of the year (spring equinox)
xy = (ix, il)
xyz = (ix, il, kx)
initialize_modules(kx = kx, il = il)

In [3]:
from jcm.physics import PhysicsTendency, PhysicsState
from jcm.physics_data import LWRadiationData, SWRadiationData, CondensationData, ConvectionData, HumidityData, SurfaceFluxData, DateData, PhysicsData
from jcm.vertical_diffusion import get_vertical_diffusion_tend

In [4]:
date_data = DateData(tyear=tyear)  # Define date data object
physics_data = PhysicsData.zeros(xy,kx,date=date_data)  # Create PhysicsData object (parameter)
#state = PhysicsState(jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xy)) # Create PhysicsState object (parameter)
state =PhysicsState.zeros(xyz)

In [5]:
# First call get_vertical_diffusion_tend to see if the function is working properly
get_vertical_diffusion_tend(state, physics_data)

(PhysicsTendency(u_wind=Array([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],
 
        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],
 
        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],
 
        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], dtype=float32), v_wind=Array([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],
 
        [[0

In [6]:

# Next, call vjp to start the process of getting the gradients of get_vertical_diffusion_tend
primals, f_vjp = vjp(get_vertical_diffusion_tend, state, physics_data)

# Testing gradient function

In [7]:
def grad_get_vertical_diffusion_tend(f, state, physics_data): 
    '''
    Calculate the gradient of get_vertical_diffusion_tend with respect to physics state and physics data
    '''
    primals, f_vjp = vjp(f, state, physics_data) 
    tends = primals[0].copy(jnp.ones_like(primals[0].u_wind),jnp.ones_like(primals[0].v_wind),
                            jnp.ones_like(primals[0].temperature),jnp.ones_like(primals[0].specific_humidity))
    datas = primals[1].copy()  #Note: would like to include a ones function to get accurate gradients
    input = (tends, datas)
    df_dtends, df_ddatas = f_vjp(input)
    return df_dtends, df_ddatas

In [8]:
df_dtends, df_ddatas = grad_get_vertical_diffusion_tend(get_vertical_diffusion_tend, state, physics_data)