Skip to content

Commit

Permalink
Change jaxdb->jdb and add option to force a backend
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Jul 29, 2022
1 parent fb0cf66 commit decdca6
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 66 deletions.
2 changes: 1 addition & 1 deletion docs/debugging/flags.md
@@ -1,6 +1,6 @@
# JAX debugging flags

JAX offers flags and context managers.
JAX offers flags and context managers that enable catching errors more easily.

## `jax_debug_nans` configuration option and context manager

Expand Down
12 changes: 9 additions & 3 deletions docs/debugging/index.md
@@ -1,6 +1,6 @@
# Debugging in JAX
# Runtime value debugging in JAX

Do you have exploding gradients? Are nans making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools!
Do you have exploding gradients? Are nans making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has tl;dr summaries and you can click the "Read more" links at the bottom to learn more.

## [Interactive inspection with `jax.debug`](print_breakpoint)

Expand All @@ -26,6 +26,8 @@ Do you have exploding gradients? Are nans making you gnash your teeth? Just want
# 🤯 0.9092974662780762 🤯
```

Click [here](print_breakpoint) to learn more!

## [Functional error checks with `jax.experimental.checkify`](checkify_guide)

**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
Expand Down Expand Up @@ -67,6 +69,8 @@ Do you have exploding gradients? Are nans making you gnash your teeth? Just want
# ValueError: nan generated by primitive sin at <...>:8 (f)
```

Click [here](checkify_guide) to learn more!

## [Throwing Python errors with JAX's debug flags](flags)

**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
Expand All @@ -80,8 +84,10 @@ def f(x, y):
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
```

Click [here](flags) to learn more!

```{toctree}
:caption: Index
:caption: Read more
:maxdepth: 1
print_breakpoint
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/debugger/cli_debugger.py
Expand Up @@ -26,7 +26,7 @@

class CliDebugger(cmd.Cmd):
"""A text-based debugger."""
prompt = '(jaxdb) '
prompt = '(jdb) '
use_rawinput: bool = False

def __init__(self, frames: List[DebuggerFrame], thread_id,
Expand All @@ -36,7 +36,7 @@ def __init__(self, frames: List[DebuggerFrame], thread_id,
self.frames = frames
self.frame_index = 0
self.thread_id = thread_id
self.intro = 'Entering jaxdb:'
self.intro = 'Entering jdb:'

def current_frame(self):
return self.frames[self.frame_index]
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/debugger/core.py
Expand Up @@ -95,7 +95,9 @@ def __call__(self, frames: List[DebuggerFrame], thread_id: Optional[int],
_debugger_registry: Dict[str, Tuple[int, Debugger]] = {}


def get_debugger() -> Debugger:
def get_debugger(backend: Optional[str] = None) -> Debugger:
if backend is not None and backend in _debugger_registry:
return _debugger_registry[backend][1]
debuggers = sorted(_debugger_registry.values(), key=lambda x: -x[0])
if not debuggers:
raise ValueError("No debuggers registered!")
Expand All @@ -111,7 +113,7 @@ def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
debug_lock = threading.Lock()


def breakpoint(*, ordered: bool = False, **kwargs): # pylint: disable=redefined-builtin
def breakpoint(*, ordered: bool = False, backend=None, **kwargs): # pylint: disable=redefined-builtin
"""Enters a breakpoint at a point in a program."""
frame_infos = inspect.stack()
# Filter out internal frames
Expand All @@ -131,7 +133,7 @@ def _breakpoint_callback(*flat_args):
thread_id = None
if threading.current_thread() is not threading.main_thread():
thread_id = threading.get_ident()
debugger = get_debugger()
debugger = get_debugger(backend=backend)
# Lock here because this could be called from multiple threads at the same
# time.
with debug_lock:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/debugger/web_debugger.py
Expand Up @@ -31,7 +31,7 @@

class WebDebugger(cli_debugger.CliDebugger):
"""A web-based debugger."""
prompt = '(jaxdb) '
prompt = '(jdb) '
use_rawinput: bool = False

def __init__(self, frames: List[debugger_core.DebuggerFrame], thread_id,
Expand Down
115 changes: 59 additions & 56 deletions tests/debugger_test.py
Expand Up @@ -67,7 +67,7 @@ def test_debugger_eof(self):

def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
with self.assertRaises(SystemExit):
f(2.)
Expand All @@ -79,13 +79,13 @@ def test_debugger_can_continue(self):

def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
f(2.)
jax.effects_barrier()
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) """)
Entering jdb:
(jdb) """)
self.assertEqual(stdout.getvalue(), expected)

@jtu.skip_on_devices(*disabled_backends)
Expand All @@ -94,12 +94,12 @@ def test_debugger_can_print_value(self):

def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) DeviceArray(2., dtype=float32)
(jaxdb) """)
Entering jdb:
(jdb) DeviceArray(2., dtype=float32)
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
Expand All @@ -111,12 +111,12 @@ def test_debugger_can_print_value_in_jit(self):
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) array(2., dtype=float32)
(jaxdb) """)
Entering jdb:
(jdb) array(2., dtype=float32)
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
Expand All @@ -128,12 +128,12 @@ def test_debugger_can_print_multiple_values(self):
@jax.jit
def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) (array(2., dtype=float32), array(3., dtype=float32))
(jaxdb) """)
Entering jdb:
(jdb) (array(2., dtype=float32), array(3., dtype=float32))
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
Expand All @@ -145,20 +145,20 @@ def test_debugger_can_print_context(self):
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
f(2.)
jax.effects_barrier()
expected = _format_multiline(r"""
Entering jaxdb:
\(jaxdb\) > .*debugger_test\.py\([0-9]+\)
Entering jdb:
\(jdb\) > .*debugger_test\.py\([0-9]+\)
@jax\.jit
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jaxdb\) """)
\(jdb\) """)
self.assertRegex(stdout.getvalue(), expected)

@jtu.skip_on_devices(*disabled_backends)
Expand All @@ -168,11 +168,11 @@ def test_debugger_can_print_backtrace(self):
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jaxdb:.*
\(jaxdb\) Traceback:.*
Entering jdb:.*
\(jdb\) Traceback:.*
""")
f(2.)
jax.effects_barrier()
Expand All @@ -184,35 +184,35 @@ def test_debugger_can_work_with_multiple_stack_frames(self):

def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y

@jax.jit
def g(x):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jaxdb:
\(jaxdb\) > .*debugger_test\.py\([0-9]+\)
Entering jdb:
\(jdb\) > .*debugger_test\.py\([0-9]+\)
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jaxdb\) > .*debugger_test\.py\([0-9]+\).*
\(jdb\) > .*debugger_test\.py\([0-9]+\).*
@jax\.jit
def g\(x\):
-> y = f\(x\)
return jnp\.exp\(y\)
.*
\(jaxdb\) array\(2\., dtype=float32\)
\(jaxdb\) > .*debugger_test\.py\([0-9]+\)
\(jdb\) array\(2\., dtype=float32\)
\(jdb\) > .*debugger_test\.py\([0-9]+\)
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jaxdb\) """)
\(jdb\) """)
g(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
Expand All @@ -223,20 +223,22 @@ def test_can_use_multiple_breakpoints(self):

def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True)
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,
backend="cli")
return y

@jax.jit
def g(x):
y = f(x) * 2.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True)
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,
backend="cli")
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) array(3., dtype=float32)
(jaxdb) Entering jaxdb:
(jaxdb) array(6., dtype=float32)
(jaxdb) """)
Entering jdb:
(jdb) array(3., dtype=float32)
(jdb) Entering jdb:
(jdb) array(6., dtype=float32)
(jdb) """)
g(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
Expand All @@ -251,7 +253,8 @@ def test_debugger_works_with_vmap(self):

def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered)
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered,
backend="cli")
return 2. * y

@jax.jit
Expand All @@ -260,11 +263,11 @@ def g(x):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jaxdb:
(jaxdb) array(1., dtype=float32)
(jaxdb) Entering jaxdb:
(jaxdb) array(2., dtype=float32)
(jaxdb) """)
Entering jdb:
(jdb) array(1., dtype=float32)
(jdb) Entering jdb:
(jdb) array(2., dtype=float32)
(jdb) """)
g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
Expand All @@ -277,19 +280,19 @@ def test_debugger_works_with_pmap(self):

def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y

@jax.pmap
def g(x):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jaxdb:
\(jaxdb\) array\(.*, dtype=float32\)
\(jaxdb\) Entering jaxdb:
\(jaxdb\) array\(.*, dtype=float32\)
\(jaxdb\) """)
Entering jdb:
\(jdb\) array\(.*, dtype=float32\)
\(jdb\) Entering jdb:
\(jdb\) array\(.*, dtype=float32\)
\(jdb\) """)
g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
Expand All @@ -302,7 +305,7 @@ def test_debugger_works_with_pjit(self):

def f(x):
y = x + 1
debugger.breakpoint(stdin=stdin, stdout=stdout)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y

def g(x):
Expand All @@ -313,9 +316,9 @@ def g(x):
with maps.Mesh(np.array(jax.devices()), ["dev"]):
arr = (1 + np.arange(8)).astype(np.int32)
expected = _format_multiline(r"""
Entering jaxdb:
\(jaxdb\) {}
\(jaxdb\) """.format(re.escape(repr(arr))))
Entering jdb:
\(jdb\) {}
\(jdb\) """.format(re.escape(repr(arr))))
g(jnp.arange(8, dtype=jnp.int32))
jax.effects_barrier()
print(stdout.getvalue())
Expand Down

0 comments on commit decdca6

Please sign in to comment.