Skip to content

Commit

Permalink
make xla.aval_to_result_handler return number of args
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 18, 2020
1 parent 5792b09 commit 10eafcf
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 32 deletions.
84 changes: 53 additions & 31 deletions jax/interpreters/xla.py
Expand Up @@ -74,7 +74,8 @@ def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()

# unit representation
def _make_unit(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool')))
def _make_unit(c):
return xb.constant(c, np.zeros((), dtype=np.dtype('bool')))
def _make_abstract_unit(_):
return (xc.Shape.array_shape(np.dtype('bool'), ()),)
def _device_put_unit(_, device):
Expand All @@ -100,17 +101,17 @@ def aval_to_xla_shapes(aval):
ConcreteArray: _make_array_shape,
}

def aval_to_result_handler(device: Optional[Device], aval: core.ShapedArray):
def aval_to_result_handler(device: Optional[Device], aval: core.ShapedArray) -> Tuple[int, Callable]:
try:
return xla_result_handlers[type(aval)](device, aval)
except KeyError as err:
raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err

def array_result_handler(device: Optional[Device], aval: core.ShapedArray):
return partial(DeviceArray, raise_to_shaped(aval), device, lazy.array(aval.shape))
return (1, partial(DeviceArray, raise_to_shaped(aval), device, lazy.array(aval.shape)))

xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
core.AbstractUnit: lambda _, __: lambda _: core.unit,
xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Tuple[int, Callable]]] = {
core.AbstractUnit: lambda _, __: (1, lambda _: core.unit),
ShapedArray: array_result_handler,
ConcreteArray: array_result_handler,
}
Expand All @@ -129,7 +130,9 @@ def _device_put_array(x, device: Optional[Device]):
def _device_put_scalar(x, device):
return _device_put_array(dtypes.coerce_to_array(x), device)

device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {core.Unit: _device_put_unit}
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
core.Unit: _device_put_unit
}
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)

Expand Down Expand Up @@ -225,6 +228,13 @@ def apply_primitive(prim, *args, **params):
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
return compiled_fun(*args)


def _partition_outputs(nouts, outs):
assert sum(nouts) == len(outs), "Internal error: sum(nouts) should equal len(outs)."
outs = iter(outs)
return [[next(outs) for _ in range(nout)] for nout in nouts]


@cache()
def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
Optional[Device]], **params):
Expand All @@ -240,10 +250,12 @@ def prim_fun(*args):
*arg_specs)
aval_out = prim.abstract_eval(*avals, **params)
if not prim.multiple_results:
handle_result = aval_to_result_handler(device, aval_out)
nouts, handle_result = aval_to_result_handler(device, aval_out)
assert nouts == 1, "Internal error: expected nouts == 1"
else:
handlers = map(partial(aval_to_result_handler, device), aval_out)
handle_result = lambda xs: tuple(h(x) for h, x in zip(handlers, xs))
nouts, handlers = unzip2(map(partial(aval_to_result_handler, device), aval_out))
handle_result = lambda *bufses:\
tuple(h(*bufs) for h, bufs in zip(handlers, _partition_outputs(nouts, bufses)))
tuple_args = len(avals) > 100
if prim in initial_style_translations:
nreps = initial_style_primitive_replicas(params)
Expand Down Expand Up @@ -327,21 +339,19 @@ def backend_compile(backend, built_c, options):

def _execute_compiled_primitive(prim, compiled, result_handler, *args):
device, = compiled.local_devices()
input_bufs = list(it.chain(
*(device_put(x, device) for x in args if x is not token)))
input_bufs = [buf for x in args for buf in device_put(x, device) if x is not token]
out_bufs = compiled.execute(input_bufs)
if FLAGS.jax_debug_nans:
check_nans(prim, out_bufs)
return result_handler(out_bufs if prim.multiple_results else out_bufs[0])
return result_handler(*out_bufs)

def _execute_replicated_primitive(prim, compiled, result_handler, *args):
input_bufs = [
list(it.chain(*(device_put(x, device) for x in args if x is not token)))
[buf for x in args for buf in device_put(x, device) if x is not token]
for device in compiled.local_devices()]
out_buf = compiled.execute_on_local_devices(input_bufs)[0]
if not prim.multiple_results:
out_buf, = out_buf
return result_handler(out_buf)
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
return result_handler(*out_bufs)


def check_nans(prim, bufs):
for buf in bufs:
Expand Down Expand Up @@ -608,15 +618,15 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = device.platform if device else backend
if config.omnistaging_enabled:
result_handlers = tuple(aval_to_result_handler(device, a) for a in out_avals)
nouts, result_handlers = unzip2(map(partial(aval_to_result_handler, device), out_avals))
else:
result_handlers = tuple(map(partial(_pval_to_result_handler, device), pvals)) # type: ignore
nouts, result_handlers = unzip2(map(partial(_pval_to_result_handler, device), pvals))

# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to force their (potentially lazy) arguments.
if not jaxpr.eqns:
return partial(_execute_trivial, jaxpr, device, consts, result_handlers)
return partial(_execute_trivial, jaxpr, device, consts, nouts, result_handlers)

if not _on_exit:
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
Expand Down Expand Up @@ -666,9 +676,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
options.parameter_is_tupled_arguments = tuple_args
compiled = backend_compile(backend, built, options)
if nreps == 1:
return partial(_execute_compiled, compiled, result_handlers)
return partial(_execute_compiled, compiled, nouts, result_handlers)
else:
return partial(_execute_replicated, compiled, result_handlers)
return partial(_execute_replicated, compiled, nouts, result_handlers)

def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
"""Configures input/output "must" aliasing based on `donated_args`."""
Expand Down Expand Up @@ -758,23 +768,25 @@ def _xla_param(builder, param_num, xla_shape, replicated, partitions):
else:
return xb.with_sharding(builder, partitions, make_param)

def _execute_compiled(compiled: XlaExecutable, handlers, *args):
def _execute_compiled(compiled: XlaExecutable, nouts, handlers, *args):
device, = compiled.local_devices()
input_bufs = list(it.chain(
*(device_put(x, device) for x in args if x is not token)))
out_bufs = compiled.execute(input_bufs)
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
if FLAGS.jax_debug_nans:
check_nans(xla_call_p, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(nouts, out_bufs))]

def _execute_replicated(compiled: XlaExecutable, handlers, *args):
def _execute_replicated(compiled: XlaExecutable, nouts, handlers, *args):
input_bufs = [
list(it.chain(*(device_put(x, device) for x in args if x is not token)))
for device in compiled.local_devices()]
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
if FLAGS.jax_debug_nans:
check_nans(xla_call_p, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(nouts, out_bufs))]

def _execute_trivial(jaxpr, device: Optional[Device], consts, handlers, *args):
def _execute_trivial(jaxpr, device: Optional[Device], consts, nouts, handlers, *args):
env = {core.unitvar: core.unit}
map(env.setdefault, jaxpr.invars, args)
map(env.setdefault, jaxpr.constvars, consts)
Expand Down Expand Up @@ -924,7 +936,7 @@ class Token(object): pass
pytype_aval_mappings[Token] = lambda _: abstract_token
core.pytype_aval_mappings[Token] = lambda _: abstract_token
xla_shape_handlers[AbstractToken] = lambda _: (xc.Shape.token_shape(),)
xla_result_handlers[AbstractToken] = lambda _, __: lambda _: token
xla_result_handlers[AbstractToken] = lambda _, __: (1, lambda _: token)
canonicalize_dtype_handlers[Token] = identity


Expand Down Expand Up @@ -1222,7 +1234,8 @@ def _device_put_impl(x, device: Optional[Device] = None):
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
handler = aval_to_result_handler(device, a) # type: ignore[arg-type]
nouts, handler = aval_to_result_handler(device, a) # type: ignore[arg-type]
assert nouts == 1, "DeviceArray cannot handle aval with multiple buffers."
return handler(*device_put(x, device))

device_put_p = core.Primitive('device_put')
Expand Down Expand Up @@ -1292,6 +1305,15 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
parallel_translations[core.axis_index_p] = _axis_index_translation_rule # type: ignore

def _pval_to_result_handler(device, pval):
pv, const = pval
if pv is None:
const = _device_put_impl(const, device) if device else const
return (1, lambda _: const)
else:
return aval_to_result_handler(device, pv)

pe.staged_out_calls.add(xla_call_p)

@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
Expand Down
2 changes: 1 addition & 1 deletion jax/lax/lax.py
Expand Up @@ -1404,7 +1404,7 @@ def _device_put_raw(x):
return x
else:
aval = raise_to_shaped(core.get_aval(x))
return xla.array_result_handler(None, aval)(*xla.device_put(x))
return xla.array_result_handler(None, aval)[1](*xla.device_put(x))

def iota(dtype: DType, size: int) -> Array:
"""Wraps XLA's `Iota
Expand Down

0 comments on commit 10eafcf

Please sign in to comment.