Skip to content

Commit

Permalink
New lowering APIs for pjit
Browse files Browse the repository at this point in the history
This is the first in a series of refactoring patches that add the new AOT APIs
to all JIT-like transforms in JAX. I'm sending this early, because I expect that
it will come in handy when adding reverse-mode AD support for pjit.

PiperOrigin-RevId: 395510449
  • Loading branch information
apaszke authored and jax authors committed Sep 8, 2021
1 parent e869e5e commit 5b4757d
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 95 deletions.
31 changes: 16 additions & 15 deletions jax/experimental/maps.py
Expand Up @@ -560,7 +560,7 @@ def normalize_resource(r) -> ResourceAxisName:
has_input_rank_assertions = any(spec.expected_rank is not None for spec in in_axes_entries)
has_output_rank_assertions = any(spec.expected_rank is not None for spec in out_axes_entries)

def fun_mapped(*args):
def infer_params(*args):
# Putting this outside of fun_mapped would make resources lexically scoped
resource_env = thread_resources.env
available_resources = set(resource_env.shape.keys())
Expand Down Expand Up @@ -609,8 +609,7 @@ def fun_mapped(*args):
raise ValueError(f"xmap argument has an in_axes specification of {spec.user_repr}, "
f"which asserts that it should be of rank {spec.expected_rank}, "
f"but the argument has rank {arg.ndim} (and shape {arg.shape})")
out_flat = xmap_p.bind(
fun_flat, *args_flat,
params = dict(
name=getattr(fun, '__name__', '<unnamed function>'),
in_axes=tuple(in_axes_flat),
out_axes_thunk=out_axes_thunk,
Expand All @@ -621,14 +620,22 @@ def fun_mapped(*args):
backend=backend,
spmd_in_axes=None,
spmd_out_axes_thunk=None)
return fun_flat, args_flat, params, out_tree

def verify_outputs(out_flat, out_tree, params):
if has_output_rank_assertions:
for out, spec in zip(out_flat, out_axes_thunk()):
for out, spec in zip(out_flat, params['out_axes_thunk']()):
if spec.expected_rank is not None and spec.expected_rank != out.ndim:
raise ValueError(f"xmap output has an out_axes specification of {spec.user_repr}, "
f"which asserts that it should be of rank {spec.expected_rank}, "
f"but the output has rank {out.ndim} (and shape {out.shape})")
return tree_unflatten(out_tree(), out_flat)

def fun_mapped(*args):
fun_flat, args_flat, params, out_tree = infer_params(*args)
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
return verify_outputs(out_flat, out_tree, params)

# Decorate fun_mapped
for loop_params in reversed(anon_serial_loops):
fun_mapped = serial_loop(*loop_params)(fun_mapped)
Expand Down Expand Up @@ -689,17 +696,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
if used_mesh_axes:
assert spmd_in_axes is None and spmd_out_axes_thunk is None # No outer xmaps, so should be None
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
return pxla.mesh_callable(f,
name,
backend,
resource_env.physical_mesh,
mesh_in_axes,
mesh_out_axes,
donated_invars,
use_spmd_lowering,
*in_avals,
tile_by_mesh_axes=True,
do_resource_typecheck=None)
return pxla.lower_mesh_computation(
f, name, resource_env.physical_mesh,
mesh_in_axes, mesh_out_axes, donated_invars,
use_spmd_lowering, in_avals,
tile_by_mesh_axes=True, do_resource_typecheck=None).compile().unsafe_call
else:
return xla._xla_callable(f, None, backend, name, donated_invars,
*((a, None) for a in in_avals))
Expand Down
45 changes: 27 additions & 18 deletions jax/experimental/pjit.py
Expand Up @@ -178,10 +178,9 @@ def pjit(fun: Callable,
donate_argnums = _ensure_index_tuple(donate_argnums)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)

@wraps(fun)
def wrapped(*args, **kwargs):
def infer_params(*args, **kwargs):
if kwargs:
raise NotImplementedError("pjit over kwargs not yet supported")
raise NotImplementedError("pjit does not support kwargs")
if max(static_argnums + donate_argnums, default=-1) >= len(args):
raise ValueError(f"jitted function has static_argnums={static_argnums}, "
f"donate_argnums={donate_argnums} but "
Expand Down Expand Up @@ -214,16 +213,28 @@ def wrapped(*args, **kwargs):
_pjit_jaxpr(flat_fun, mesh, local_in_avals,
in_tree, hashable_pytree(in_axis_resources),
HashableFunction(out_tree, closure=()), hashable_pytree(out_axis_resources))

out = pjit_p.bind(
*args_flat,
params = dict(
jaxpr=jaxpr,
in_axis_resources=in_axis_resources_flat,
out_axis_resources=out_axis_resources_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=flat_fun.__name__)
return tree_unflatten(out_tree(), out)
return args_flat, params, out_tree()

@wraps(fun)
def wrapped(*args, **kwargs):
args_flat, params, out_tree = infer_params(*args, **kwargs)
out = pjit_p.bind(*args_flat, **params)
return tree_unflatten(out_tree, out)

def lower(*args, **kwargs):
args_flat, params, out_tree = infer_params(*args, **kwargs)
return _pjit_lower(
params['jaxpr'], params['in_axis_resources'],
params['out_axis_resources'], params['resource_env'],
params['donated_invars'], params['name'])
wrapped.lower = lower

return wrapped

Expand Down Expand Up @@ -399,36 +410,34 @@ def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape
def _pjit_call_impl(*args, jaxpr,
in_axis_resources, out_axis_resources,
resource_env, donated_invars, name):
compiled = pjit_callable(
compiled = _pjit_lower(
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name)
resource_env, donated_invars, name).compile()
distributed_debug_log(("Running pjit'd function", name),
("mesh", resource_env.physical_mesh))
return compiled(*args)
return compiled.unsafe_call(*args)
pjit_p.def_impl(_pjit_call_impl)

@cache()
def pjit_callable(
def _pjit_lower(
jaxpr: core.ClosedJaxpr,
in_axis_resources: Tuple[ParsedPartitionSpec, ...],
out_axis_resources: Tuple[ParsedPartitionSpec, ...],
resource_env,
donated_invars,
name: str):

in_axes = [get_array_mapping(axes) for axes in in_axis_resources]
out_axes = [get_array_mapping(axes) for axes in out_axis_resources]
f = core.jaxpr_as_fun(jaxpr)
f.__name__ = name
fun = lu.wrap_init(f)
local_in_avals = global_to_local(resource_env.physical_mesh,
jaxpr.in_avals, in_axis_resources)
# TODO(skye): allow for using a submesh of physical_mesh
return pxla.mesh_callable(fun, name, None, resource_env.physical_mesh,
in_axes, out_axes, donated_invars,
True, *local_in_avals, tile_by_mesh_axes=False,
do_resource_typecheck="pjit")

return pxla.lower_mesh_computation(
fun, name, resource_env.physical_mesh,
in_axes, out_axes, donated_invars,
True, local_in_avals, tile_by_mesh_axes=False,
do_resource_typecheck="pjit")


def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env, **_):
Expand Down
148 changes: 86 additions & 62 deletions jax/interpreters/pxla.py
Expand Up @@ -36,7 +36,6 @@
import threading
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional,
Sequence, Set, Tuple, Type, Union, Iterable)

from absl import logging
import numpy as np

Expand Down Expand Up @@ -1529,17 +1528,20 @@ def vtile_by_mesh(fun: lu.WrappedFun,
main_type=SPMDBatchTrace)
return fun

def mesh_callable(fun: lu.WrappedFun,
transformed_name: str,
backend_name: Optional[str],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Union[Sequence[ArrayMapping], Callable[[], Sequence[ArrayMapping]]],
donated_invars: Sequence[bool],
spmd_lowering: bool,
*local_in_untiled_avals,
tile_by_mesh_axes: bool,
do_resource_typecheck: Optional[str]):
def lower_mesh_computation(
fun: lu.WrappedFun,
transformed_name: str,
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Union[Sequence[ArrayMapping], Callable[[], Sequence[ArrayMapping]]],
donated_invars: Sequence[bool],
spmd_lowering: bool,
local_in_untiled_avals: Sequence[core.ShapedArray],
tile_by_mesh_axes: bool,
do_resource_typecheck: Optional[str]):
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])

local_mesh = mesh.local_mesh
global_axis_sizes = mesh.shape
local_axis_sizes = local_mesh.shape
Expand Down Expand Up @@ -1615,72 +1617,94 @@ def mesh_callable(fun: lu.WrappedFun,
donated_invars=donated_invars)
with core.extend_axis_env_nd(mesh.shape.items()):
out_nodes = xla.jaxpr_subcomp(
c, jaxpr, backend_name, axis_env, xla_consts,
c, jaxpr, backend.platform, axis_env, xla_consts,
extend_name_stack(wrap_name(transformed_name, 'xmap')), *xla_args)
if backend_name is None:
backend = xb.get_device_backend(mesh.devices.flat[0])
else:
backend = xb.get_backend(backend_name)
if spmd_lowering:
out_partitions_t = xb.tuple_sharding_proto(out_partitions)
out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)
else:
out_tuple = xops.Tuple(c, out_nodes)

if backend.platform in ("gpu", "tpu"):
xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
# TODO: Warn about unused donations?

built = c.Build(out_tuple)
return MeshComputation(
built, mesh, local_in_untiled_avals,
local_out_untiled_avals, in_axes, out_axes,
spmd_lowering, tuple_args)

return compile_and_wrap_mesh_hlo(built, backend, mesh, local_in_untiled_avals,
local_out_untiled_avals, in_axes, out_axes,
spmd_lowering, tuple_args)

class MeshComputation:
def __init__(self, hlo, *compile_args):
self._executable = None
self.hlo = hlo
self.compile_args = compile_args

def compile_and_wrap_mesh_hlo(computation: xc.XlaComputation, backend,
mesh: Mesh,
local_in_untiled_avals: Sequence[ShapedArray],
local_out_untiled_avals: Sequence[ShapedArray],
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
spmd_lowering: bool, tuple_args: bool):
local_mesh = mesh.local_mesh
local_axis_sizes = local_mesh.shape
if spmd_lowering:
num_replicas, num_partitions = 1, mesh.size
num_local_replicas, num_local_partitions = 1, local_mesh.size
else:
num_replicas, num_partitions = mesh.size, 1
num_local_replicas, num_local_partitions = local_mesh.size, 1
device_assignment = mesh.device_ids.reshape((num_replicas, num_partitions))
compile_options = xb.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,
use_spmd_partitioning=spmd_lowering,
)
compile_options.parameter_is_tupled_arguments = tuple_args
def compile(self):
if self._executable is None:
self._executable = MeshExecutable(self.hlo, *self.compile_args)
return self._executable

local_sharding_spec = mesh_sharding_specs(local_axis_sizes, mesh.axis_names)
local_input_specs = [local_sharding_spec(aval, aval_in_axes)
if aval is not core.abstract_unit else None
for aval, aval_in_axes in safe_zip(local_in_untiled_avals, in_axes)]
input_indices = [spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in safe_zip(local_in_untiled_avals, local_input_specs)]

local_output_specs = [local_sharding_spec(aval, aval_out_axes)
for aval, aval_out_axes in safe_zip(local_out_untiled_avals, out_axes)]
handle_outs = avals_to_results_handler(num_local_replicas, num_local_partitions,
local_output_specs, local_out_untiled_avals)
class MeshExecutable:
def __init__(self,
computation: xc.XlaComputation,
mesh: Mesh,
local_in_untiled_avals: Sequence[ShapedArray],
local_out_untiled_avals: Sequence[ShapedArray],
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
spmd_lowering: bool, tuple_args: bool):
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])

local_mesh = mesh.local_mesh
local_axis_sizes = local_mesh.shape
if spmd_lowering:
num_replicas, num_partitions = 1, mesh.size
num_local_replicas, num_local_partitions = 1, local_mesh.size
else:
num_replicas, num_partitions = mesh.size, 1
num_local_replicas, num_local_partitions = local_mesh.size, 1
device_assignment = mesh.device_ids.reshape((num_replicas, num_partitions))
compile_options = xb.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,
use_spmd_partitioning=spmd_lowering,
)
compile_options.parameter_is_tupled_arguments = tuple_args

local_sharding_spec = mesh_sharding_specs(local_axis_sizes, mesh.axis_names)
local_input_specs = [local_sharding_spec(aval, aval_in_axes)
if aval is not core.abstract_unit else None
for aval, aval_in_axes in safe_zip(local_in_untiled_avals, in_axes)]
input_indices = [spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in safe_zip(local_in_untiled_avals, local_input_specs)]

local_output_specs = [local_sharding_spec(aval, aval_out_axes)
for aval, aval_out_axes in safe_zip(local_out_untiled_avals, out_axes)]
handle_outs = avals_to_results_handler(num_local_replicas, num_local_partitions,
local_output_specs, local_out_untiled_avals)

if hasattr(backend, "compile_replicated"):
self.unsafe_call = backend.compile_replicated(
computation, compile_options,
input_indices, local_input_specs,
handle_outs)
else:
compiled = xla.compile_or_get_cached(backend, computation, compile_options)
handle_args = InputsHandler(compiled.local_devices(), local_input_specs,
input_indices)
self.unsafe_call = partial(execute_replicated, compiled, backend, handle_args, handle_outs)

def __call__(self, *args):
# TODO(apaszke): Validate arguments
return self.unsafe_call(*args)

if hasattr(backend, "compile_replicated"):
return backend.compile_replicated(computation, compile_options,
input_indices, local_input_specs,
handle_outs)
compiled = xla.compile_or_get_cached(backend, computation, compile_options)
handle_args = InputsHandler(compiled.local_devices(), local_input_specs,
input_indices)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)

_forbidden_primitives = {
'xla_pmap': 'pmap',
Expand Down

0 comments on commit 5b4757d

Please sign in to comment.