Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deal with NaN errors in jax.experimental.ode.optimal_step_size() #14612

Open
virajpandya opened this issue Feb 21, 2023 · 0 comments
Open

Deal with NaN errors in jax.experimental.ode.optimal_step_size() #14612

virajpandya opened this issue Feb 21, 2023 · 0 comments
Labels
bug Something isn't working contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome.

Comments

@virajpandya
Copy link

virajpandya commented Feb 21, 2023

Description

The following is related to my discussion about diffrax with @patrick-kidger here: patrick-kidger/diffrax#223

It can happen when solving certain ODE systems that the error estimate for adaptive stepping will become NaN (e.g., if a trial timestep is too large). Currently this means jax.experimental.ode.odeint will fail. If we use jax.debug.print to print the values of the state variables, derivatives, and other intermediary quantities inside the integrator function while odeint is running, we would see NaN's in all timesteps except perhaps the first few (after an initial NaN appears, odeint cannot recover).

The fix is simple -- change this line in jax.experimental.ode.optimal_step_size() from

factor = jnp.minimum(ifactor,
                      jnp.maximum(mean_error_ratio**(-1.0 / order) * safety, dfactor))

to something like

factor = jnp.nanmin(jnp.array([ifactor,
                                 jnp.nanmax(jnp.array([mean_error_ratio**(-1.0 / order) * safety, dfactor]))]))  

I don't have a minimal working example because the code I was running into this problem with is a bit complicated (it involves a non-autonomous ODE system). However, you can plug in mean_error_ratio = jnp.nan in the above two examples to see what I mean. I wrote my own adaptive RK23 (Bogacki-Shampine) integrator in JAX to diagnose the above problem. My simple integrator had the same problem as jax's odeint until I made that switch to use the jnp.nanmin and jnp.nanmax functions when computing the multiplicative factor for the new optimal stepsize (following eqn 4.13 of Hairer section II.4). Now both jax-odeint and my manual JAX-based RK23 integrator give the correct, expected solution compared to my original pure Python implementation of my ODE system using scipy.integrate.solve_ivp and manual non-adaptive Euler integration with extremely small timesteps. For what it's worth, scipy.integrate.solve_ivp did not run into this issue because it uses Python's built-in min and max functions, which can weirdly handle NaN's as long as they are not the first argument (see lines 152-168 of scipy.integrate._ivp.rk)...

Thoughts? Could this spell trouble for taking the gradient? (I assume not since steps involving NaN would now explicitly be rejected?)

What jax/jaxlib version are you using?

jax 0.4.2, jaxlib 0.4.2

Which accelerator(s) are you using?

CPU

Additional system info

Macbook M1 Pro

NVIDIA GPU info

No response

@virajpandya virajpandya added the bug Something isn't working label Feb 21, 2023
@apaszke apaszke added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Feb 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome.
Projects
None yet
Development

No branches or pull requests

2 participants