Skip to content

Commit

Permalink
jax errors: improve leading info in error docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 14, 2023
1 parent 9015920 commit c3e9f85
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions jax/_src/errors.py
Expand Up @@ -47,7 +47,8 @@ class JAXIndexError(_JAXErrorMixin, IndexError):
class ConcretizationTypeError(JAXTypeError):
"""
This error occurs when a JAX Tracer object is used in a context where a
concrete value is required. In some situations, it can be easily fixed by
concrete value is required (see :ref:`faq-different-kinds-of-jax-values`
for more on what a Tracer is). In some situations, it can be easily fixed by
marking problematic values as static; in others, it may indicate that your
program is doing operations that are not directly supported by JAX's JIT
compilation model.
Expand Down Expand Up @@ -283,41 +284,45 @@ def __init__(self, tracer: core.Tracer):
class TracerArrayConversionError(JAXTypeError):
"""
This error occurs when a program attempts to convert a JAX Tracer object into
a standard NumPy array. It typically occurs in one of a few situations.
a standard NumPy array (see :ref:`faq-different-kinds-of-jax-values` for more
on what a Tracer is). It typically occurs in one of a few situations.
Using `numpy` rather than `jax.numpy` functions
This error can occur when a JAX Tracer object is passed to a raw numpy
function, or a method on a numpy.ndarray object. For example::
Using non-JAX functions in JAX transformations
This error can occur if you attempt to use a non-JAX library like `numpy`
or ``scipy`` inside a JAX transformation (:func:`~jax.jit`, :func:`~jax.grad`,
:func:`jax.vmap`, etc.). For example::
>>> from functools import partial
>>> from jax import jit
>>> import numpy as np
>>> import jax.numpy as jnp
>>> @jit
... def func(x):
... return np.sin(x)
>>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
>>> func(np.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on the JAX Tracer object
In this case, check that you are using `jax.numpy` methods rather than
`numpy` methods::
In this case, you can fix the issue by using :func:`jax.numpy.sin` in place of
:func:`numpy.sin`::
>>> import jax.numpy as jnp
>>> @jit
... def func(x):
... return jnp.sin(x)
>>> func(jnp.arange(4))
Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
See also `External Callbacks`_ for options for calling back to host-side computations
from transformed JAX code.
Indexing a numpy array with a tracer
If this error arises on a line that involves array indexing, it may be that
the array being indexed `x` is a raw numpy.ndarray while the indices `idx`
are traced. For example::
the array being indexed `x` is a standard numpy.ndarray while the indices `idx`
are traced JAX arrays. For example::
>>> x = np.arange(10)
Expand All @@ -343,6 +348,7 @@ class TracerArrayConversionError(JAXTypeError):
or by declaring the index as a static argument::
>>> from functools import partial
>>> @partial(jit, static_argnums=(0,))
... def func(i):
... return x[i]
Expand All @@ -353,6 +359,8 @@ class TracerArrayConversionError(JAXTypeError):
To understand more subtleties having to do with tracers vs. regular values,
and concrete vs. abstract values, you may want to read
:ref:`faq-different-kinds-of-jax-values`.
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
"""
def __init__(self, tracer: core.Tracer):
super().__init__(
Expand All @@ -364,13 +372,13 @@ def __init__(self, tracer: core.Tracer):
class TracerIntegerConversionError(JAXTypeError):
"""
This error can occur when a JAX Tracer object is used in a context where a
Python integer is expected. It typically occurs in a few situations.
Python integer is expected (see :ref:`faq-different-kinds-of-jax-values` for
more on what a Tracer is). It typically occurs in a few situations.
Passing a tracer in place of an integer
This error can occur if you attempt to pass a tracer to a function that
requires an integer argument; for example::
This error can occur if you attempt to pass a traced value to a function
that requires a static integer argument; for example::
>>> from functools import partial
>>> from jax import jit
>>> import numpy as np
Expand All @@ -387,6 +395,7 @@ class TracerIntegerConversionError(JAXTypeError):
When this happens, the solution is often to mark the problematic argument as
static::
>>> from functools import partial
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
... return np.split(x, 2, axis)
Expand All @@ -410,7 +419,6 @@ class TracerIntegerConversionError(JAXTypeError):
quantity.
For example::
>>> from functools import partial
>>> import jax.numpy as jnp
>>> from jax import jit
Expand All @@ -437,6 +445,7 @@ class TracerIntegerConversionError(JAXTypeError):
or by declaring the index as a static argument::
>>> from functools import partial
>>> @partial(jit, static_argnums=0)
... def func(i):
... return L[i]
Expand Down

0 comments on commit c3e9f85

Please sign in to comment.