Deal with NaN errors in jax.experimental.ode.optimal_step_size() #14612
Labels
bug
Something isn't working
contributions welcome
The JAX team has not prioritized work on this. Community contributions are welcome.
Description
The following is related to my discussion about diffrax with @patrick-kidger here: patrick-kidger/diffrax#223
It can happen when solving certain ODE systems that the error estimate for adaptive stepping will become NaN (e.g., if a trial timestep is too large). Currently this means jax.experimental.ode.odeint will fail. If we use jax.debug.print to print the values of the state variables, derivatives, and other intermediary quantities inside the integrator function while odeint is running, we would see NaN's in all timesteps except perhaps the first few (after an initial NaN appears, odeint cannot recover).
The fix is simple -- change this line in jax.experimental.ode.optimal_step_size() from
to something like
I don't have a minimal working example because the code I was running into this problem with is a bit complicated (it involves a non-autonomous ODE system). However, you can plug in mean_error_ratio = jnp.nan in the above two examples to see what I mean. I wrote my own adaptive RK23 (Bogacki-Shampine) integrator in JAX to diagnose the above problem. My simple integrator had the same problem as jax's odeint until I made that switch to use the jnp.nanmin and jnp.nanmax functions when computing the multiplicative factor for the new optimal stepsize (following eqn 4.13 of Hairer section II.4). Now both jax-odeint and my manual JAX-based RK23 integrator give the correct, expected solution compared to my original pure Python implementation of my ODE system using scipy.integrate.solve_ivp and manual non-adaptive Euler integration with extremely small timesteps. For what it's worth, scipy.integrate.solve_ivp did not run into this issue because it uses Python's built-in min and max functions, which can weirdly handle NaN's as long as they are not the first argument (see lines 152-168 of scipy.integrate._ivp.rk)...
Thoughts? Could this spell trouble for taking the gradient? (I assume not since steps involving NaN would now explicitly be rejected?)
What jax/jaxlib version are you using?
jax 0.4.2, jaxlib 0.4.2
Which accelerator(s) are you using?
CPU
Additional system info
Macbook M1 Pro
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: