Skip to content

Commit

Permalink
improve docs and error message for odeint *args (#2931)
Browse files Browse the repository at this point in the history
cf. #2920
  • Loading branch information
mattjj committed May 2, 2020
1 parent a182578 commit 64f12a4
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion jax/experimental/ode.py
Expand Up @@ -28,6 +28,7 @@

import jax
import jax.numpy as np
from jax import core
from jax import lax
from jax import ops
from jax.util import safe_map, safe_zip
Expand Down Expand Up @@ -141,7 +142,9 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):
y0: array or pytree of arrays representing the initial value for the state.
t: array of float times for evaluation, like `np.linspace(0., 10., 101)`,
in which the values must be strictly increasing.
*args: tuple of additional arguments for `func`.
*args: tuple of additional arguments for `func`, which must be arrays
scalars, or (nested) standard Python containers (tuples, lists, dicts,
namedtuples, i.e. pytrees) of those types.
rtol: float, relative local error tolerance for solver (optional).
atol: float, absolute local error tolerance for solver (optional).
mxstep: int, maximum number of steps to take for each timepoint (optional).
Expand All @@ -151,6 +154,12 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):
point in `t`, represented as an array (or pytree of arrays) with the same
shape/structure as `y0` except with a new leading axis of length `len(t)`.
"""
def _check_arg(arg):
if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg):
msg = ("The contents of odeint *args must be arrays or scalars, but got "
"\n{}.")
raise TypeError(msg.format(arg))
tree_map(_check_arg, args)
return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)

@partial(jax.jit, static_argnums=(0, 1, 2, 3))
Expand Down

0 comments on commit 64f12a4

Please sign in to comment.