Skip to content

Commit

Permalink
Remove global_arg_shapes from pmap since it was only used for sharded…
Browse files Browse the repository at this point in the history
…_jit and sharded_jit was removed from JAX a long time ago

PiperOrigin-RevId: 520356179
  • Loading branch information
yashk2810 authored and jax authors committed Mar 29, 2023
1 parent a964ae7 commit fbc05ee
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 66 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -34,6 +34,9 @@ Remember to align the itemized text with the first line of an item within a list
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
from source.
* `global_arg_shapes` argument of pmap only worked with sharded_jit and has
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.

## jaxlib 0.4.8

Expand Down
1 change: 0 additions & 1 deletion docs/jaxpr.rst
Expand Up @@ -460,7 +460,6 @@ captured using the ``xla_pmap`` primitive. Consider this example
in (k,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
global_axis_size=1
in_axes=(None, 0)
is_explicit_global_axis_size=False
Expand Down
48 changes: 14 additions & 34 deletions jax/_src/api.py
Expand Up @@ -1437,12 +1437,6 @@ def pmap(
For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
the partitioned values span multiple processes. The global cross-process
per-replica shape of each argument, i.e. does not include the leading
pmapped dimension. Can be None for replicated arguments. This API is
likely to change in the future.
Returns:
A parallelized version of ``fun`` with arguments that correspond to those of
``fun`` but with extra array axes at positions indicated by ``in_axes`` and
Expand Down Expand Up @@ -1565,6 +1559,12 @@ def pmap(
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
[ 13. 13.]
"""
if global_arg_shapes is not None:
raise ValueError(
"global_arg_shapes only worked with sharded_jit which has long been"
" removed from JAX. Please migrate to pjit and remove global_arg_shapes"
" from pmap.")

if FLAGS.experimental_cpp_pmap:
func = _cpp_pmap
else:
Expand All @@ -1579,8 +1579,7 @@ def pmap(
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=donate_argnums,
global_arg_shapes=global_arg_shapes)
donate_argnums=donate_argnums)


class PmapCallInfo(NamedTuple):
Expand All @@ -1591,7 +1590,6 @@ class PmapCallInfo(NamedTuple):
donated_invars: Sequence[bool]
in_axes_flat: Sequence[Optional[int]]
local_axis_size: int
global_arg_shapes_flat: Sequence[Optional[Tuple[int, ...]]]
out_axes_thunk: HashableFunction
devices: Optional[Sequence[xc.Device]]
global_axis_size: int
Expand Down Expand Up @@ -1628,7 +1626,7 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
return global_axis_size

def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, in_devices, backend_name,
donate_tuple, in_devices, backend_name,
axis_size, args, kwargs):
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
Expand All @@ -1651,25 +1649,15 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums)
else:
dyn_in_axes = in_axes
dyn_global_arg_shapes = global_arg_shapes

if isinstance(global_arg_shapes, tuple):
dyn_global_arg_shapes = tuple(global_arg_shapes[i] for i in dyn_argnums)
else:
dyn_global_arg_shapes = global_arg_shapes
else:
dyn_args, dyn_in_axes = args, in_axes
dyn_global_arg_shapes = global_arg_shapes
args, in_tree = tree_flatten((dyn_args, kwargs))

if donate_tuple and not config.jax_debug_nans:
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
else:
donated_invars = (False,) * len(args)
in_axes_flat = tuple(flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0)))
global_arg_shapes_flat = tuple(flatten_axes(
"pmap global_arg_shapes", in_tree, (dyn_global_arg_shapes, None),
kws=True))
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")

f, res_paths = result_paths(f)
Expand Down Expand Up @@ -1709,7 +1697,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donated_invars=donated_invars,
in_axes_flat=in_axes_flat,
local_axis_size=local_axis_size,
global_arg_shapes_flat=global_arg_shapes_flat,
out_axes_thunk=out_axes_thunk,
devices=None if in_devices is None else tuple(in_devices),
global_axis_size=global_axis_size,
Expand All @@ -1727,12 +1714,11 @@ def _get_f_mapped(
backend: Optional[str],
axis_size: Optional[int],
donate_tuple: Tuple[int, ...],
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]],
):
def pmap_f(*args, **kwargs):
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, backend, axis_size, args, kwargs)
devices, backend, axis_size, args, kwargs)
for arg in p.flat_args:
dispatch.check_arg(arg)
out = pxla.xla_pmap(
Expand All @@ -1741,7 +1727,6 @@ def pmap_f(*args, **kwargs):
devices=p.devices,
in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk,
name=p.flat_fun.__name__, donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size)
return p.out_tree, out

Expand Down Expand Up @@ -1780,7 +1765,6 @@ def _python_pmap(
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> stages.Wrapped:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
Expand All @@ -1799,15 +1783,14 @@ def pmap_f(*args, **kwargs):
devices=devices,
backend=backend,
axis_size=axis_size,
global_arg_shapes=global_arg_shapes,
donate_tuple=donate_tuple)

out_tree, out_flat = f_pmapped_(*args, **kwargs)
return tree_unflatten(out_tree(), out_flat)

pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, global_arg_shapes, donate_tuple)
backend, axis_size, donate_tuple)

return cast(stages.Wrapped, pmap_f)

Expand Down Expand Up @@ -1842,7 +1825,6 @@ def _cpp_pmap(
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
Expand All @@ -1852,7 +1834,7 @@ def _cpp_pmap(
@api_boundary
def cache_miss(*args, **kwargs):
p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, devices, backend,
donate_tuple, devices, backend,
axis_size, args, kwargs)
for arg in p.flat_args:
dispatch.check_arg(arg)
Expand All @@ -1867,7 +1849,6 @@ def cache_miss(*args, **kwargs):
out_axes_thunk=p.out_axes_thunk,
name=p.flat_fun.__name__,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
)

Expand Down Expand Up @@ -1939,13 +1920,13 @@ def cache_miss(*args, **kwargs):

pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, global_arg_shapes, donate_tuple)
backend, axis_size, donate_tuple)

return pmap_f


def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
devices, backend, axis_size, global_arg_shapes, donate_tuple): # noqa: F811
devices, backend, axis_size, donate_tuple): # noqa: F811
"""Make a ``lower`` method for pmapped functions."""
# If the function we returned from ``pmap`` were a class instance,
# this might naturally be a method, with ``fun`` as a ``self`` and
Expand All @@ -1966,7 +1947,7 @@ def lower(*args, _experimental_lowering_platform: Optional[str] = None,
"""
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, backend, axis_size, args, kwargs)
devices, backend, axis_size, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args))
computation = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name,
Expand All @@ -1976,7 +1957,6 @@ def lower(*args, _experimental_lowering_platform: Optional[str] = None,
in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_platform=_experimental_lowering_platform)
Expand Down
36 changes: 8 additions & 28 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -745,25 +745,22 @@ def xla_pmap_impl_lazy(
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
) -> Callable:
if (config.jax_disable_jit and config.jax_eager_pmap and
not is_explicit_global_axis_size and not any(d for d in donated_invars)
and not all(g is not None for g in global_arg_shapes)):
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
def _emap_apply_fn(*args):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
axis_size=axis_size, global_axis_size=global_axis_size,
devices=devices, name=name, in_axes=in_axes,
out_axes_thunk=out_axes_thunk,
donated_invars=donated_invars,
global_arg_shapes=global_arg_shapes,
is_explicit_global_axis_size=is_explicit_global_axis_size)
return _emap_apply_fn
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, *abstract_args)

# Don't re-abstractify args unless logging is enabled for performance.
Expand Down Expand Up @@ -793,15 +790,12 @@ def _emap_impl(fun: lu.WrappedFun, *args,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
):
from jax._src import array
# TODO(sharadmv,mattjj): implement these cases
if any(d for d in donated_invars):
raise NotImplementedError("Buffer donation not supported in eager pmap.")
if any(g is not None for g in global_arg_shapes):
raise NotImplementedError("Global arg shapes not supported in eager pmap.")
if is_explicit_global_axis_size:
raise NotImplementedError("Non-default global_axis_size not supported in "
"eager pmap.")
Expand Down Expand Up @@ -1029,12 +1023,11 @@ def parallel_callable(fun: lu.WrappedFun,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
*avals):
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, avals, lowering_platform=None)
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
Expand Down Expand Up @@ -1091,26 +1084,17 @@ def find_replicas(jaxpr, axis_size, global_axis_size):

def stage_parallel_callable(
pci: ParallelCallableInfo,
fun: lu.WrappedFun,
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
fun: lu.WrappedFun):
sharded_avals = tuple(
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(pci.in_axes, pci.avals))
if any(s is not None for s in global_arg_shapes):
# TODO(skye): we could take this branch unconditionally if we handled
# grad of global_arg_shapes correctly.
global_sharded_avals = [
aval.update(shape=shape) if shape is not None else aval
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
else:
global_sharded_avals = sharded_avals # type: ignore

with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for pmap in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

Expand All @@ -1133,7 +1117,7 @@ def stage_parallel_callable(
num_global_shards = replicas.num_global_replicas * parts.num_partitions

shards = ShardInfo(
sharded_avals, out_sharded_avals, global_sharded_avals,
sharded_avals, out_sharded_avals, sharded_avals,
num_local_shards, num_global_shards)

return jaxpr, consts, replicas, parts, shards
Expand All @@ -1158,7 +1142,6 @@ def lower_parallel_callable(
in_axes: Iterable[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
*,
Expand Down Expand Up @@ -1197,8 +1180,7 @@ def lower_parallel_callable(
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
pci, fun, global_arg_shapes)
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(pci, fun)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("sharded_avals: %s", shards.sharded_avals)
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
Expand Down Expand Up @@ -1976,7 +1958,6 @@ def _pmap_dce_rule(used_outputs, eqn):
eqn.params['global_axis_size'], None):
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
# TODO(yashkatariya,mattjj): Handle global_arg_shapes here too.
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
new_params = dict(eqn.params, call_jaxpr=new_jaxpr,
Expand Down Expand Up @@ -2095,8 +2076,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
def _pmap_lowering(ctx, *in_nodes, axis_name,
axis_size, global_axis_size, devices, name,
call_jaxpr, backend=None, in_axes, out_axes,
donated_invars, global_arg_shapes,
is_explicit_global_axis_size):
donated_invars, is_explicit_global_axis_size):
del donated_invars # Unused.
xla.check_backend_matches(backend, ctx.module_context.platform)
# We in-line here rather than generating a Call HLO as in the xla_call
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/lax.py
Expand Up @@ -4272,14 +4272,13 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
_identity_fn, None, (), (), sharded_dim, sharded_dim)
p = api._prepare_pmap(
_identity_fn, sharded_dim, sharded_dim, static_broadcasted_tuple,
donate_tuple, None, None, None, None, args, kwargs)
donate_tuple, None, None, None, args, kwargs)
out_flat = pxla.xla_pmap_impl(
p.flat_fun, *p.flat_args, backend=None, axis_name=axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices, in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
)
return tree_util.tree_unflatten(p.out_tree(), out_flat)
Expand Down
1 change: 0 additions & 1 deletion tests/host_callback_test.py
Expand Up @@ -2999,7 +2999,6 @@ def f(xv):
in (c, f, g) }
devices=None
donated_invars=(False, False, False)
global_arg_shapes=(None,)
global_axis_size=None
in_axes=(0, 0, 0)
name=<lambda>
Expand Down

0 comments on commit fbc05ee

Please sign in to comment.