Skip to content

Commit

Permalink
track input argument information in one tree at each AOT stage
Browse files Browse the repository at this point in the history
Both `Lowered` and `Compiled` carry information about input arguments
for which the underlying computation was lowered (namely avals,
donation bits, and the input pytree structure today). This change
rearranges some internals so that all of this information is held
together in a single pytree of structs. Doing so simplifies the fields
of both stage classes and helps ensure the input argument properties
are consistent with one another (e.g. now they must share a consistent
pytree structure by definition).
  • Loading branch information
froystig committed Mar 20, 2022
1 parent e9f59ae commit 4b4dc3c
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 71 deletions.
9 changes: 4 additions & 5 deletions jax/_src/api.py
Expand Up @@ -533,8 +533,8 @@ def lower(*args, **kwargs) -> stages.Lowered:
computation = dispatch.lower_xla_callable(flat_fun, device, backend, name,
donated_invars,
*arg_specs_and_device)
return stages.Lowered(computation, in_tree, in_tree.unflatten(arg_specs),
out_tree(), donate_argnums)
return stages.Lowered.from_flat_info(
computation, in_tree, arg_specs, donate_argnums, out_tree())

return lower

Expand Down Expand Up @@ -2059,9 +2059,8 @@ def lower(*args, **kwargs) -> stages.Lowered:
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
avals=abstract_args)
return stages.Lowered(
computation, p.in_tree, p.in_tree.unflatten(abstract_args),
p.out_tree(), donate_tuple)
return stages.Lowered.from_flat_info(
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())

return lower

Expand Down
155 changes: 93 additions & 62 deletions jax/_src/stages.py
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
from typing_extensions import Protocol

from jax import core
from jax import tree_util
from jax.interpreters import pxla
from jax.tree_util import PyTreeDef, tree_flatten, tree_unflatten

from jax._src import dispatch
from jax._src import source_info_util
Expand All @@ -33,38 +34,70 @@
zip, unsafe_zip = util.safe_zip, zip


class Compiled:
Computation = Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]

Executable = Union[dispatch.XlaCompiledComputation,
pxla.MeshExecutable,
pxla.PmapExecutable]


@dataclass
class ArgInfo:
aval: core.ShapedArray
donated: bool


class Stage:
args_info: Any # PyTree of ArgInfo

@property
def in_tree(self):
"""``PyTreeDef`` of the pair (positional arguments, keyword arguments)."""
return tree_util.tree_structure(self.args_info)

@property
def in_avals(self):
"""Tree of input avals."""
return tree_util.tree_map(lambda x: x.aval, self.args_info)

@property
def donate_argnums(self):
"""Flat tuple of donated argument indices."""
return tuple([
i for i, x in enumerate(tree_util.tree_leaves(self.args_info))
if x.donated])


def make_args_info(in_tree, in_avals, donate_argnums):
donate_argnums = frozenset(donate_argnums)
flat_avals, _ = tree_util.tree_flatten(in_avals) # todo: remove
return in_tree.unflatten([
ArgInfo(aval, i in donate_argnums)
for i, aval in enumerate(flat_avals)])


class Compiled(Stage):
"""Compiled representation of a function specialized to types/values.
A compiled computation is associated with an executable and the
remaining information needed to execute it. It also provides a
common API for querying properties of compiled computations across
JAX's various compilation paths and backends.
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_executable",
"_no_kwargs"
]

# The PyTreeDef of the (positional arguments, keyword arguments).
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_executable: Union[dispatch.XlaCompiledComputation,
pxla.MeshExecutable,
pxla.PmapExecutable]
__slots__ = ["args_info", "out_tree", "_executable", "_no_kwargs"]

args_info: Any # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef
_executable: Executable
_no_kwargs: bool

def __init__(self, executable, in_tree, in_avals, out_tree, donate_argnums,
no_kwargs=False):
def __init__(self, executable, args_info, out_tree, no_kwargs=False):
self._executable = executable
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
self.args_info = args_info
self.out_tree = out_tree

def compiler_ir(self):
"""Post-compilation IR.
Expand All @@ -90,7 +123,7 @@ def __call__(self, *args, **kwargs):
raise NotImplementedError(
"function was compiled by a transformation that does not support "
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_flatten((args, kwargs))
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
if in_tree != self.in_tree:
# TODO(frostig): provide more info about the source function
# and transformation
Expand All @@ -115,10 +148,10 @@ def __call__(self, *args, **kwargs):
f"Tracer type {type(arg)}.")
else:
raise
return tree_unflatten(self.out_tree, out_flat)
return tree_util.tree_unflatten(self.out_tree, out_flat)


class Lowered:
class Lowered(Stage):
"""Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class
Expand All @@ -127,52 +160,50 @@ class Lowered:
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]

# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]
__slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"]

args_info: Any # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef
_lowering: Computation
_no_kwargs: bool

def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
"""Initializer.
@staticmethod
def from_flat_info(lowering: Computation,
in_tree: tree_util.PyTreeDef,
in_avals,
donate_argnums: Tuple[int],
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False):
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
Args:
in_tree: The `PyTreeDef` of (args, kwargs).
out_tree: The `PyTreeDef` of the outputs.
no_kwargs: If `True` the transformation, and the `Compiled` returned from
this object will not support keyword arguments (an error will be raised
if some are provided).
in_tree: The ``PyTreeDef`` of (args, kwargs).
out_tree: The ``PyTreeDef`` of the outputs.
no_kwargs: If ``True`` the transformation, and the
``Compiled`` returned from this object will not support keyword
arguments (an error will be raised if some are provided).
"""
return Lowered(lowering, make_args_info(in_tree, in_avals, donate_argnums),
out_tree, no_kwargs=no_kwargs)
donate_argnums = frozenset(donate_argnums)
flat_avals, _ = tree_util.tree_flatten(in_avals) # todo: remove
return in_tree.unflatten([
ArgInfo(aval, i in donate_argnums)
for i, aval in enumerate(flat_avals)])

def __init__(self,
lowering: Computation,
args_info, # PyTreee of ArgInfo
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False):
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
self.args_info = args_info
self.out_tree = out_tree

def compile(self) -> Compiled:
return Compiled(
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)
return Compiled(self._lowering.compile(), self.args_info,
self.out_tree, no_kwargs=self._no_kwargs)

def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/maps.py
Expand Up @@ -670,8 +670,8 @@ def lower(*args):

in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
in_avals = in_tree.unflatten(avals_flat)
return stages.Lowered(
computation, in_tree, in_avals, out_tree(), donate_argnums,
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree(),
no_kwargs=True)

fun_mapped.lower = lower
Expand Down
5 changes: 3 additions & 2 deletions jax/experimental/pjit.py
Expand Up @@ -283,8 +283,9 @@ def lower(*args, **kwargs):

args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
return stages.Lowered(lowering, args_kwargs_in_tree, local_in_avals,
out_tree, donate_argnums, no_kwargs=True)
return stages.Lowered.from_flat_info(
lowering, args_kwargs_in_tree, local_in_avals, donate_argnums, out_tree,
no_kwargs=True)

wrapped.lower = lower
return wrapped
Expand Down

0 comments on commit 4b4dc3c

Please sign in to comment.