Skip to content

Commit

Permalink
Various debugger improvements
Browse files Browse the repository at this point in the history
- disables globals
- can opt out of filtering frames
- can limit number of frames
  • Loading branch information
sharadmv committed Aug 9, 2022
1 parent 8dce848 commit c34aa39
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 16 deletions.
50 changes: 38 additions & 12 deletions jax/_src/debugger/core.py
Expand Up @@ -80,7 +80,7 @@ def from_frameinfo(cls, frame_info) -> DebuggerFrame:
return DebuggerFrame(
filename=frame_info.filename,
locals=frame_info.frame.f_locals,
globals=frame_info.frame.f_globals,
globals={},
code_context=frame_info.code_context,
source=source,
lineno=frame_info.lineno,
Expand Down Expand Up @@ -113,19 +113,45 @@ def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
debug_lock = threading.Lock()


def breakpoint(*, ordered: bool = False, backend=None, **kwargs): # pylint: disable=redefined-builtin
"""Enters a breakpoint at a point in a program."""
def breakpoint(*, backend: Optional[str] = None, filter_frames: bool = True,
num_frames: Optional[int] = None, ordered: bool = False,
**kwargs): # pylint: disable=redefined-builtin
"""Enters a breakpoint at a point in a program.
Args:
backend: The debugger backend to use. By default, picks the highest priority
debugger and in the absence of other registered debuggers, falls back to
the CLI debugger.
filter_frames: Whether or not to filter out JAX-internal stack frames from
the traceback. Since some libraries, like Flax, also make user of JAX's
stack frame filtering system, this option can also affect whether stack
frames from libraries are filtered.
num_frames: The number of frames above the current stack frame to make
available for inspection in the interactive debugger.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this ``debug_print``
with respect to other ordered ``debug_print`` calls.
Returns:
None.
"""
frame_infos = inspect.stack()
# Filter out internal frames
frame_infos = [
frame_info for frame_info in frame_infos
if traceback_util.include_frame(frame_info.frame)
]
frames = [
DebuggerFrame.from_frameinfo(frame_info) for frame_info in frame_infos
]
# Throw out first frame corresponding to this function
frames = frames[1:]
frame_infos = frame_infos[1:]
if num_frames is not None:
frame_infos = frame_infos[:num_frames]
# Filter out internal frames
if filter_frames:
frames = [
DebuggerFrame.from_frameinfo(frame_info)
for frame_info in frame_infos
if traceback_util.include_frame(frame_info.frame)
]
else:
frames = [
DebuggerFrame.from_frameinfo(frame_info)
for frame_info in frame_infos
]
flat_args, frames_tree = tree_util.tree_flatten(frames)

def _breakpoint_callback(*flat_args):
Expand Down
68 changes: 64 additions & 4 deletions tests/debugger_test.py
Expand Up @@ -59,6 +59,8 @@ def tearDownModule():
if jaxlib.version < (0, 3, 15):
disabled_backends.append("tpu")

foo = 2

class CliDebuggerTest(jtu.JaxTestCase):

@jtu.skip_on_devices(*disabled_backends)
Expand Down Expand Up @@ -321,8 +323,6 @@ def g(x):
\(jdb\) """.format(re.escape(repr(arr))))
g(jnp.arange(8, dtype=jnp.int32))
jax.effects_barrier()
print(stdout.getvalue())
print(expected)
self.assertRegex(stdout.getvalue(), expected)

@jtu.skip_on_devices(*disabled_backends)
Expand All @@ -344,10 +344,70 @@ def f(x):
\(jdb\) """)
f(2.)
jax.effects_barrier()
print(stdout.getvalue())
print(expected)
self.assertRegex(stdout.getvalue(), expected)


@jtu.skip_on_devices(*disabled_backends)
def test_debugger_accesses_globals(self):
stdin, stdout = make_fake_stdin_stdout(["p foo", "c"])

@jax.jit
def g():
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")

expected = _format_multiline(r"""
Entering jdb:
\(jdb\) \*\*\* NameError: name 'foo' is not defined
\(jdb\) """)
g()
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)

@jtu.skip_on_devices(*disabled_backends)
def test_can_limit_num_frames(self):
stdin, stdout = make_fake_stdin_stdout(["u", "p x", "c"])

def g():
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli",
num_frames=2)

@jax.jit
def f():
x = 2
g()
return x

_ = f()
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) .*
.*
.*
.*
.*
.*
.*
\(jdb\) 2
\(jdb\) """)
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)

stdin, stdout = make_fake_stdin_stdout(["u", "u", "c"])

def g2():
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli",
num_frames=2)

@jax.jit
def f2():
x = 2
g2()
return x

expected = ".*At topmost frame.*"
_ = f2()
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit c34aa39

Please sign in to comment.