Skip to content

Commit

Permalink
Try flattening the locals/globals dict in the debugger and have a
Browse files Browse the repository at this point in the history
fallback if it fails
  • Loading branch information
sharadmv committed Aug 9, 2022
1 parent 88636d2 commit 18f164f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 10 deletions.
61 changes: 52 additions & 9 deletions jax/_src/debugger/core.py
Expand Up @@ -14,11 +14,10 @@
from __future__ import annotations

import dataclasses
import functools
import inspect
import threading

from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Hashable, List, Optional, Tuple
from typing_extensions import Protocol

import jax.numpy as jnp
Expand All @@ -29,6 +28,46 @@
from jax._src import util
import numpy as np


@tree_util.register_pytree_node_class
class _DictWrapper:
keys: list[Hashable]
values: list[Any]

def __init__(self, keys, values):
self._keys = keys
self._values = values

def to_dict(self):
return dict(zip(self._keys, self._values))

def tree_flatten(self):
return self._values, self._keys

@classmethod
def tree_unflatten(cls, keys, values):
return _DictWrapper(keys, values)


class _CantFlatten:
__repr__ = lambda _: "<cant_flatten>"
cant_flatten = _CantFlatten()

def _safe_flatten_dict(dct: dict[Any, Any]
) -> tuple[list[Any], tree_util.PyTreeDef]:
# We avoid comparison between keys by just using the original order
keys, values = [], []
for key, value in dct.items():
try:
tree_util.tree_leaves(value)
except:
# If flattening fails, we substitute a sentinel object.
value = cant_flatten
keys.append(key)
values.append(value)
return tree_util.tree_flatten(_DictWrapper(keys, values))


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class DebuggerFrame:
Expand All @@ -42,22 +81,26 @@ class DebuggerFrame:
offset: Optional[int]

def tree_flatten(self):
flat_vars, vars_tree = tree_util.tree_flatten((self.locals, self.globals))
flat_locals, locals_tree = _safe_flatten_dict(self.locals)
flat_globals, globals_tree = _safe_flatten_dict(self.globals)
flat_vars = flat_locals + flat_globals
is_valid = [
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
for l in flat_vars
]
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
return valid_vars, (is_valid, invalid_vars, vars_tree, self.filename,
self.code_context, self.source, self.lineno,
self.offset)
return valid_vars, (is_valid, invalid_vars, locals_tree, globals_tree,
len(flat_locals), self.filename, self.code_context,
self.source, self.lineno, self.offset)

@classmethod
def tree_unflatten(cls, info, valid_vars):
(is_valid, invalid_vars, vars_tree, filename, code_context, source,
lineno, offset) = info
(is_valid, invalid_vars, locals_tree, globals_tree, num_locals, filename,
code_context, source, lineno, offset) = info
flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars)
locals_, globals_ = tree_util.tree_unflatten(vars_tree, flat_vars)
flat_locals, flat_globals = util.split_list(flat_vars, [num_locals])
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals).to_dict()
globals_ = tree_util.tree_unflatten(globals_tree, flat_globals).to_dict()
return DebuggerFrame(filename, locals_, globals_, code_context, source,
lineno, offset)

Expand Down
23 changes: 22 additions & 1 deletion tests/debugger_test.py
Expand Up @@ -376,7 +376,6 @@ def f():
x = 2
g()
return x

_ = f()
expected = _format_multiline(r"""
Entering jdb:
Expand Down Expand Up @@ -409,5 +408,27 @@ def f2():
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)

def test_can_handle_dictionaries_with_unsortable_keys(self):
stdin, stdout = make_fake_stdin_stdout(["p x", "p weird_dict",
"p weirder_dict", "c"])

@jax.jit
def f():
weird_dict = {(lambda x: x): 2., (lambda x: x * 2): 3}
weirder_dict = {(lambda x: x): weird_dict}
x = 2.
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
del weirder_dict
return x
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) 2.0
\(jdb\) <cant_flatten>
\(jdb\) <cant_flatten>
\(jdb\) """)
_ = f()
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 18f164f

Please sign in to comment.