From 4b4dc3c745e2d43f966d974dad601c28bddddce8 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 18 Mar 2022 21:35:55 -0700 Subject: [PATCH] track input argument information in one tree at each AOT stage Both `Lowered` and `Compiled` carry information about input arguments for which the underlying computation was lowered (namely avals, donation bits, and the input pytree structure today). This change rearranges some internals so that all of this information is held together in a single pytree of structs. Doing so simplifies the fields of both stage classes and helps ensure the input argument properties are consistent with one another (e.g. now they must share a consistent pytree structure by definition). --- jax/_src/api.py | 9 +-- jax/_src/stages.py | 155 +++++++++++++++++++++++---------------- jax/experimental/maps.py | 4 +- jax/experimental/pjit.py | 5 +- 4 files changed, 102 insertions(+), 71 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 694345e11963..9c913e6eb264 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -533,8 +533,8 @@ def lower(*args, **kwargs) -> stages.Lowered: computation = dispatch.lower_xla_callable(flat_fun, device, backend, name, donated_invars, *arg_specs_and_device) - return stages.Lowered(computation, in_tree, in_tree.unflatten(arg_specs), - out_tree(), donate_argnums) + return stages.Lowered.from_flat_info( + computation, in_tree, arg_specs, donate_argnums, out_tree()) return lower @@ -2059,9 +2059,8 @@ def lower(*args, **kwargs) -> stages.Lowered: donated_invars=p.donated_invars, global_arg_shapes=p.global_arg_shapes_flat, avals=abstract_args) - return stages.Lowered( - computation, p.in_tree, p.in_tree.unflatten(abstract_args), - p.out_tree(), donate_tuple) + return stages.Lowered.from_flat_info( + computation, p.in_tree, abstract_args, donate_tuple, p.out_tree()) return lower diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 7f955f4a812c..bd2e468ab278 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Any, Optional, Tuple, Union from typing_extensions import Protocol from jax import core +from jax import tree_util from jax.interpreters import pxla -from jax.tree_util import PyTreeDef, tree_flatten, tree_unflatten from jax._src import dispatch from jax._src import source_info_util @@ -33,7 +34,51 @@ zip, unsafe_zip = util.safe_zip, zip -class Compiled: +Computation = Union[dispatch.XlaComputation, + pxla.MeshComputation, + pxla.PmapComputation] + +Executable = Union[dispatch.XlaCompiledComputation, + pxla.MeshExecutable, + pxla.PmapExecutable] + + +@dataclass +class ArgInfo: + aval: core.ShapedArray + donated: bool + + +class Stage: + args_info: Any # PyTree of ArgInfo + + @property + def in_tree(self): + """``PyTreeDef`` of the pair (positional arguments, keyword arguments).""" + return tree_util.tree_structure(self.args_info) + + @property + def in_avals(self): + """Tree of input avals.""" + return tree_util.tree_map(lambda x: x.aval, self.args_info) + + @property + def donate_argnums(self): + """Flat tuple of donated argument indices.""" + return tuple([ + i for i, x in enumerate(tree_util.tree_leaves(self.args_info)) + if x.donated]) + + +def make_args_info(in_tree, in_avals, donate_argnums): + donate_argnums = frozenset(donate_argnums) + flat_avals, _ = tree_util.tree_flatten(in_avals) # todo: remove + return in_tree.unflatten([ + ArgInfo(aval, i in donate_argnums) + for i, aval in enumerate(flat_avals)]) + + +class Compiled(Stage): """Compiled representation of a function specialized to types/values. A compiled computation is associated with an executable and the @@ -41,30 +86,18 @@ class Compiled: common API for querying properties of compiled computations across JAX's various compilation paths and backends. """ - __slots__ = [ - "in_tree", "in_avals", "out_tree", "donate_argnums", "_executable", - "_no_kwargs" - ] - - # The PyTreeDef of the (positional arguments, keyword arguments). - in_tree: PyTreeDef - # The nested input tree of `ShapedArray` abstract values of (args, kwargs). - in_avals: Any - out_tree: PyTreeDef - donate_argnums: Tuple[int] - _executable: Union[dispatch.XlaCompiledComputation, - pxla.MeshExecutable, - pxla.PmapExecutable] + __slots__ = ["args_info", "out_tree", "_executable", "_no_kwargs"] + + args_info: Any # PyTree of ArgInfo + out_tree: tree_util.PyTreeDef + _executable: Executable _no_kwargs: bool - def __init__(self, executable, in_tree, in_avals, out_tree, donate_argnums, - no_kwargs=False): + def __init__(self, executable, args_info, out_tree, no_kwargs=False): self._executable = executable - self.in_tree = in_tree - self.in_avals = in_avals - self.out_tree = out_tree - self.donate_argnums = donate_argnums self._no_kwargs = no_kwargs + self.args_info = args_info + self.out_tree = out_tree def compiler_ir(self): """Post-compilation IR. @@ -90,7 +123,7 @@ def __call__(self, *args, **kwargs): raise NotImplementedError( "function was compiled by a transformation that does not support " f"keyword arguments, but called with keyword arguments: {kws}") - args_flat, in_tree = tree_flatten((args, kwargs)) + args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) if in_tree != self.in_tree: # TODO(frostig): provide more info about the source function # and transformation @@ -115,10 +148,10 @@ def __call__(self, *args, **kwargs): f"Tracer type {type(arg)}.") else: raise - return tree_unflatten(self.out_tree, out_flat) + return tree_util.tree_unflatten(self.out_tree, out_flat) -class Lowered: +class Lowered(Stage): """Lowering of a function specialized to argument types and values. A lowering is a computation ready for compilation. This class @@ -127,52 +160,50 @@ class Lowered: querying properties of lowered computations across JAX's various lowering paths (``jit``, ``pmap``, etc.). """ - __slots__ = [ - "in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering", - "_no_kwargs" - ] - - # The PyTreeDef of the (positional arguments, keyword arguments). - # - # To get the individual PyTreeDef for the positional an keyword arguments, - # use `in_tree.children() which will return you a sequence of 2 PyTreeDef. - in_tree: PyTreeDef - # The nested input tree of `ShapedArray` abstract values of (args, kwargs). - in_avals: Any - out_tree: PyTreeDef - donate_argnums: Tuple[int] - _lowering: Union[dispatch.XlaComputation, - pxla.MeshComputation, - pxla.PmapComputation] + __slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"] + + args_info: Any # PyTree of ArgInfo + out_tree: tree_util.PyTreeDef + _lowering: Computation _no_kwargs: bool - def __init__(self, - lowering, - in_tree: PyTreeDef, - in_avals, - out_tree: PyTreeDef, - donate_argnums: Tuple[int], - no_kwargs: bool = False): - """Initializer. + @staticmethod + def from_flat_info(lowering: Computation, + in_tree: tree_util.PyTreeDef, + in_avals, + donate_argnums: Tuple[int], + out_tree: tree_util.PyTreeDef, + no_kwargs: bool = False): + """Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef. Args: - in_tree: The `PyTreeDef` of (args, kwargs). - out_tree: The `PyTreeDef` of the outputs. - no_kwargs: If `True` the transformation, and the `Compiled` returned from - this object will not support keyword arguments (an error will be raised - if some are provided). + in_tree: The ``PyTreeDef`` of (args, kwargs). + out_tree: The ``PyTreeDef`` of the outputs. + no_kwargs: If ``True`` the transformation, and the + ``Compiled`` returned from this object will not support keyword + arguments (an error will be raised if some are provided). """ + return Lowered(lowering, make_args_info(in_tree, in_avals, donate_argnums), + out_tree, no_kwargs=no_kwargs) + donate_argnums = frozenset(donate_argnums) + flat_avals, _ = tree_util.tree_flatten(in_avals) # todo: remove + return in_tree.unflatten([ + ArgInfo(aval, i in donate_argnums) + for i, aval in enumerate(flat_avals)]) + + def __init__(self, + lowering: Computation, + args_info, # PyTreee of ArgInfo + out_tree: tree_util.PyTreeDef, + no_kwargs: bool = False): self._lowering = lowering - self.in_tree = in_tree - self.in_avals = in_avals - self.out_tree = out_tree - self.donate_argnums = donate_argnums self._no_kwargs = no_kwargs + self.args_info = args_info + self.out_tree = out_tree def compile(self) -> Compiled: - return Compiled( - self._lowering.compile(), self.in_tree, self.in_avals, - self.out_tree, self.donate_argnums, self._no_kwargs) + return Compiled(self._lowering.compile(), self.args_info, + self.out_tree, no_kwargs=self._no_kwargs) def compiler_ir(self, dialect: Optional[str] = None): if dialect is None or dialect == "mhlo": diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 098a4c498177..22bae4540f36 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -670,8 +670,8 @@ def lower(*args): in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) in_avals = in_tree.unflatten(avals_flat) - return stages.Lowered( - computation, in_tree, in_avals, out_tree(), donate_argnums, + return stages.Lowered.from_flat_info( + computation, in_tree, in_avals, donate_argnums, out_tree(), no_kwargs=True) fun_mapped.lower = lower diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index d65aed262e37..33472a0aaee1 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -283,8 +283,9 @@ def lower(*args, **kwargs): args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals) - return stages.Lowered(lowering, args_kwargs_in_tree, local_in_avals, - out_tree, donate_argnums, no_kwargs=True) + return stages.Lowered.from_flat_info( + lowering, args_kwargs_in_tree, local_in_avals, donate_argnums, out_tree, + no_kwargs=True) wrapped.lower = lower return wrapped