Skip to content

Commit

Permalink
Fix URL in custom errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 16, 2021
1 parent 2d148a3 commit e9195ba
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
12 changes: 10 additions & 2 deletions jax/_src/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

class JAXTypeError(TypeError):
"""Base class for JAX-specific TypeErrors"""
_error_page = 'https://jax.readthedocs.io/en/latest/errors.html'

def __init__(self, message: str):
error_page = 'https://jax.readthedocs.io/en/latest/errors.html'
module_name = self.__class__.__module__
error_page = self._error_page
module_name = getattr(self, '_module_name', self.__class__.__module__)
class_name = self.__class__.__name__
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
super().__init__(error_msg)
Expand Down Expand Up @@ -128,6 +130,8 @@ class ConcretizationTypeError(JAXTypeError):
To understand more subtleties having to do with tracers vs. regular values, and
concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
_module_name = "jax.errors"

def __init__(self, tracer: "core.Tracer", context: str = ""):
super().__init__(
"Abstract tracer value encountered where concrete value is expected: "
Expand Down Expand Up @@ -201,6 +205,8 @@ class TracerArrayConversionError(JAXTypeError):
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
_module_name = "jax.errors"

def __init__(self, tracer: "core.Tracer"):
super().__init__(
"The numpy.ndarray conversion method __array__() was called on "
Expand Down Expand Up @@ -287,6 +293,8 @@ class TracerIntegerConversionError(JAXTypeError):
To understand more subtleties having to do with tracers vs. regular values, and concrete vs.
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
_module_name = "jax.errors"

def __init__(self, tracer: "core.Tracer"):
super().__init__(
f"The __index__() method was called on the JAX Tracer object {tracer}")
18 changes: 17 additions & 1 deletion tests/errors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import unittest

from absl.testing import absltest
from absl.testing import parameterized

from jax import grad, jit, vmap, lax
import jax
from jax import core, grad, jit, vmap, lax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import traceback_util
Expand Down Expand Up @@ -307,5 +309,19 @@ def outer(x):
traceback_util.FilteredStackTrace)


class CustomErrorsTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(errorclass), "errorclass": errorclass}
for errorclass in dir(jax.errors)
if errorclass.endswith('Error') and errorclass != 'JAXTypeError'))
def testErrorsURL(self, errorclass):
class FakeTracer(core.Tracer):
aval = None
ErrorClass = getattr(jax.errors, errorclass)
err = ErrorClass(FakeTracer(None))

self.assertIn(f'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.{errorclass}', str(err))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit e9195ba

Please sign in to comment.