Skip to content

Commit

Permalink
unrevert #3674 (revert #3791)
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Aug 18, 2020
1 parent 1ba4e06 commit 13b3230
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 2 deletions.
11 changes: 11 additions & 0 deletions jax/api.py
Expand Up @@ -43,6 +43,7 @@
from .api_util import (wraps, flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
flatten_fun_nokwargs2, argnums_partial, flatten_axes,
donation_vector, rebase_donate_argnums)
from .traceback_util import api_boundary
from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
tree_transpose, tree_leaves, tree_multimap,
treedef_is_leaf, Partial)
Expand Down Expand Up @@ -162,6 +163,7 @@ def jit(fun: Callable[..., T],
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)

@wraps(fun)
@api_boundary
def f_jitted(*args, **kwargs):
if _jit_is_disabled():
return fun(*args, **kwargs)
Expand Down Expand Up @@ -371,6 +373,7 @@ def abstractify(x):
return ShapedArray(np.shape(x), dtypes.result_type(x))

@wraps(fun)
@api_boundary
def computation_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun)
if static_argnums:
Expand Down Expand Up @@ -468,11 +471,13 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
"positions {argnums}.")

@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f(*args, **kwargs):
_, g = value_and_grad_f(*args, **kwargs)
return g

@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f_aux(*args, **kwargs):
(_, aux), g = value_and_grad_f(*args, **kwargs)
return g, aux
Expand Down Expand Up @@ -515,6 +520,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
_check_callable(fun)

@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def value_and_grad_f(*args, **kwargs):
max_argnum = argnums if type(argnums) is int else max(argnums)
if max_argnum >= len(args):
Expand Down Expand Up @@ -930,6 +936,7 @@ def vmap(fun: Callable[..., T], in_axes=0, out_axes=0, axis_name=None) -> Callab
del in_axes_, out_axes_

@wraps(fun, docstr=docstr)
@api_boundary
def batched_fun(*args):
args_flat, in_tree = tree_flatten(args)
f = lu.wrap_init(fun)
Expand Down Expand Up @@ -1206,6 +1213,7 @@ def pmap(fun: Callable[..., T],
raise ValueError(f"pmap in_axes leaves must be 0 or None, got {in_axes}")

@wraps(fun)
@api_boundary
def f_pmapped(*args, **kwargs):
f = lu.wrap_init(fun)
if static_broadcasted_tuple:
Expand Down Expand Up @@ -1272,6 +1280,7 @@ def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
raise ValueError(f"soft_pmap in_axes leaves must be 0 or None, got {in_axes}")

@wraps(fun)
@api_boundary
def f_pmapped(*args, **kwargs):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten((args, kwargs))
Expand Down Expand Up @@ -1635,6 +1644,7 @@ def make_jaxpr(fun: Callable,
static_argnums = (static_argnums,)

@wraps(fun)
@api_boundary
def jaxpr_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun)
if static_argnums:
Expand Down Expand Up @@ -1891,6 +1901,7 @@ def checkpoint(fun: Callable, concrete: bool = False) -> Callable:
...
"""
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
Expand Down
143 changes: 143 additions & 0 deletions jax/traceback_util.py
@@ -0,0 +1,143 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import traceback
import types

from .api_util import wraps

_jax_path = os.path.dirname(__file__)
_include_paths = [
os.path.join(_jax_path, path) for path in (
'config.py', 'dlpack.py', 'experimental', 'lax', 'lax_linalg.py',
'lax_reference.py', 'nn', 'numpy', 'ops', 'profiler.py', 'random.py',
'scipy', 'test_util.py', 'third_party', 'tools',
)]

_jax_message_append = (
'The stack trace above excludes JAX-internal frames.\n'
'The following is the original exception that occurred, unmodified.\n'
'\n--------------------')

def include_frame(f):
return (not f.f_code.co_filename.startswith(_jax_path) or
any(f.f_code.co_filename.startswith(path) for path in _include_paths))

# When scanning stack traces, we might encounter frames from cpython that are
# removed from printed stack traces, such as frames from parts of importlib. We
# ignore these frames heuristically based on source and name match.
def ignore_known_hidden_frame(f):
return 'importlib._bootstrap' in f.f_code.co_filename

def filter_traceback_and_stack(e):
out = None

# Scan the traceback and collect relevant frames.

for f, lineno in reversed(list(traceback.walk_tb(e.__traceback__))):
if include_frame(f):
out = types.TracebackType(out, f, f.f_lasti, lineno) # pytype: disable=wrong-arg-count

# Continue up the call stack.
#
# We would like to avoid stepping too far up, e.g. past the exec/eval point of
# a REPL such as IPython. To that end, we stop past the first contiguous bunch
# of module-level frames, if we reach any such frames at all. This is a
# heuristic that might stop in advance of the REPL boundary. For example, if
# the call stack includes module-level frames from the current module A, and
# the current module A was imported from within a function F elsewhere, then
# the stack trace we produce will be truncated at F's frame.

reached_module_level = False
for f, lineno in traceback.walk_stack(e.__traceback__.tb_frame):
if ignore_known_hidden_frame(f):
continue
if reached_module_level and f.f_code.co_name != '<module>':
break
if include_frame(f):
out = types.TracebackType(out, f, f.f_lasti, lineno) # pytype: disable=wrong-arg-count
if f.f_code.co_name == '<module>':
reached_module_level = True

return out

def is_reraiser_frame(f):
return (f.filename == __file__ and
f.name == 'reraise_with_filtered_traceback')

def is_under_reraiser(e):
tb = traceback.extract_stack(e.__traceback__.tb_frame)
return any(is_reraiser_frame(f) for f in tb[:-1])

def format_exception_only(e):
return ''.join(traceback.format_exception_only(type(e), e)).strip()

def last_cause(e):
prev, cur = e, e.__cause__
while cur is not None:
prev, cur = cur, cur.__cause__
return prev

class FilteredStackTrace(Exception): pass

def filtered_tracebacks_supported():
return sys.version_info >= (3, 7)

def api_boundary(fun):
'''Wraps ``fun`` to form a boundary for filtering exception tracebacks.
When an exception occurs below ``fun``, this appends to it a custom
``__cause__`` that carries a filtered traceback. The traceback imitates the
stack trace of the original exception, but with JAX-internal frames removed.
This boundary annotation works in composition with itself. The topmost frame
corresponding to an ``api_boundary`` is the one below which stack traces are
filtered. In other words, if ``api_boundary(f)`` calls ``api_boundary(g)``,
directly or indirectly, the filtered stack trace provided is the same as if
``api_boundary(f)`` were to simply call ``g`` instead.
This annotation is primarily useful in wrapping functions output by JAX's
transformations. For example, consder ``g = jax.jit(f)``. When ``g`` is
called, JAX's JIT compilation machinery is invoked, which in turn calls ``f``
in order to trace and translate it. If the function ``f`` raises an exception,
the stack unwinds through JAX's JIT internals up to the original call site of
``g``. Because the function returned by ``jax.jit`` is annotated as an
``api_boundary``, such an exception is accompanied by an additional traceback
that excludes the frames specific to JAX's implementation.
'''

if not filtered_tracebacks_supported():
return fun

@wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
try:
return fun(*args, **kwargs)
except Exception as e:
if not is_under_reraiser(e):
filtered_tb = filter_traceback_and_stack(e)
if filtered_tb:
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
filtered = FilteredStackTrace(msg).with_traceback(filtered_tb)
cause = last_cause(e)
cause.__cause__ = filtered
raise
else:
raise
else:
raise
return reraise_with_filtered_traceback
149 changes: 149 additions & 0 deletions tests/errors_test.py
@@ -0,0 +1,149 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import re
import traceback
import unittest

from absl.testing import absltest

from jax import grad, jit, vmap
import jax.numpy as jnp
from jax import test_util as jtu
from jax import traceback_util


from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS


def get_exception(etype, f):
try:
f()
except etype as e:
return e
assert False

def check_filtered_stack_trace(test, etype, f, frame_patterns=[]):
test.assertRaises(etype, f)
e = get_exception(etype, f)
c = traceback_util.last_cause(e)
test.assertIsInstance(c, traceback_util.FilteredStackTrace)
c_tb = traceback.format_tb(c.__traceback__)
if frame_patterns:
for (fname_pat, line_pat), frame_fmt in zip(
reversed(frame_patterns), reversed(c_tb)):
fname_pat = re.escape(fname_pat)
line_pat = re.escape(line_pat)
full_pat = (
f' File "{__file__}", line ' r'[0-9]+'
f', in {fname_pat}' r'\n\s*' f'{line_pat}')
test.assertRegex(frame_fmt, full_pat)


class FilteredTracebackTest(jtu.JaxTestCase):

def test_nested_jit(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')

@jit
def innermost(x):
assert False
@jit
def inbetween(x):
return 1 + innermost(x)
@jit
def outermost(x):
return 2 + inbetween(x)

f = lambda: outermost(jnp.array([1, 2]))

check_filtered_stack_trace(self, AssertionError, f, [
('<lambda>', 'f = lambda: outermost'),
('outermost', 'return 2 + inbetween(x)'),
('inbetween', 'return 1 + innermost(x)'),
('innermost', 'assert False')])

def test_nested_jit_and_vmap(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')

@jit
def innermost(x):
assert False
@jit
def inbetween(x):
return 1 + vmap(innermost)(x)
@jit
def outermost(x):
return 2 + inbetween(x)

f = lambda: outermost(jnp.array([1, 2]))

check_filtered_stack_trace(self, AssertionError, f, [
('<lambda>', 'f = lambda: outermost'),
('outermost', 'return 2 + inbetween(x)'),
('inbetween', 'return 1 + vmap(innermost)(x)'),
('innermost', 'assert False')])

def test_nested_jit_and_grad(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')

@jit
def innermost(x):
assert False
@jit
def inbetween(x):
return 1 + grad(innermost)(x)
@jit
def outermost(x):
return 2 + inbetween(x)

f = lambda: outermost(jnp.array([1, 2]))

check_filtered_stack_trace(self, TypeError, f, [
('<lambda>', 'f = lambda: outermost'),
('outermost', 'return 2 + inbetween(x)'),
('inbetween', 'return 1 + grad(innermost)(x)')])

def test_cause_chain(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')

@jit
def inner(x):
raise ValueError('inner')
@jit
def outer(x):
try:
inner(x)
except ValueError as e:
raise TypeError('outer') from e

f = lambda: outer(1.)

check_filtered_stack_trace(self, TypeError, f, [
('<lambda>', 'f = lambda: outer'),
('outer', 'raise TypeError')])
e = get_exception(TypeError, f)
self.assertIsInstance(e.__cause__, ValueError)
self.assertIsInstance(e.__cause__.__cause__,
traceback_util.FilteredStackTrace)

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
5 changes: 3 additions & 2 deletions tests/loops_test.py
Expand Up @@ -355,9 +355,10 @@ def f_op(inc):
s.out += i
return s.out

with self.assertRaisesWithLiteralMatch(
with self.assertRaisesRegex(
ValueError,
"Body of cond_range or while_range should not use the index variable returned by iterator."):
r"^Body of cond_range or while_range should not use the index variable "
r"returned by iterator\."):
api.make_jaxpr(f_op)(2.)

def test_while(self):
Expand Down

0 comments on commit 13b3230

Please sign in to comment.