# Testing AutoDiff for functions in shortwave radiation

Here, I test how JAX VJP works with functions in the shortwave radiation fluxes file

In [1]:
import jax
import jax.numpy as jnp
from jax import grad
from jax import vjp
from jcm.physics import PhysicsTendency, PhysicsState
from jcm.physics_data import SWRadiationData, CondensationData, ConvectionData, HumidityData, SurfaceFluxData, DateData, PhysicsData
from jcm.shortwave_radiation import clouds, get_zonal_average_fields, solar

In [2]:
#Define gradient function

def grad_get_zonal_average_fields(f, state, physics_state): 
    '''
    Calculate the gradient of get_zonal_average_fields with respect to tyear
    '''
    primals, f_vjp = vjp(f, state, physics_state)
    tyear = 0.25
    ix, il, kx = state.temperature.shape
    date_data = DateData(tyear=tyear)
    datas = PhysicsData(xy,kx,date=date_data)
    xy = (ix, il)
    xyz = (ix, il, kx)
    tends = PhysicsTendency(jnp.ones_like(state.u_wind),jnp.ones_like(state.v_wind),jnp.ones_like(state.temperature),jnp.ones_like(state.temperature))
    input = (tends, datas)
    df_dtends, df_ddatas = f_vjp(input)
    return df_dtends, df_ddatas


In [5]:
#Testing gradient function of get_zonal_average_fields
ix = 96         # Number of longitudes
iy = 24         # Number of latitudes in hemisphere
il = 2 * iy     # Number of latitudes in full sphere
kx = 8          # Number of vertical levels
tyear = 0.25  # Example time of the year (spring equinox)
xy = (ix, il)
xyz = (ix, il, kx)
date_data = DateData(tyear=tyear)
physics_data = PhysicsData(xy,kx,date=date_data)
state = PhysicsState(jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xyz), jnp.zeros(xy))

df_dtends, df_ddatas = grad_get_zonal_average_fields(get_zonal_average_fields, tyear, ix, il)
print(df_dtends)
print(df_ddatas)

#Not great, outputting nans

(Array(nan, dtype=float32),)
