Skip to content

Commit

Permalink
improve pjit in/out_axis_resources pytree errors
Browse files Browse the repository at this point in the history
This is an application of the utilities in #9372.
  • Loading branch information
mattjj committed Feb 9, 2022
1 parent fbda1a6 commit d57990e
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 47 deletions.
58 changes: 41 additions & 17 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import collections
import difflib
import functools
from functools import partial
import operator as op
from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, List,
Dict, Type, TypeVar, overload, TYPE_CHECKING, NamedTuple)
import textwrap

from jax._src.lib import pytree

Expand Down Expand Up @@ -380,9 +382,10 @@ def flatten_one_level(pytree: Any) -> Tuple[List[Any], Hashable]:
else:
raise ValueError(f"can't tree-flatten type: {type(pytree)}")

def prefix_errors(prefix_tree: Any, full_tree: Any
def prefix_errors(prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> List[Callable[[str], ValueError]]:
return list(_prefix_error(KeyPath(()), prefix_tree, full_tree))
return list(_prefix_error(KeyPath(()), prefix_tree, full_tree, is_leaf))

class KeyPathEntry(NamedTuple):
key: Any
Expand Down Expand Up @@ -437,38 +440,59 @@ def register_keypaths(ty: Type, handler: Callable[[Any], List[KeyPathEntry]]
register_keypaths(dict,
lambda dct: [GetitemKeyPathEntry(k) for k in sorted(dct)])

def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any
def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Iterable[Callable[[str], ValueError]]:
# A leaf is a valid prefix of any tree:
if treedef_is_leaf(tree_structure(prefix_tree)): return
if treedef_is_leaf(tree_structure(prefix_tree, is_leaf=is_leaf)): return

# The subtrees may disagree because their roots are of different types:
if type(prefix_tree) != type(full_tree):
yield lambda name: ValueError(
"pytree structure error: different types "
f"at {{name}}{key_path.pprint()}: "
f"prefix pytree {{name}} has type {type(prefix_tree)} "
f"where full pytree has type {type(full_tree)}.".format(name=name))
"pytree structure error: different types at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"but at the same key path the full pytree has a subtree of different type\n"
f" {type(full_tree)}.".format(name=name))
return # don't look for more errors in this subtree

# Or they may disagree if their roots have different numbers of children:
# Or they may disagree if their roots have different numbers of children (note
# that because both prefix_tree and full_tree have the same type at this
# point, and because prefix_tree is not a leaf, each can be flattened once):
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
"pytree structure error: different numbers of pytree children "
f"at {{name}}{key_path.pprint()}: "
f"prefix pytree {{name}} has {len(prefix_tree_children)} children where "
f"full pytree has {len(full_tree_children)} children.".format(name=name))
"pytree structure error: different numbers of pytree children at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with {len(prefix_tree_children)} children, "
f"but at the same key path the full pytree has a subtree of the same "
f"type but with {len(full_tree_children)} children.".format(name=name))
return # don't look for more errors in this subtree

# Or they may disagree if their roots have different pytree metadata:
if prefix_tree_meta != full_tree_meta:
prefix_tree_meta_str = str(prefix_tree_meta)
full_tree_meta_str = str(full_tree_meta)
metadata_diff = textwrap.indent(
'\n'.join(difflib.ndiff(prefix_tree_meta_str.splitlines(),
full_tree_meta_str.splitlines())),
prefix=" ")
yield lambda name: ValueError(
"pytree structure error: different pytree metadata "
f"at {{name}}{key_path.pprint()}: "
f"prefix pytree {{name}} has metadata {prefix_tree_meta} where "
f"full pytree has metadata {full_tree_meta}.".format(name=name))
"pytree structure error: different pytree metadata at key path\n"
f" {{name}}{key_path.pprint()}\n"
f"At that key path, the prefix pytree {{name}} has a subtree of type\n"
f" {type(prefix_tree)}\n"
f"with metadata\n"
f" {prefix_tree_meta_str}\n"
f"but at the same key path the full pytree has a subtree of the same "
f"type but with metadata\n"
f" {full_tree_meta_str}\n"
f"so the diff in the metadata at these pytree nodes is\n"
f"{metadata_diff}".format(name=name))
return # don't look for more errors in this subtree

# If the root types and numbers of children agree, there must be an error
Expand Down
64 changes: 57 additions & 7 deletions jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
from jax.interpreters import partial_eval as pe
from jax.interpreters.sharded_jit import PartitionSpec
from jax._src.lib import xla_client as xc
from jax.tree_util import tree_map, tree_flatten, tree_unflatten
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
treedef_is_leaf, tree_structure)
from jax._src.tree_util import prefix_errors
from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
wrap_name, wraps, distributed_debug_log,
split_list, cache, tuple_insert)
Expand Down Expand Up @@ -293,11 +295,60 @@ def flatten_axis_resources(what, tree, axis_resources, tupled_args):
try:
return tuple(flatten_axes(what, tree, axis_resources, tupled_args=tupled_args))
except ValueError:
pass
pass # Raise a tree prefix error below

# Tree leaves are always valid prefixes, so if there was a prefix error as
# assumed here, axis_resources must not be a leaf.
assert not treedef_is_leaf(tree_structure(axis_resources))

# Check the type directly rather than using isinstance because of namedtuples.
if tupled_args and (type(axis_resources) is not tuple or
len(axis_resources) != len(tree.children())):
# We know axis_resources is meant to be a tuple corresponding to the args
# tuple, but while it is a non-leaf pytree, either it wasn't a tuple or it
# wasn't the right length.
msg = (f"{what} specification must be a tree prefix of the positional "
f"arguments tuple passed to the `pjit`-decorated function. In "
f"particular, {what} must either be a None, a PartitionSpec, or "
f"a tuple of length equal to the number of positional arguments.")
# If `tree` represents an args tuple, then `axis_resources` must be a tuple.
# TODO(mattjj,apaszke): disable implicit list casts, remove 'or list' below
if type(axis_resources) is not tuple:
msg += f" But {what} is not a tuple: got {type(axis_resources)} instead."
elif len(axis_resources) != len(tree.children()):
msg += (f" But {what} is the wrong length: got a tuple or list of length "
f"{len(axis_resources)} for an args tuple of length "
f"{len(tree.children())}.")

# As an extra hint, let's check if the user just forgot to wrap
# in_axis_resources in a singleton tuple.
if len(tree.children()) == 1:
try: flatten_axes(what, tree, (axis_resources,))
except ValueError: pass # That's not the issue.
else:
msg += (f" Given the corresponding argument being "
f"passed, it looks like {what} might need to be wrapped in "
f"a singleton tuple.")

raise ValueError(msg)

# Replace axis_resources with unparsed versions to avoid revealing internal details
flatten_axes(what, tree, tree_map(lambda parsed: parsed.user_spec, axis_resources),
tupled_args=tupled_args)
raise AssertionError("Please open a bug request!") # This should be unreachable
axis_tree = tree_map(lambda parsed: parsed.user_spec, axis_resources)

# Because ecause we only have the `tree` treedef and not the full pytree here,
# we construct a dummy tree to compare against. Revise this in callers?
dummy_tree = tree_unflatten(tree, [PytreeLeaf()] * tree.num_leaves)
errors = prefix_errors(axis_tree, dummy_tree)
if errors:
e = errors[0] # Only show information about the first disagreement found.
raise e(what)

# At this point we've failed to find a tree prefix error.
assert False, "Please open a bug report!" # This should be unreachable.

class PytreeLeaf:
def __repr__(self): return "pytree leaf"


@lu.cache
def _pjit_jaxpr(fun, mesh, local_in_avals,
Expand Down Expand Up @@ -472,8 +523,7 @@ def __repr__(self):
def _prepare_axis_resources(axis_resources,
arg_name,
allow_unconstrained_dims=False):
# PyTrees don't treat None values as leaves, so we explicitly need
# to explicitly declare them as such
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
entries, treedef = tree_flatten(axis_resources, is_leaf=lambda x: x is None)
what = f"{arg_name} leaf specifications"
entries = [
Expand Down
3 changes: 3 additions & 0 deletions jax/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,7 @@
treedef_children as treedef_children,
treedef_is_leaf as treedef_is_leaf,
treedef_tuple as treedef_tuple,
register_keypaths as register_keypaths,
AttributeKeyPathEntry as AttributeKeyPathEntry,
GetitemKeyPathEntry as GetitemKeyPathEntry,
)
41 changes: 32 additions & 9 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,15 +1138,32 @@ def testEmptyMesh(self):
def testAxisResourcesMismatch(self):
x = jnp.ones([])
p = [None, None, None]

pjit(lambda x: x, (p,), p)([x, x, x]) # OK

error = re.escape(
r"pjit in_axis_resources specification must be a tree prefix of the "
r"corresponding value, got specification (None, None, None) for value "
r"tree PyTreeDef((*, *)). Note that pjit in_axis_resources that are "
r"non-trivial pytrees should always be wrapped in a tuple representing "
r"the argument list.")
"pjit in_axis_resources specification must be a tree prefix of the "
"positional arguments tuple passed to the `pjit`-decorated function. "
"In particular, pjit in_axis_resources must either be a None, a "
"PartitionSpec, or a tuple of length equal to the number of positional "
"arguments. But pjit in_axis_resources is the wrong length: got a "
"tuple or list of length 3 for an args tuple of length 2.")
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x, y: x, p, p)(x, x) # Error, but make sure we hint at tupling
pjit(lambda x, y: x, p, p)(x, x)

Foo = namedtuple('Foo', ['x'])
error = "in_axis_resources is not a tuple.*might need to be wrapped"
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, Foo(None), Foo(None))(Foo(x))

pjit(lambda x: x, (Foo(None),), Foo(None))(Foo(x)) # OK w/ singleton tuple

# TODO(apaszke,mattjj): Disable implicit list casts and enable this
# error = ("it looks like pjit in_axis_resources might need to be wrapped in "
# "a singleton tuple.")
# with self.assertRaisesRegex(ValueError, error):
# pjit(lambda x, y: x, p, p)([x, x, x])

# TODO(apaszke): Disable implicit list casts and enable this
# error = re.escape(
# r"pjit in_axis_resources specification must be a tree prefix of the "
Expand All @@ -1158,10 +1175,16 @@ def testAxisResourcesMismatch(self):
# r"singleton tuple.")
# with self.assertRaisesRegex(ValueError, error):
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple

error = re.escape(
r"pjit out_axis_resources specification must be a tree prefix of the "
r"corresponding value, got specification [[None, None, None], None] for "
r"value tree PyTreeDef([*, *, *]).")
"pytree structure error: different numbers of pytree children at "
"key path\n"
" pjit out_axis_resources tree root\n"
"At that key path, the prefix pytree pjit out_axis_resources has a "
"subtree of type\n"
" <class 'list'>\n"
"with 2 children, but at the same key path the full pytree has a "
"subtree of the same type but with 3 children.")
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message

Expand Down
40 changes: 26 additions & 14 deletions tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,79 +439,91 @@ class TreePrefixErrorsTest(jtu.JaxTestCase):

def test_different_types(self):
e, = prefix_errors((1, 2), [1, 2])
expected = "pytree structure error: different types at in_axes tree root"
expected = ("pytree structure error: different types at key path\n"
" in_axes tree root")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')

def test_different_types_nested(self):
e, = prefix_errors(((1,), (2,)), ([3], (4,)))
expected = r"pytree structure error: different types at in_axes\[0\]"
expected = ("pytree structure error: different types at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')

def test_different_types_multiple(self):
e1, e2 = prefix_errors(((1,), (2,)), ([3], [4]))
expected = r"pytree structure error: different types at in_axes\[0\]"
expected = ("pytree structure error: different types at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e1('in_axes')
expected = r"pytree structure error: different types at in_axes\[1\]"
expected = ("pytree structure error: different types at key path\n"
r" in_axes\[1\]")
with self.assertRaisesRegex(ValueError, expected):
raise e2('in_axes')

def test_different_num_children(self):
e, = prefix_errors((1,), (2, 3))
expected = ("pytree structure error: different numbers of pytree children "
"at in_axes tree root")
"at key path\n"
" in_axes tree root")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')

def test_different_num_children_nested(self):
e, = prefix_errors([[1]], [[2, 3]])
expected = ("pytree structure error: different numbers of pytree children "
r"at in_axes\[0\]")
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')

def test_different_num_children_multiple(self):
e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]])
expected = ("pytree structure error: different numbers of pytree children "
r"at in_axes\[0\]")
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e1('in_axes')
expected = ("pytree structure error: different numbers of pytree children "
r"at in_axes\[1\]")
"at key path\n"
r" in_axes\[1\]")
with self.assertRaisesRegex(ValueError, expected):
raise e2('in_axes')

def test_different_metadata(self):
e, = prefix_errors({1: 2}, {3: 4})
expected = ("pytree structure error: different pytree metadata "
"at in_axes tree root")
"at key path\n"
" in_axes tree root")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')

def test_different_metadata_nested(self):
e, = prefix_errors([{1: 2}], [{3: 4}])
expected = ("pytree structure error: different pytree metadata "
r"at in_axes\[0\]")
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')

def test_different_metadata_multiple(self):
e1, e2 = prefix_errors([{1: 2}, {3: 4}], [{3: 4}, {5: 6}])
expected = ("pytree structure error: different pytree metadata "
r"at in_axes\[0\]")
"at key path\n"
r" in_axes\[0\]")
with self.assertRaisesRegex(ValueError, expected):
raise e1('in_axes')
expected = ("pytree structure error: different pytree metadata "
r"at in_axes\[1\]")
"at key path\n"
r" in_axes\[1\]")
with self.assertRaisesRegex(ValueError, expected):
raise e2('in_axes')

def test_fallback_keypath(self):
e, = prefix_errors(Special(1, [2]), Special(3, 4))
expected = ("pytree structure error: different types at "
r"in_axes\[<flat index 1>\]")
expected = ("pytree structure error: different types at key path\n"
r" in_axes\[<flat index 1>\]")
with self.assertRaisesRegex(ValueError, expected):
raise e('in_axes')

Expand Down

0 comments on commit d57990e

Please sign in to comment.