# Reproducing speedy gradient check bug 

In [None]:
from jax.test_util import check_vjp, check_jvp
import functools
import jax
import jax.numpy as jnp
from jcm.model import Model
from jcm.utils import convert_back, convert_to_float
from jcm.geometry import Geometry

from jcm.boundaries import BoundaryData
from jcm.physics.speedy.physics_data import SurfaceFluxData, HumidityData, ConvectionData, CondensationData, SWRadiationData, DateData, PhysicsData
from jcm.physics_interface import PhysicsState
from jcm.physics.speedy.shortwave_radiation import get_clouds, get_zonal_average_fields
from jcm.physics.speedy.params import Parameters
from jcm.geometry import Geometry

In [None]:
# Model run output nan error
# Create model that goes through one timestep
model = Model()
state = model._prepare_initial_modal_state()

state_floats = convert_to_float(state)

def f(state_f):
    state_f_mapped = jax.tree.map(jnp.asarray, state_f)
    _ = model.run(total_time=0) # to set up model fields
    predictions = model.run(initial_state=convert_back(state_f_mapped, state), save_interval=(1/48.), total_time=(1/48.))
    return convert_to_float(model._final_modal_state), convert_to_float(predictions)

# Calculate gradient
f_jvp = functools.partial(jax.jvp, f)
f_vjp = functools.partial(jax.vjp, f) 

check_vjp(f, f_vjp, args = (state,), 
                        atol=None, rtol=1, eps=0.00001)
check_jvp(f, f_jvp, args = (state,), 
                        atol=None, rtol=1, eps=0.001)

2025-10-07 13:41:32.694849: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_multistep] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2025-10-07 13:42:52.713550: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3m20.024583s

********************************
[Compiling module jit_multistep] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


AssertionError: 
Not equal to tolerance rtol=16632, atol=33.264
tangent
x and y nan location mismatch:
 x: array([[[ 1.764052e+00,  3.310613e-01,  9.517571e-01, ...,
          3.276176e-02,  3.493220e-01, -7.208256e-01],
        [-1.980796e+00, -2.283787e-01,  1.765298e-01, ...,...
 y: array([[[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],...

In [None]:
# Reproductin shortwave radiation nan errors in gradient check unit test
# Function setup
ix, il, kx = 96, 48, 8
parameters = Parameters.default()
geometry = Geometry.from_grid_shape((ix, il), kx)
boundaries = BoundaryData.zeros((ix, il))

qa = 0.5 * 1000. * jnp.array([0., 0.00035438, 0.00347954, 0.00472337, 0.00700214,0.01416442,0.01782708, 0.0216505])
qsat = 1000. * jnp.array([0., 0.00037303, 0.00366268, 0.00787228, 0.01167024, 0.01490992, 0.01876534, 0.02279])
rh = qa/qsat
geopotential = jnp.arange(7, -1, -1, dtype = float)
se = .1*geopotential

xy = (ix, il)
zxy = (kx, ix, il)
broadcast = lambda a: jnp.tile(a[:, jnp.newaxis, jnp.newaxis], (1,) + xy)
qa, qsat, rh, geopotential, se = broadcast(qa), broadcast(qsat), broadcast(rh), broadcast(geopotential), broadcast(se)

psa = jnp.ones(xy)
precnv = -1.0 * jnp.ones(xy)
precls = 4.0 * jnp.ones(xy)
iptop = 8 * jnp.ones(xy, dtype=int)
fmask = .7 * jnp.ones(xy)

surface_flux = SurfaceFluxData.zeros(xy)
humidity = HumidityData.zeros(xy, kx, rh=rh, qsat=qsat)
convection = ConvectionData.zeros(xy, kx, iptop=iptop, precnv=precnv, se=se)
condensation = CondensationData.zeros(xy, kx, precls=precls)
sw_data = SWRadiationData.zeros(xy, kx, compute_shortwave=True)

date_data = DateData.zeros()
date_data.tyear = 0.6

physics_data = PhysicsData.zeros(xy,kx,surface_flux=surface_flux, humidity=humidity, convection=convection, condensation=condensation, shortwave_rad=sw_data, date=date_data)
state = PhysicsState.zeros(zxy, specific_humidity=qa, geopotential=geopotential, normalized_surface_pressure=psa)
# boundaries = BoundaryData.zeros(xy, fmask=fmask)

In [16]:
# Reproducing nan error in check_jvp (gives nans for finite differencing approximation)
# Set float inputs
physics_data_floats = convert_to_float(physics_data)
state_floats = convert_to_float(state)
parameters_floats = convert_to_float(parameters)
boundaries_floats = convert_to_float(boundaries)
geometry_floats = convert_to_float(geometry)

def f(physics_data_f, state_f, parameters_f, boundaries_f,geometry_f):
    tend_out, data_out = get_clouds(physics_data=convert_back(physics_data_f, physics_data), 
                                state=convert_back(state_f, state), 
                                parameters=convert_back(parameters_f, parameters), 
                                boundaries=convert_back(boundaries_f, boundaries), 
                                geometry=convert_back(geometry_f, geometry)
                                )
    return convert_to_float(data_out)

# Calculate gradient
f_jvp = functools.partial(jax.jvp, f)
f_vjp = functools.partial(jax.vjp, f)  


check_jvp(f, f_jvp, args = (physics_data_floats, state_floats, parameters_floats, boundaries_floats, geometry_floats), 
                        atol=None, rtol=1, eps=0.000001)

AssertionError: 
Not equal to tolerance rtol=36864, atol=73.728
tangent
x and y nan location mismatch:
 x: array([[[-0.593189,  0.006917,  0.107968, ...,  1.646269,  0.343626,
         -0.079846],
        [-0.408734,  1.005633,  0.369222, ..., -1.093404,  0.340922,...
 y: array([[[      nan,       nan,       nan, ...,       nan,       nan,
               nan],
        [      nan,       nan,       nan, ...,       nan,       nan,...

In [17]:
# Reproducing nan error in check_vjp (gives nans for finite differencing approximation)
check_vjp(f, f_vjp, args = (physics_data_floats, state_floats, parameters_floats, boundaries_floats, geometry_floats), 
                                atol=None, rtol=1, eps=0.00001)

AssertionError: 
Not equal to tolerance rtol=1, atol=0.002
cotangent projection
x and y nan location mismatch:
 x: array(-554.49866, dtype=float32)
 y: array(nan, dtype=float32)

In [18]:
# Reproducing nan error in check_jvp (gives nans for jax gradient calculation)
_, physics_data = get_clouds(state, physics_data, parameters, boundaries, geometry)

# Set float inputs
physics_data_floats = convert_to_float(physics_data)

def f(physics_data_f, state_f, boundaries_f,geometry_f):
    data_out = get_zonal_average_fields(physics_data=convert_back(physics_data_f, physics_data), 
                                state=convert_back(state_f, state), 
                                boundaries=convert_back(boundaries_f, boundaries), 
                                geometry=convert_back(geometry_f, geometry)
                                )
    return convert_to_float(data_out)

# Calculate gradient
f_jvp = functools.partial(jax.jvp, f)
f_vjp = functools.partial(jax.vjp, f)  

check_jvp(f, f_jvp, args = (physics_data_floats, state_floats, boundaries_floats, geometry_floats), 
                        atol=None, rtol=1, eps=0.0001)

AssertionError: 
Not equal to tolerance rtol=4608, atol=9.216
tangent
x and y nan location mismatch:
 x: array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],...
 y: array([[    0.    ,     0.    ,     0.    , ..., -1091.919 , -2044.3726,
        -2345.581 ],
       [    0.    ,     0.    ,     0.    , ..., -1091.919 , -2044.3726,...