Skip to content

Commit

Permalink
* Check if the device assignment is the same across input and output …
Browse files Browse the repository at this point in the history
…shardings.

* Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources.

PiperOrigin-RevId: 460326054
  • Loading branch information
yashk2810 authored and jax authors committed Jul 11, 2022
1 parent 11896b6 commit 0bc8f8a
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 45 deletions.
91 changes: 49 additions & 42 deletions jax/experimental/pjit.py
Expand Up @@ -314,11 +314,8 @@ def infer_params(*args, _global_avals=False, **kwargs):
donated_invars = (False,) * len(args_flat)

if config.jax_array:
in_shardings = _get_and_check_in_shardings(
dyn_args, in_axis_resources, pjit_mesh, in_tree)
out_shardings = out_axis_resources
_check_array_device_assignment(
pjit_mesh, tuple(tree_flatten(out_shardings)[0]))
in_shardings, out_shardings = _get_and_check_in_and_out_shardings(
args_flat, in_axis_resources, out_axis_resources, pjit_mesh, in_tree)
else:
in_shardings = tree_map(lambda x: _create_mesh_pspec_sharding(pjit_mesh, x),
in_axis_resources)
Expand Down Expand Up @@ -406,34 +403,32 @@ def hashable_pytree(pytree):
closure=(treedef, vals))


def _get_and_check_in_shardings(args, pjit_in_shardings, pjit_mesh, in_tree):
try:
# tree_map over `args` to preserve the pytree structure of args.
arg_in_shardings = tree_map(lambda x: x.sharding, args)
except AttributeError:
arg_in_shardings = None

arg_in_shardings_flat = tuple(tree_flatten(arg_in_shardings)[0])
def _get_and_check_in_and_out_shardings(args_flat, pjit_in_shardings, out_shardings,
pjit_mesh, in_tree):
arg_in_shardings_flat = tuple(a.sharding if hasattr(a, 'sharding') else _UNSPECIFIED
for a in args_flat)
arg_ndims = tuple(a.ndim for a in args_flat)

if _is_unspecified(pjit_in_shardings):
if arg_in_shardings is None:
raise ValueError('Please specify sharding either on the args or on pjit.')
else:
# This function is cached.
_check_array_device_assignment(pjit_mesh, arg_in_shardings_flat)
return arg_in_shardings
# If pjit_in_shardings is unspecified, then arg_in_shardings cannot have
# unspecified in them.
for a in arg_in_shardings_flat:
if _is_unspecified(a):
raise ValueError('Please specify sharding either on the arg or on '
f'pjit. Found sharding {a} which is invalid.')
in_shardings_flat = arg_in_shardings_flat
else:
if arg_in_shardings is None:
_check_array_device_assignment(
pjit_mesh, tuple(tree_flatten(pjit_in_shardings)[0]))
return pjit_in_shardings
else:
# This function is cached.
_check_pjit_arg_shardings(
hashable_pytree(pjit_in_shardings), arg_in_shardings_flat, in_tree)
return arg_in_shardings
# This function is cached.
in_shardings_flat = _get_and_check_pjit_arg_shardings(
hashable_pytree(pjit_in_shardings), arg_in_shardings_flat, arg_ndims,
in_tree)

assert False, "Please open a bug report!" # This should be unreachable.
out_shardings_flat = tuple(tree_flatten(out_shardings)[0])
# Check if the device assignment is the same across inputs and outputs.
# This function is cached.
_check_array_device_assignment(pjit_mesh, in_shardings_flat + out_shardings_flat)

return tree_unflatten(in_tree, in_shardings_flat), out_shardings


def _create_mesh_pspec_sharding(mesh, x):
Expand Down Expand Up @@ -1325,7 +1320,8 @@ def _check_array_device_assignment(pjit_mesh, shardings):
# If mesh is empty, then check if all devices across shardings are
# equal
if first_device_assignment != arr_device_assignment:
raise ValueError("Devices of all `Array` inputs should be the same. "
raise ValueError("Devices of all `Array` inputs and outputs should be "
"the same. "
f"Got array devices: {first_device_assignment},\n "
f"another array devices: {arr_device_assignment}")
else:
Expand All @@ -1337,21 +1333,32 @@ def _check_array_device_assignment(pjit_mesh, shardings):
f"Array devices: {arr_device_assignment}")

@cache()
def _check_pjit_arg_shardings(pjit_in_shardings, arg_in_shardings_flat,
in_tree):
def _get_and_check_pjit_arg_shardings(pjit_in_shardings, arg_in_shardings_flat,
arg_ndims, in_tree):
pjit_in_shardings_flat = flatten_axis_resources(
"pjit in_shardings", in_tree, pjit_in_shardings(), tupled_args=True)

if pxla._check_if_any_auto(pjit_in_shardings_flat):
raise ValueError('Passing sharding on pjit and on args while using the '
'auto spmd partitioner is not allowed. Please call the '
'compiled object on the inputs.')

for p, a in safe_zip(pjit_in_shardings_flat, arg_in_shardings_flat):
if p.normalize() != a.normalize():
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
f'Got pjit sharding: {p},\narg sharding: {a}')
out = []
for pjit_sharding, arg_sharding, ndim in safe_zip(
pjit_in_shardings_flat, arg_in_shardings_flat, arg_ndims):
# If the sharding of the arg is not known, replace it with the sharding on
# pjit.
if _is_unspecified(arg_sharding):
out.append(pjit_sharding)
elif _is_auto(pjit_sharding):
raise ValueError('Passing sharding on pjit and on args while using the '
'auto spmd partitioner is not allowed. Please call the '
'compiled object on the inputs.')
else:
if pjit_sharding._to_xla_op_sharding(ndim) != arg_sharding._to_xla_op_sharding(ndim):
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
f'Got pjit sharding: {pjit_sharding},\n'
f'arg sharding: {arg_sharding}')
out.append(pjit_sharding)

assert not any(_is_unspecified(o) for o in out)
return tuple(out)


def _maybe_check_pjit_gda_mesh(args, mesh):
Expand Down
69 changes: 66 additions & 3 deletions tests/pjit_test.py
Expand Up @@ -1451,10 +1451,28 @@ def test_non_array_input_error(self):
out_axis_resources=MeshPspecSharding(
global_mesh, P('x', 'y')))
with self.assertRaisesRegex(
ValueError, ('Please specify sharding either on the args or on '
'pjit.')):
ValueError, 'Please specify sharding either on the arg or on pjit'):
f(input_data)

def test_numpy_array_input(self):
input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_data = np.arange(
prod(input_shape), dtype=np.float32).reshape(input_shape)
with jax._src.config.jax_array(True):
with global_mesh:
f = pjit(lambda x: x,
in_axis_resources=MeshPspecSharding(
global_mesh, P(None)),
out_axis_resources=MeshPspecSharding(
global_mesh, P('x', 'y')))
out = f(input_data)
self.assertIsInstance(out, array.Array)
for s in out.addressable_shards:
self.assertEqual(s.data.shape, (2, 1))
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(out._value, input_data)

def test_unspecified_out_axis_resources(self):
global_input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
Expand Down Expand Up @@ -1561,13 +1579,25 @@ def test_in_axis_resources_same_as_array_sharding(self):
in_axis_resources=MeshPspecSharding(global_mesh, P('x' ,'y')))(input_array)
self.assertIsInstance(out, array.Array)

def test_in_axis_resources_error(self):
mesh = jtu.create_global_mesh((2,), ('x'))
with jax._src.config.jax_array(True):
with self.assertRaisesRegex(
ValueError,
('When `config.jax_array` flag is enabled, '
'in_axis_resources should contain instances of `Sharding` '
'or `pjit.AUTO`.')):
pjit(lambda x: x,
in_axis_resources=(MeshPspecSharding(mesh, P('x')),
pjit_lib._UNSPECIFIED))

def test_out_axis_resources_error(self):
with jax._src.config.jax_array(True):
with self.assertRaisesRegex(
ValueError,
('When `config.jax_array` flag is enabled, '
'out_axis_resources should contain instances of `Sharding`.')):
'out_axis_resources should contain instances of `Sharding` '
'or `pjit.AUTO`.')):
pjit(lambda x: x, out_axis_resources=P('x'))

def test_no_input_output(self):
Expand Down Expand Up @@ -1630,6 +1660,39 @@ def test_array_device_assignment_mismatch_out_shardings(self):
out_axis_resources=(MeshPspecSharding(m1, spec),
MeshPspecSharding(m2, spec)))(a1, a1)

def test_array_device_assignment_mismatch_in_and_out_shardings(self):
input_shape = (8, 2)
m1 = jtu.create_global_mesh((4, 2), ('x', 'y'))
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
spec = P('x', 'y')

a1, _ = create_array(input_shape, m2, spec)

with jax._src.config.jax_array(True):
with m1:
with self.assertRaisesRegex(
ValueError, "Pjit's devices and Array's devices should be equal"):
pjit(lambda x, y: (x, y),
in_axis_resources=MeshPspecSharding(m2, spec),
out_axis_resources=MeshPspecSharding(m1, spec))(a1, a1)

def test_mixed_inputs(self):
input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
spec = P('x', 'y')

a1, input_data = create_array(input_shape, global_mesh, spec)

with jax._src.config.jax_array(True):
with global_mesh:
f = pjit(lambda x, y: (x, y),
in_axis_resources=MeshPspecSharding(global_mesh, P(None)))
with self.assertRaisesRegex(
ValueError,
('Sharding passed to pjit does not match the sharding on the '
'respective arg')):
f(input_data, a1)


def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
Expand Down

0 comments on commit 0bc8f8a

Please sign in to comment.