Skip to content

Commit

Permalink
Check the gda pjit mesh mismatch inside pjit and not during _shard_ar…
Browse files Browse the repository at this point in the history
…g which is shared by pmap, xmap and pjit. For pmap, pjit mesh has nothing to do with it. So this error should not be raised.

PiperOrigin-RevId: 422929245
  • Loading branch information
yashk2810 authored and jax authors committed Jan 20, 2022
1 parent 1a47076 commit 04e6786
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
5 changes: 0 additions & 5 deletions jax/experimental/global_device_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple

from jax.experimental import maps
from jax import core
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -431,10 +430,6 @@ def from_batched_callback_with_devices(
xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity

def _gda_shard_arg(x, devices, indices):
pjit_mesh = maps.thread_resources.env.physical_mesh
if x._global_mesh != pjit_mesh:
raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit "
f"mesh: {pjit_mesh},\n GDA mesh: {x._global_mesh}")
return [s.data for s in x.local_shards]
pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg

Expand Down
8 changes: 8 additions & 0 deletions jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def infer_params(*args, **kwargs):
else:
donated_invars = (False,) * len(args_flat)

_maybe_check_pjit_gda_mesh(args_flat, mesh)

local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
# TODO(yashkatariya): This is a hack. This should go away when avals have
# is_global attribute.
Expand Down Expand Up @@ -1004,6 +1006,12 @@ def gda_mesh_axes_to_canonicalized_parsed_pspec(mesh_axes) -> CanonicalizedParse
return CanonicalizedParsedPartitionSpec(ParsedPartitionSpec.from_user_input(
pspec, arg_name='GDA mesh_axes'))

def _maybe_check_pjit_gda_mesh(args, mesh):
for x in args:
if isinstance(x, GDA) and x._global_mesh != mesh:
raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit "
f"mesh: {mesh},\n GDA mesh: {x._global_mesh}")

# -------------------- XLA OpSharding to PartitionSpec --------------------
# Note that OpSharding is more expressive than PartitionSpecs, so it's not
# always possible to convert them, but the code below should at least
Expand Down

0 comments on commit 04e6786

Please sign in to comment.