From c3e9f85cfa3ee68dc14aff5850e28da2ca115879 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 Jun 2023 02:07:00 -0700 Subject: [PATCH] jax errors: improve leading info in error docs --- jax/_src/errors.py | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 5ee8355808c1..acd6c077daad 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -47,7 +47,8 @@ class JAXIndexError(_JAXErrorMixin, IndexError): class ConcretizationTypeError(JAXTypeError): """ This error occurs when a JAX Tracer object is used in a context where a - concrete value is required. In some situations, it can be easily fixed by + concrete value is required (see :ref:`faq-different-kinds-of-jax-values` + for more on what a Tracer is). In some situations, it can be easily fixed by marking problematic values as static; in others, it may indicate that your program is doing operations that are not directly supported by JAX's JIT compilation model. @@ -283,30 +284,31 @@ def __init__(self, tracer: core.Tracer): class TracerArrayConversionError(JAXTypeError): """ This error occurs when a program attempts to convert a JAX Tracer object into - a standard NumPy array. It typically occurs in one of a few situations. + a standard NumPy array (see :ref:`faq-different-kinds-of-jax-values` for more + on what a Tracer is). It typically occurs in one of a few situations. - Using `numpy` rather than `jax.numpy` functions - This error can occur when a JAX Tracer object is passed to a raw numpy - function, or a method on a numpy.ndarray object. For example:: + Using non-JAX functions in JAX transformations + This error can occur if you attempt to use a non-JAX library like `numpy` + or ``scipy`` inside a JAX transformation (:func:`~jax.jit`, :func:`~jax.grad`, + :func:`jax.vmap`, etc.). For example:: - >>> from functools import partial >>> from jax import jit >>> import numpy as np - >>> import jax.numpy as jnp >>> @jit ... def func(x): ... return np.sin(x) - >>> func(jnp.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> func(np.arange(4)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object - In this case, check that you are using `jax.numpy` methods rather than - `numpy` methods:: + In this case, you can fix the issue by using :func:`jax.numpy.sin` in place of + :func:`numpy.sin`:: + >>> import jax.numpy as jnp >>> @jit ... def func(x): ... return jnp.sin(x) @@ -314,10 +316,13 @@ class TracerArrayConversionError(JAXTypeError): >>> func(jnp.arange(4)) Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32) + See also `External Callbacks`_ for options for calling back to host-side computations + from transformed JAX code. + Indexing a numpy array with a tracer If this error arises on a line that involves array indexing, it may be that - the array being indexed `x` is a raw numpy.ndarray while the indices `idx` - are traced. For example:: + the array being indexed `x` is a standard numpy.ndarray while the indices `idx` + are traced JAX arrays. For example:: >>> x = np.arange(10) @@ -343,6 +348,7 @@ class TracerArrayConversionError(JAXTypeError): or by declaring the index as a static argument:: + >>> from functools import partial >>> @partial(jit, static_argnums=(0,)) ... def func(i): ... return x[i] @@ -353,6 +359,8 @@ class TracerArrayConversionError(JAXTypeError): To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`. + + .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html """ def __init__(self, tracer: core.Tracer): super().__init__( @@ -364,13 +372,13 @@ def __init__(self, tracer: core.Tracer): class TracerIntegerConversionError(JAXTypeError): """ This error can occur when a JAX Tracer object is used in a context where a - Python integer is expected. It typically occurs in a few situations. + Python integer is expected (see :ref:`faq-different-kinds-of-jax-values` for + more on what a Tracer is). It typically occurs in a few situations. Passing a tracer in place of an integer - This error can occur if you attempt to pass a tracer to a function that - requires an integer argument; for example:: + This error can occur if you attempt to pass a traced value to a function + that requires a static integer argument; for example:: - >>> from functools import partial >>> from jax import jit >>> import numpy as np @@ -387,6 +395,7 @@ class TracerIntegerConversionError(JAXTypeError): When this happens, the solution is often to mark the problematic argument as static:: + >>> from functools import partial >>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return np.split(x, 2, axis) @@ -410,7 +419,6 @@ class TracerIntegerConversionError(JAXTypeError): quantity. For example:: - >>> from functools import partial >>> import jax.numpy as jnp >>> from jax import jit @@ -437,6 +445,7 @@ class TracerIntegerConversionError(JAXTypeError): or by declaring the index as a static argument:: + >>> from functools import partial >>> @partial(jit, static_argnums=0) ... def func(i): ... return L[i]