Skip to content

Commit

Permalink
factor AOT types out to a stages module
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Mar 15, 2022
1 parent 5354a01 commit 0474884
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 173 deletions.
3 changes: 1 addition & 2 deletions jax/__init__.py
Expand Up @@ -62,7 +62,6 @@
checkpoint as checkpoint,
checkpoint_policies as checkpoint_policies,
closure_convert as closure_convert,
Compiled as Compiled,
curry, # TODO(phawkins): update users to avoid this.
custom_ivjp as custom_ivjp,
custom_gradient as custom_gradient,
Expand Down Expand Up @@ -92,7 +91,6 @@
jvp as jvp,
local_device_count as local_device_count,
local_devices as local_devices,
Lowered as Lowered,
linearize as linearize,
linear_transpose as linear_transpose,
make_jaxpr as make_jaxpr,
Expand Down Expand Up @@ -139,6 +137,7 @@
from jax import ops as ops
from jax import profiler as profiler
from jax import random as random
from jax import stages as stages
from jax import tree_util as tree_util
from jax import util as util

Expand Down
179 changes: 13 additions & 166 deletions jax/_src/api.py
Expand Up @@ -47,21 +47,22 @@
tree_multimap, treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)

from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
Expand Down Expand Up @@ -493,161 +494,6 @@ def get_device_info():
return f_jitted


class Lowered:
"""Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]

# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]
_no_kwargs: bool

def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
"""Initializer.
Args:
in_tree: The `PyTreeDef` of (args, kwargs).
out_tree: The `PyTreeDef` of the outputs.
no_kwargs: If `True` the transformation, and the `Compiled` returned from
this object will not support keyword arguments (an error will be raised
if some are provided).
"""
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def compile(self) -> 'Compiled':
return Compiled(
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)

def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
return self._lowering.mhlo()
elif dialect == "hlo":
return self._lowering.hlo()
else:
raise ValueError(f"Unknown dialect {dialect}")

# TODO(frostig): remove this in favor of `compiler_ir`
def _xla_computation(self):
return self._lowering.hlo()


class Compiled:
"""Compiled representation of a function specialized to types/values.
A compiled computation is associated with an executable and the
remaining information needed to execute it. It also provides a
common API for querying properties of compiled computations across
JAX's various compilation paths and backends.
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_executable",
"_no_kwargs"
]


# The PyTreeDef of the (positional arguments, keyword arguments).
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_executable: Union[dispatch.XlaCompiledComputation,
pxla.MeshExecutable,
pxla.PmapExecutable]
_no_kwargs: bool

def __init__(self, executable, in_tree, in_avals, out_tree, donate_argnums,
no_kwargs=False):
self._executable = executable
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def compiler_ir(self):
"""Post-compilation IR.
Compilation typically involves code transformation and
optimization. This method exists to reflect the compiler's
representation of the program after such passes, whenever
possible.
"""
return self._executable.xla_executable.hlo_modules()

def runtime_executable(self):
return self._executable.xla_executable

def _xla_executable(self):
# TODO(frostig): finalize API. For now, return the underlying
# executable directly via this method.
return self._executable.xla_executable

def __call__(self, *args, **kwargs):
if self._no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())
raise NotImplementedError(
'function was compiled by a transformation that does not support '
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_flatten((args, kwargs))
if in_tree != self.in_tree:
# TODO(frostig): provide more info about the source function
# and transformation
raise TypeError(
f'function compiled for {self.in_tree}, called with {in_tree}')
try:
out_flat = self._executable.call(*args_flat)
except TypeError as e:
# We can't transform ahead-of-time compiled calls, since we've
# lowered and compiled for a fixed function signature, and JAX
# transformations change signatures. We interpret a Tracer
# argument as an indication of a transformation attempt. We
# could check this before the executable call, but we'd rather
# avoid isinstance checks on the call path. Seeing a TypeError
# might mean that arguments have JAX-invalid types, which in
# turn might mean some are Tracers.
for arg in args_flat:
if isinstance(arg, core.Tracer):
raise TypeError(
'Cannot apply JAX transformations to a function lowered and '
'compiled for a particular signature. Detected argument of '
f'Tracer type {type(arg)}.')
else:
raise
return tree_unflatten(self.out_tree, out_flat)


def _jit_lower(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline):
"""Make a ``lower`` method for jitted functions."""
Expand All @@ -664,7 +510,7 @@ def arg_spec(x):
return aval, None

@api_boundary
def lower(*args, **kwargs) -> Lowered:
def lower(*args, **kwargs) -> stages.Lowered:
"""Lower this function for the given arguments.
A lowered function is staged out of Python and translated to a
Expand All @@ -687,8 +533,8 @@ def lower(*args, **kwargs) -> Lowered:
computation = dispatch.lower_xla_callable(flat_fun, device, backend, name,
donated_invars,
*arg_specs_and_device)
return Lowered(computation, in_tree, in_tree.unflatten(arg_specs),
out_tree(), donate_argnums)
return stages.Lowered(computation, in_tree, in_tree.unflatten(arg_specs),
out_tree(), donate_argnums)

return lower

Expand Down Expand Up @@ -2182,7 +2028,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
# this might naturally be a method, with ``fun`` as a ``self`` and
# all the other arguments stored as attributes.
@api_boundary
def lower(*args, **kwargs) -> Lowered:
def lower(*args, **kwargs) -> stages.Lowered:
"""Lower a parallel-mapped form of this function for the given arguments.
A parallel-mapped and lowered function is staged out of Python and
Expand All @@ -2208,8 +2054,9 @@ def lower(*args, **kwargs) -> Lowered:
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
avals=abstract_args)
return Lowered(computation, p.in_tree, p.in_tree.unflatten(abstract_args),
p.out_tree(), donate_tuple)
return stages.Lowered(
computation, p.in_tree, p.in_tree.unflatten(abstract_args),
p.out_tree(), donate_tuple)

return lower

Expand Down

0 comments on commit 0474884

Please sign in to comment.