Skip to content

Commit

Permalink
Replace out_specs with out_shardings and remove out_indices in …
Browse files Browse the repository at this point in the history
…ResultsHandler.

PiperOrigin-RevId: 461788039
  • Loading branch information
yashk2810 authored and jax authors committed Jul 19, 2022
1 parent e1fdd57 commit ea627b8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 86 deletions.
6 changes: 4 additions & 2 deletions jax/_src/api.py
Expand Up @@ -2172,6 +2172,8 @@ def cache_miss(*args, **kwargs):
execute_replicated = execute[0]
out_handler = execute_replicated.out_handler
in_handler = execute_replicated.in_handler
out_indices = [tuple(s.devices_indices_map(a.shape).values())
for s, a in safe_zip(out_handler.out_shardings, out_handler.out_avals)]
fastpath_data = _PmapFastpathData(
version=1,
xla_executable=execute_replicated.xla_executable,
Expand All @@ -2181,8 +2183,8 @@ def cache_miss(*args, **kwargs):
input_sharding_specs=in_handler.sharding_specs,
input_devices=in_handler.local_devices,
input_indices=in_handler.input_indices,
out_sharding_specs=out_handler.out_specs,
out_indices=out_handler.out_indices,
out_sharding_specs=[s.sharding_spec for s in out_handler.out_shardings],
out_indices=out_indices,
out_avals=out_handler.out_avals,
)

Expand Down
127 changes: 43 additions & 84 deletions jax/interpreters/pxla.py
Expand Up @@ -1218,6 +1218,11 @@ def from_hlo(xla_computation,
)
compile_options.parameter_is_tupled_arguments = tuple_args

process_index = xb.process_index(pci.backend)
local_device_assignment = np.array([
d for d in device_assignment.flat if d.process_index == process_index
])

local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals)
input_sharding_specs = [
_pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size,
Expand Down Expand Up @@ -1262,7 +1267,8 @@ def from_hlo(xla_computation,
parts.local_num_partitions, out_parts, aval, out_axis)
for out_parts, aval, out_axis in safe_zip(
local_out_parts, local_out_avals, pci.out_axes)]
handle_outs = local_avals_to_results_handler(out_specs, local_unmapped_avals)
pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings)

if hasattr(pci.backend, "compile_replicated"):
execute_fun = pci.backend.compile_replicated(
Expand Down Expand Up @@ -1465,111 +1471,64 @@ class ResultsHandler:
# `out_avals` is the `GlobalDeviceArray` global avals when using pjit or xmap
# with `config.parallel_functions_output_gda=True`. It is the local one
# otherwise, and also when using `pmap`.
__slots__ = ("handlers", "out_specs", "out_indices", "out_avals")
__slots__ = ("handlers", "out_shardings", "out_avals")

def __init__(self, handlers, out_specs, out_indices, out_avals):
def __init__(self, handlers, out_shardings, out_avals):
self.handlers = handlers
self.out_specs = out_specs
self.out_indices = out_indices
self.out_shardings = out_shardings
self.out_avals = out_avals

def __call__(self, out_bufs):
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]


def _get_sharding_specs(
shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray]
) -> Sequence[ShardingSpec]:
from jax.experimental import sharding

if all(isinstance(s, sharding.PmapSharding) for s in shardings):
return [s.sharding_spec for s in shardings] # type: ignore
elif all(isinstance(s, sharding.MeshPspecSharding) for s in shardings):
return [new_mesh_sharding_specs(s.mesh.shape, s.mesh.axis_names)(
aval.ndim, _get_array_mapping(s.spec))
for aval, s in safe_zip(avals, shardings)]
else:
raise ValueError('Getting sharding spec is only supported for '
'PmapSharding and MeshPspecSharding.')

def local_avals_to_results_handler(
local_out_specs: Sequence[Optional[ShardingSpec]],
unmapped_local_out_avals: Sequence[Optional[ShapedArray]]):
out_indices = [spec_to_indices(aval.shape, spec)
for aval, spec in safe_zip(unmapped_local_out_avals, local_out_specs)] # pytype: disable=attribute-error
unmapped_local_out_avals: Sequence[Optional[ShapedArray]],
local_shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
local_out_specs = _get_sharding_specs(
local_shardings, cast(Sequence[ShapedArray], unmapped_local_out_avals))
out_indices = [tuple(s.devices_indices_map(aval.shape).values())
for s, aval in safe_zip(local_shardings, unmapped_local_out_avals)]
handlers = [
local_aval_to_result_handler(aval, spec, idcs)
for aval, spec, idcs in safe_zip(unmapped_local_out_avals, local_out_specs, out_indices)
]
return ResultsHandler(handlers, local_out_specs, out_indices, unmapped_local_out_avals)


def _get_mesh_sharding_spec_and_avals(
shardings: Sequence[MeshPspecSharding], avals: Sequence[ShapedArray],
is_global: bool) -> Tuple[Sequence[ShardingSpec], Sequence[ShapedArray]]:
global_mesh = shardings[0].mesh
if is_global:
global_sharding_spec = mesh_sharding_specs(
global_mesh.shape, global_mesh.axis_names)
return [global_sharding_spec(aval, _get_array_mapping(s.spec))
for aval, s in safe_zip(avals, shardings)], avals
else:
out_axes = [_get_array_mapping(s.spec) for s in shardings]
local_sharding_spec = mesh_sharding_specs(
global_mesh.local_mesh.shape, global_mesh.axis_names)
local_out_untiled_avals = [
global_mesh._global_to_local(o, aval)
for aval, o in safe_zip(avals, out_axes)
]
return [local_sharding_spec(aval, oa)
for aval, oa in safe_zip(local_out_untiled_avals, out_axes)], local_out_untiled_avals


def _get_sharding_spec_and_avals(
shardings: Sequence[XLACompatibleSharding],
avals: Sequence[ShapedArray],
is_global: bool) -> Tuple[Sequence[ShardingSpec], Sequence[ShapedArray]]:
"""Returns the sharding spec and the avals required by `ResultsHandler`.
If `is_global`, then global sharding specs and global avals are returned else
host local sharding specs and host local avals are returned.
"""
from jax.experimental import sharding

if not shardings:
return [], []

if is_global:
if config.jax_parallel_functions_output_gda:
assert all(isinstance(s, sharding.MeshPspecSharding) for s in shardings)
return _get_mesh_sharding_spec_and_avals(
cast(Sequence[sharding.MeshPspecSharding], shardings), avals, is_global=True)
elif config.jax_array:
if all(isinstance(s, sharding.PmapSharding) for s in shardings):
# Cast for type checkers. Does not affect runtime.
shardings = cast(Sequence[sharding.PmapSharding], shardings)
return [s.sharding_spec for s in shardings], avals
elif all(isinstance(s, sharding.MeshPspecSharding) for s in shardings):
return _get_mesh_sharding_spec_and_avals(
cast(Sequence[sharding.MeshPspecSharding], shardings), avals, is_global=True)
else:
# TODO(b/239098037): Delete `_get_sharding_spec_and_avals` and
# _get_mesh_sharding_spec_and_avals. It's ok to return `[]` because
# for shardings other than MeshPspecSharding and PmapSharding, they
# don't have sharding specs and it doesn't make sense for other
# shardings to have that field set.
return [], avals
else:
raise ValueError('Option not recognized. Please file a bug against JAX.')
else:
assert all(isinstance(s, sharding.MeshPspecSharding) for s in shardings)
return _get_mesh_sharding_spec_and_avals(
cast(Sequence[sharding.MeshPspecSharding], shardings), avals, is_global=False)
return ResultsHandler(handlers, local_shardings, unmapped_local_out_avals)


def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray],
shardings: Sequence[XLACompatibleSharding]):
shardings: Sequence[XLACompatibleSharding]) -> ResultsHandler:
from jax.experimental.sharding import MeshPspecSharding

if config.jax_parallel_functions_output_gda or config.jax_array:
global_out_specs, _ = _get_sharding_spec_and_avals(
shardings, global_out_avals, is_global=True)
global_out_indices = [tuple(s.devices_indices_map(aval.shape).values())
for s, aval in safe_zip(shardings, global_out_avals)]
handlers = [
global_aval_to_result_handler(global_aval, s)
for global_aval, s in safe_zip(global_out_avals, shardings)
]
return ResultsHandler(handlers, global_out_specs, global_out_indices,
global_out_avals)
return ResultsHandler(handlers, shardings, global_out_avals)
else:
local_out_specs, local_out_untiled_avals = _get_sharding_spec_and_avals(
shardings, global_out_avals, is_global=False)
return local_avals_to_results_handler(local_out_specs, local_out_untiled_avals)
# This path is taken when the outputs are SDAs.
assert all(isinstance(s, MeshPspecSharding) for s in shardings)
local_out_avals = [s.mesh._global_to_local(_get_array_mapping(s.spec), aval)
for aval, s in safe_zip(global_out_avals, shardings)]
local_shardings = [MeshPspecSharding(s.mesh.local_mesh, s.spec) for s in shardings] # type: ignore
return local_avals_to_results_handler(local_out_avals, local_shardings)


@profiler.annotate_function
Expand Down Expand Up @@ -2305,7 +2264,7 @@ def lower_sharding_computation(
for aval, i in safe_zip(global_in_avals, in_shardings)]
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_op_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
out_op_shardings = [o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(global_out_avals, out_shardings)]
replicated_args = [False] * len(in_jaxpr_avals)
axis_ctx = mlir.ShardingContext(first_sharding)
Expand Down

0 comments on commit ea627b8

Please sign in to comment.