# Testing AutoDiff for functions in shortwave radiation

Here, I test how JAX VJP works with functions in the shortwave radiation fluxes file

In [1]:
import jax
import jax.numpy as jnp
from jax import grad
from jax import vjp
from jcm.shortwave_radiation import clouds, get_zonal_average_fields, solar

In [2]:
#Define gradient function

def grad_get_zonal_average_fields(f, tyear, ix, il): 
    '''
    Calculate the gradient of get_zonal_average_fields with respect to tyear

    Parameters:
    f : function
        Function to take the gradient of 
    tyear : float
        Time as fraction of year (0-1, 0 = 1 Jan)
    ix : integer 
        row dimension of the outputs of f
    il : integer  
        column dimension of the outputs of f

    Returns: 
    '''
    primals, f_vjp = vjp(f, tyear)
    input = (jnp.ones((ix, il)), jnp.ones((ix, il)), jnp.ones((ix, il)), jnp.ones((ix, il)), jnp.ones((ix, il)))
    df_dtyear = f_vjp(input)
    return df_dtyear


In [5]:
#Testing gradient function of get_zonal_average_fields
ix = 96         # Number of longitudes
iy = 24         # Number of latitudes in hemisphere
il = 2 * iy     # Number of latitudes in full sphere
tyear = 0.25  # Example time of the year (spring equinox)

df_dtyear = grad_get_zonal_average_fields(get_zonal_average_fields, tyear, ix, il)
print(df_dtyear)

#Not great, outputting nans

(Array(nan, dtype=float32),)
