Skip to content

Commit

Permalink
Make all pmap tests pass with Array! I am skipping all soft pmap test…
Browse files Browse the repository at this point in the history
…s for now.

PiperOrigin-RevId: 467264992
  • Loading branch information
yashk2810 authored and jax authors committed Aug 12, 2022
1 parent d20dcf4 commit 18b6a32
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 27 deletions.
4 changes: 3 additions & 1 deletion jax/_src/api.py
Expand Up @@ -1897,7 +1897,7 @@ class PmapCallInfo(NamedTuple):


def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices):
from jax.experimental.sharding import PmapSharding
from jax.experimental.sharding import PmapSharding, SingleDeviceSharding
from jax.experimental.array import Array

if not args:
Expand All @@ -1907,6 +1907,8 @@ def _check_in_pmap_sharding_with_arrays(args, in_axes_flat, in_devices):
for a, i in safe_zip(args, in_axes_flat):
if not isinstance(a, Array):
continue
if isinstance(a.sharding, SingleDeviceSharding):
continue
if not isinstance(a.sharding, PmapSharding):
raise NotImplementedError('pmap only works with PmapSharding.')
if first_device_assignment is None:
Expand Down
14 changes: 13 additions & 1 deletion jax/experimental/array.py
Expand Up @@ -18,13 +18,15 @@
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List

from jax import core
from jax._src import ad_util
from jax._src import api_util
from jax._src import dispatch
from jax._src.lax import lax as lax_internal
from jax._src.config import config
from jax._src.util import prod, safe_zip
from jax._src.lib import xla_client as xc
from jax._src.api import device_put
from jax.interpreters import pxla, xla
from jax.interpreters import pxla, xla, mlir
from jax.experimental.sharding import (Sharding, SingleDeviceSharding,
XLACompatibleSharding)

Expand Down Expand Up @@ -245,6 +247,14 @@ def make_array_from_callback(shape: Shape, sharding: Sharding,
xla.canonicalize_dtype_handlers[Array] = pxla.identity
api_util._shaped_abstractify_handlers[Array] = \
lambda x: core.ShapedArray(x.shape, x.dtype)
ad_util.jaxval_adders[Array] = lax_internal.add
ad_util.jaxval_zeros_likers[Array] = lax_internal.zeros_like_array


def _array_mlir_constant_handler(val, canonicalize_types=True):
return mlir.ir_constants(val._value,
canonicalize_types=canonicalize_types)
mlir.register_constant_handler(Array, _array_mlir_constant_handler)


def _device_put_array(x, device: Optional[Device]):
Expand All @@ -267,6 +277,8 @@ def _array_shard_arg(x, devices, indices, mode):
if mode == pxla.InputsHandlerMode.pmap:
# sharding mismatch between `Array` and pmap sharding is checked in api.py's
# `_check_in_pmap_sharding_with_arrays` function.
if isinstance(x.sharding, SingleDeviceSharding):
return pxla._shard_device_array(x, devices, indices, mode)
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
else:
Expand Down
4 changes: 2 additions & 2 deletions jax/interpreters/pxla.py
Expand Up @@ -1332,8 +1332,8 @@ def from_hlo(xla_computation,
parts.local_num_partitions, out_parts, aval, out_axis)
for out_parts, aval, out_axis in safe_zip(
local_out_parts, local_out_avals, pci.out_axes)]
pmap_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
handle_outs = local_avals_to_results_handler(local_unmapped_avals, pmap_shardings)
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
handle_outs = local_avals_to_results_handler(local_unmapped_avals, out_shardings)

if hasattr(pci.backend, "compile_replicated"):
execute_fun = pci.backend.compile_replicated(
Expand Down

0 comments on commit 18b6a32

Please sign in to comment.