diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py index 3f615cf12e45..5a8f0e109401 100644 --- a/jax/_src/debugger/core.py +++ b/jax/_src/debugger/core.py @@ -14,11 +14,10 @@ from __future__ import annotations import dataclasses -import functools import inspect import threading -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Hashable, List, Optional, Tuple from typing_extensions import Protocol import jax.numpy as jnp @@ -29,6 +28,46 @@ from jax._src import util import numpy as np + +@tree_util.register_pytree_node_class +class _DictWrapper: + keys: list[Hashable] + values: list[Any] + + def __init__(self, keys, values): + self._keys = keys + self._values = values + + def to_dict(self): + return dict(zip(self._keys, self._values)) + + def tree_flatten(self): + return self._values, self._keys + + @classmethod + def tree_unflatten(cls, keys, values): + return _DictWrapper(keys, values) + + +class _CantFlatten: + __repr__ = lambda _: "" +cant_flatten = _CantFlatten() + +def _safe_flatten_dict(dct: dict[Any, Any] + ) -> tuple[list[Any], tree_util.PyTreeDef]: + # We avoid comparison between keys by just using the original order + keys, values = [], [] + for key, value in dct.items(): + try: + tree_util.tree_leaves(value) + except: + # If flattening fails, we substitute a sentinel object. + value = cant_flatten + keys.append(key) + values.append(value) + return tree_util.tree_flatten(_DictWrapper(keys, values)) + + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) class DebuggerFrame: @@ -42,22 +81,26 @@ class DebuggerFrame: offset: Optional[int] def tree_flatten(self): - flat_vars, vars_tree = tree_util.tree_flatten((self.locals, self.globals)) + flat_locals, locals_tree = _safe_flatten_dict(self.locals) + flat_globals, globals_tree = _safe_flatten_dict(self.globals) + flat_vars = flat_locals + flat_globals is_valid = [ isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray)) for l in flat_vars ] invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars) - return valid_vars, (is_valid, invalid_vars, vars_tree, self.filename, - self.code_context, self.source, self.lineno, - self.offset) + return valid_vars, (is_valid, invalid_vars, locals_tree, globals_tree, + len(flat_locals), self.filename, self.code_context, + self.source, self.lineno, self.offset) @classmethod def tree_unflatten(cls, info, valid_vars): - (is_valid, invalid_vars, vars_tree, filename, code_context, source, - lineno, offset) = info + (is_valid, invalid_vars, locals_tree, globals_tree, num_locals, filename, + code_context, source, lineno, offset) = info flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars) - locals_, globals_ = tree_util.tree_unflatten(vars_tree, flat_vars) + flat_locals, flat_globals = util.split_list(flat_vars, [num_locals]) + locals_ = tree_util.tree_unflatten(locals_tree, flat_locals).to_dict() + globals_ = tree_util.tree_unflatten(globals_tree, flat_globals).to_dict() return DebuggerFrame(filename, locals_, globals_, code_context, source, lineno, offset) diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 462e502f7487..3065ee5b84f0 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -376,7 +376,6 @@ def f(): x = 2 g() return x - _ = f() expected = _format_multiline(r""" Entering jdb: @@ -409,5 +408,27 @@ def f2(): jax.effects_barrier() self.assertRegex(stdout.getvalue(), expected) + def test_can_handle_dictionaries_with_unsortable_keys(self): + stdin, stdout = make_fake_stdin_stdout(["p x", "p weird_dict", + "p weirder_dict", "c"]) + + @jax.jit + def f(): + weird_dict = {(lambda x: x): 2., (lambda x: x * 2): 3} + weirder_dict = {(lambda x: x): weird_dict} + x = 2. + debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli") + del weirder_dict + return x + expected = _format_multiline(r""" + Entering jdb: + \(jdb\) 2.0 + \(jdb\) + \(jdb\) + \(jdb\) """) + _ = f() + jax.effects_barrier() + self.assertRegex(stdout.getvalue(), expected) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())