Skip to content

Commit

Permalink
Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (
Browse files Browse the repository at this point in the history
…#2896)

* allow in_axes=None for pmap in api.py

* wire in_axes=None through parallel_callable

* add test

* fix error string

* fixes

* fixes

* add test for nested pmap with in_axes

* test pmap still defaults to (implicit) out_axes=0
  • Loading branch information
jekbradbury committed May 1, 2020
1 parent 49a8901 commit 1cdd8f1
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 83 deletions.
143 changes: 81 additions & 62 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,58 +757,59 @@ def vmap(fun: Callable, in_axes=0, out_axes=0) -> Callable:
"or a (nested) tuple of those types, got {} and {} respectively.")
raise TypeError(msg.format(type(in_axes), type(out_axes)))

def _get_axis_size(i:int, shape: Tuple[int, ...], axis: int):
try:
return shape[axis]
except (IndexError, TypeError) as e:
raise ValueError(f"vmap got arg {i} of rank {len(shape)} but axis to be mapped {axis}") from e

def _check_axis_sizes(tree, vals, dims):
mapped_axis_sizes = {_get_axis_size(i, onp.shape(x), d) for i, (x, d) in enumerate(zip(vals, dims))
if d is not None}
try:
size, = mapped_axis_sizes
except ValueError as e:
if not mapped_axis_sizes:
raise ValueError("vmap must have at least one non-None in_axes") from e
msg = "vmap got inconsistent sizes for array axes to be mapped:\n{}"
# we switch the error message based on whether args is a tuple of arrays,
# in which case we can produce an error message based on argument indices,
# or if it has nested containers.
# TODO(mattjj,phawkins): add a way to inspect pytree kind more directly
if tree == tree_flatten((core.unit,) * tree.num_leaves)[1]:
lines1 = ["arg {} has shape {} and axis {} is to be mapped"
.format(i, x.shape, d) for i, (x, d) in enumerate(zip(vals, dims))]
sizes = collections.defaultdict(list)
for i, (x, d) in enumerate(zip(vals, dims)):
if d is not None:
sizes[x.shape[d]].append(i)
lines2 = ["{} {} {} {} to be mapped of size {}".format(
"args" if len(idxs) > 1 else "arg",
", ".join(map(str, idxs)),
"have" if len(idxs) > 1 else "has",
"axes" if len(idxs) > 1 else "an axis",
size)
for size, idxs in sizes.items()]
raise ValueError(msg.format("\n".join(lines1 + ["so"] + lines2))) from e
else:
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
sizes = tree_unflatten(tree, sizes)
raise ValueError(msg.format("the tree of axis sizes is:\n{}".format(sizes))) from e

@wraps(fun, docstr=docstr)
def batched_fun(*args):
args_flat, in_tree = tree_flatten(args)
f = lu.wrap_init(fun)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
in_axes_flat = _flatten_axes(in_tree, in_axes)
_check_axis_sizes(in_tree, args_flat, in_axes_flat)
_ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
lambda: _flatten_axes(out_tree(), out_axes))
return tree_unflatten(out_tree(), out_flat)

return batched_fun

def _get_axis_size(i:int, shape: Tuple[int, ...], axis: int):
try:
return shape[axis]
except (IndexError, TypeError) as e:
raise ValueError(f"vmap got arg {i} of rank {len(shape)} but axis to be mapped {axis}") from e

def _mapped_axis_size(tree, vals, dims, name):
mapped_axis_sizes = {_get_axis_size(i, onp.shape(x), d) for i, (x, d) in enumerate(zip(vals, dims))
if d is not None}
try:
size, = mapped_axis_sizes
return size
except ValueError as e:
if not mapped_axis_sizes:
raise ValueError("{} must have at least one non-None in_axes".format(name)) from e
msg = "{} got inconsistent sizes for array axes to be mapped:\n".format(name) + "{}"
# we switch the error message based on whether args is a tuple of arrays,
# in which case we can produce an error message based on argument indices,
# or if it has nested containers.
# TODO(mattjj,phawkins): add a way to inspect pytree kind more directly
if tree == tree_flatten((core.unit,) * tree.num_leaves)[1]:
lines1 = ["arg {} has shape {} and axis {} is to be mapped"
.format(i, x.shape, d) for i, (x, d) in enumerate(zip(vals, dims))]
sizes = collections.defaultdict(list)
for i, (x, d) in enumerate(zip(vals, dims)):
if d is not None:
sizes[x.shape[d]].append(i)
lines2 = ["{} {} {} {} to be mapped of size {}".format(
"args" if len(idxs) > 1 else "arg",
", ".join(map(str, idxs)),
"have" if len(idxs) > 1 else "has",
"axes" if len(idxs) > 1 else "an axis",
size)
for size, idxs in sizes.items()]
raise ValueError(msg.format("\n".join(lines1 + ["so"] + lines2))) from e
else:
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
sizes = tree_unflatten(tree, sizes)
raise ValueError(msg.format("the tree of axis sizes is:\n{}".format(sizes))) from e

def _flatten_axes(treedef, axis_tree):
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
Expand All @@ -830,7 +831,7 @@ def _flatten_axes(treedef, axis_tree):
return axes


def pmap(fun: Callable, axis_name: Optional[AxisName] = None,
def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices=None, backend: Optional[str] = None,
axis_size: Optional[int] = None) -> Callable:
Expand Down Expand Up @@ -876,6 +877,9 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None,
(tuple/list/dict) thereof.
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
in_axes: A nonnegative integer, None, or nested Python container thereof
that specifies which axes in the input to map over (see ``vmap``).
Currently, only 0 and None are supported axes for pmap.
static_broadcasted_argnums: An int or collection of ints specifying which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded.
Expand All @@ -895,12 +899,12 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None,
Returns:
A parallelized version of ``fun`` with arguments that correspond to those of
``fun`` but each with an additional leading array axis (with equal sizes)
``fun`` but with extra array axes at positions indicated by ``in_axes``
and with output that has an additional leading array axis (with the same
size).
For example, assuming 8 XLA devices are available, ``pmap`` can be used as a
map along a leading array axes:
map along a leading array axis:
>>> out = pmap(lambda x: x ** 2)(np.arange(8))
>>> print(out)
Expand All @@ -916,6 +920,18 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None,
[[ 1412. 1737.]
[ 1740. 2141.]]]
As with ``vmap``, using ``None`` in ``in_axes`` indicates that an argument
doesn't have an extra axis and should be broadcasted, rather than mapped,
across the replicas:
>>> x, y = np.arange(2.), 4.
>>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y)
>>> print(out)
([4., 5.], [8., 8.])
Note that ``pmap`` always returns values mapped over their leading axis,
equivalent to using ``out_axes=0`` in ``vmap``.
In addition to expressing pure maps, ``pmap`` can also be used to express
parallel single-program multiple-data (SPMD) programs that communicate via
collective operations. For example:
Expand Down Expand Up @@ -1008,10 +1024,17 @@ def f_pmapped(*args, **kwargs):
if static_broadcasted_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_broadcasted_argnums]
f, dyn_args = argnums_partial(f, dyn_argnums, args)
if isinstance(in_axes, tuple):
dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums)
else:
dyn_in_axes = in_axes
else:
dyn_args = args
dyn_args, dyn_in_axes = args, in_axes
args, in_tree = tree_flatten((dyn_args, kwargs))
local_axis_size = _pmap_axis_size(args)
in_axes_flat = _flatten_axes(in_tree, (dyn_in_axes, 0))
assert all(axis in (0, None) for axis in in_axes_flat), \
"pmap currently only supports mapping over the leading axis"
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap")
_check_args(args)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = pxla.xla_pmap(
Expand All @@ -1023,23 +1046,13 @@ def f_pmapped(*args, **kwargs):
global_axis_size=axis_size,
devices=tuple(devices) if devices is not None else devices,
name=flat_fun.__name__,
mapped_invars=(True,) * len(args))
mapped_invars=tuple(axis is not None for axis in in_axes_flat))
return tree_unflatten(out_tree(), out)

namestr = "pmap({}, axis_name={})".format
f_pmapped.__name__ = namestr(f_pmapped.__name__, axis_name)
return f_pmapped

def _pmap_axis_size(xs):
for x in xs:
try:
return x.shape[0]
except AttributeError:
pass
else:
msg = "pmap got value with no leading axis to map over: {}."
raise ValueError(msg.format([x for x in xs if not hasattr(x, 'shape')]))

class _TempAxisName(object):
def __init__(self, obj):
self.obj = obj
Expand All @@ -1051,8 +1064,8 @@ def __eq__(self, other):
return self.obj is other.obj


def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None,
backend: Optional[str] = None) -> Callable:
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, *,
in_axes=0, backend: Optional[str] = None) -> Callable:
warn("soft_pmap is an experimental feature and probably has bugs!")
_check_callable(fun)
axis_name = _TempAxisName(fun) if axis_name is None else axis_name
Expand All @@ -1061,7 +1074,11 @@ def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None,
def f_pmapped(*args, **kwargs):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten((args, kwargs))
axis_size = _pmap_axis_size(args_flat)
in_axes_flat = _flatten_axes(in_tree, (in_axes, 0))
assert all(axis in (0, None) for axis in in_axes_flat), \
"soft_pmap currently only supports mapping over the leading axis"
mapped_invars = tuple(axis is not None for axis in in_axes_flat)
axis_size = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "soft_pmap")
_check_args(args_flat)
flat_fun, out_tree = flatten_fun(f, in_tree)

Expand All @@ -1081,7 +1098,7 @@ def f_pmapped(*args, **kwargs):
axis_name=axis_name, axis_size=num_chunks,
global_axis_size=None, devices=None,
name=soft_mapped_fun.__name__,
mapped_invars=(True,) * len(reshaped_args))
mapped_invars=mapped_invars)
outs = [_reshape_merge(out) for out in reshaped_outs]
return tree_unflatten(out_tree(), outs)

Expand Down Expand Up @@ -1112,7 +1129,8 @@ def papply_fun(*args, **kwargs):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(f, in_tree)
axis_size = _pmap_axis_size(args_flat)
axis_size = _mapped_axis_size(
in_tree, args_flat, (0,) * len(args_flat), "papply")
out_flat = parallel.papply(flat_fun, axis_name, args_flat, axis_size)
return tree_unflatten(out_tree(), out_flat)

Expand All @@ -1126,7 +1144,8 @@ def pfun(*args):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten(args)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
axis_size = _pmap_axis_size(args_flat)
axis_size = _mapped_axis_size(
in_tree, args_flat, (0,) * len(args_flat), "parallelize")

chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count())
if chunk_size == 0 and leftover:
Expand Down
47 changes: 29 additions & 18 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ def spec_to_indices(shape: Tuple[int, ...],

# `product` will always return a sequence of tuples. Skip the tuples if each
# index is a single element.
if len(indices_per_axis) > 1:
indices = list(product(*indices_per_axis))
else:
if len(indices_per_axis) == 1:
indices = list(indices_per_axis[0])
else:
indices = list(product(*indices_per_axis))

return tuple(i for i in indices
for _ in range(sharding_spec.replication_factor))
Expand Down Expand Up @@ -196,6 +196,7 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
buffers[r][a] = (buf if buf.device() == devices[r]
else buf.copy_to_device(devices[r]))
else:
arg = xla.canonicalize_dtype(arg)
bufs = shard_arg_handlers[type(arg)](arg, devices, indices[a])
for r, buf in enumerate(bufs):
buffers[r][a] = buf
Expand Down Expand Up @@ -460,8 +461,9 @@ def __init__(self,
# creating pmap-style ShardedDeviceArrays.
if device_buffers is None:
device_buffers = sharding_spec
sharded_aval = ShapedArray(aval.shape[1:], aval.dtype)
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
aval.shape[1:])
sharded_aval, True)

# TODO(skye): assert invariants. Keep performance in mind though.
if indices is None:
Expand Down Expand Up @@ -562,12 +564,13 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size, glob
devices, name, mapped_invars):
abstract_args = map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
global_axis_size, devices, name, *abstract_args)
global_axis_size, devices, name, mapped_invars,
*abstract_args)
return compiled_fun(*args)

@lu.cache
def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
devices, name, *avals):
devices, name, mapped_invars, *avals):
if devices is not None and len(devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

Expand Down Expand Up @@ -613,9 +616,10 @@ def dynamic_fun(dummy, *args):
with extend_dynamic_axis_env(axis_name, dummy._trace, global_axis_size):
return fun.call_wrapped(*args)

sharded_avals = tuple(map(partial(shard_aval, axis_size), avals))
sharded_avals = tuple(shard_aval(axis_size, aval) if m else aval
for m, aval in zip(mapped_invars, avals))
pvals = [pe.PartialVal.unknown(aval) for aval in sharded_avals]
# We add a dummy first invar, to carry the trace details to `dynamic_fun`
# We add a dummy first invar, to carry the trace details to `dynamic_fun`
pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True)
Expand Down Expand Up @@ -703,9 +707,8 @@ def dynamic_fun(dummy, *args):
compiled = backend.compile(built, compile_options=compile_options)

input_sharding_specs = [_pmap_sharding_spec(num_local_replicas, axis_size,
aval.shape)
if aval is not core.abstract_unit else None
for aval in sharded_avals]
aval, m)
for m, aval in zip(mapped_invars, sharded_avals)]
input_indices = [spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in zip(avals, input_sharding_specs)]
Expand Down Expand Up @@ -771,7 +774,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None):
# TODO(jekbradbury): use ShardingSpec.replication_factor instead
aval = xla.abstractify(val) # type: ShapedArray
replicated_aval = ShapedArray((axis_size,) + aval.shape, aval.dtype)
sharding_spec = _pmap_sharding_spec(nrep, axis_size, aval.shape)
sharding_spec = _pmap_sharding_spec(nrep, axis_size, aval, True)
device_buffers = [xla.device_put(val, d) for d in devices]
return ShardedDeviceArray(replicated_aval, sharding_spec, device_buffers)

Expand All @@ -797,19 +800,27 @@ def _pval_to_result_handler(axis_size, nrep, pval, devices, backend):
return lambda _: bcast_const
else:
if pv is not core.abstract_unit:
sharding_spec = _pmap_sharding_spec(nrep, axis_size, pv.shape)
sharding_spec = _pmap_sharding_spec(nrep, axis_size, pv, True)
indices = spec_to_indices((axis_size,) + pv.shape, sharding_spec)
else:
sharding_spec = indices = None
return aval_to_result_handler(axis_size, sharding_spec, indices, pv)

def _pmap_sharding_spec(nrep, axis_size, sharded_shape):
def _pmap_sharding_spec(nrep, axis_size, sharded_aval, mapped):
if sharded_aval is core.abstract_unit:
return None
replication_factor, ragged = divmod(nrep, axis_size)
assert not ragged
return ShardingSpec(
shards_per_axis=(axis_size,) + (1,) * len(sharded_shape),
is_axis_materialized=(False,) + (True,) * len(sharded_shape),
replication_factor=replication_factor)
if mapped:
return ShardingSpec(
shards_per_axis=(axis_size,) + (1,) * len(sharded_aval.shape),
is_axis_materialized=(False,) + (True,) * len(sharded_aval.shape),
replication_factor=replication_factor)
else:
return ShardingSpec(
shards_per_axis=(1,) * len(sharded_aval.shape),
is_axis_materialized=(True,) * len(sharded_aval.shape),
replication_factor=replication_factor * axis_size)


def execute_replicated(compiled, backend, in_handler, out_handler, *args):
Expand Down
4 changes: 2 additions & 2 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,8 +2916,8 @@ def _reshape_sharded_device_array(array, new_sizes, old_sizes):
if ragged: return None
if new_sizes[0] != split_axis_size: return None
aval = ShapedArray(new_sizes, array.dtype)
sharding_spec = pxla._pmap_sharding_spec(new_sizes[0], new_sizes[0],
new_sizes[1:])
sharding_spec = pxla._pmap_sharding_spec(
new_sizes[0], new_sizes[0], ShapedArray(new_sizes[1:], array.dtype), True)
return pxla.ShardedDeviceArray(aval, sharding_spec, array.device_buffers)

return None
Expand Down

0 comments on commit 1cdd8f1

Please sign in to comment.