Skip to content

Commit

Permalink
Add early support in pjit for single device shardings. Also lift the …
Browse files Browse the repository at this point in the history
…restriction of needing the mesh context manager when `config.jax_array` is enabled.

PiperOrigin-RevId: 465712981
  • Loading branch information
yashk2810 authored and jax authors committed Aug 6, 2022
1 parent 81b6263 commit c02359b
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 69 deletions.
30 changes: 20 additions & 10 deletions jax/experimental/pjit.py
Expand Up @@ -295,8 +295,13 @@ def infer_params(*args, _global_avals=False, **kwargs):
resource_env = pxla.thread_resources.env
pjit_mesh = resource_env.physical_mesh
if pjit_mesh.empty:
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
"it's defined at the call site?")
if config.jax_array:
# Don't enforce requiring a mesh when `jax_array` flag is enabled. But
# if mesh is not empty then pjit will respect it.
pass
else:
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
"it's defined at the call site?")

f = lu.wrap_init(fun)
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
Expand Down Expand Up @@ -955,6 +960,7 @@ def _pjit_batcher_for_sharding(
return OpShardingSharding(s._device_assignment, new_op)
else:
assert isinstance(s, OpShardingSharding)
assert not mesh.empty
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0]
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
mps = MeshPspecSharding._from_parsed_pspec(mesh, parsed_pspec)
Expand Down Expand Up @@ -1613,14 +1619,18 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti

def _get_op_sharding_from_executable(
executable) -> Tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
input_op_shardings: List[xc.OpSharding] = []
for s in executable.hlo_modules()[0].spmd_parameters_shardings:
input_op_shardings.extend(_get_op_sharding(s))

output_op_shardings: Sequence[xc.OpSharding] = _get_op_sharding(
executable.hlo_modules()[0].spmd_output_sharding)

return input_op_shardings, output_op_shardings
in_op_shardings: List[xc.OpSharding] = []
parameter_shardings_from_xla = executable.hlo_modules()[0].spmd_parameters_shardings
if parameter_shardings_from_xla is not None:
for s in parameter_shardings_from_xla:
in_op_shardings.extend(_get_op_sharding(s))

out_op_shardings: List[xc.OpSharding] = []
output_shardings_from_xla = executable.hlo_modules()[0].spmd_output_sharding
if output_shardings_from_xla is not None:
out_op_shardings = _get_op_sharding(output_shardings_from_xla) # type: ignore

return in_op_shardings, out_op_shardings


def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
Expand Down
14 changes: 9 additions & 5 deletions jax/experimental/sharding.py
Expand Up @@ -205,6 +205,13 @@ def _to_xla_op_sharding(
return sharding_spec.sharding_proto(special_axes=special_axes)


@functools.lru_cache()
def _get_replicated_op_sharding():
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
return proto


class SingleDeviceSharding(XLACompatibleSharding):

def __init__(self, device: Device):
Expand Down Expand Up @@ -237,9 +244,7 @@ def _device_assignment(self) -> XLADeviceAssignment:
return [self._device]

def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
return proto
return _get_replicated_op_sharding()


class PmapSharding(XLACompatibleSharding):
Expand Down Expand Up @@ -338,6 +343,5 @@ def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:

@classmethod
def get_replicated(cls, device_assignment):
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
proto = _get_replicated_op_sharding()
return cls(device_assignment, proto)
17 changes: 14 additions & 3 deletions jax/interpreters/pxla.py
Expand Up @@ -2630,11 +2630,21 @@ def _get_input_metadata(
return shardings, input_indices, input_avals


def _get_op_sharding_shardings_from_executable(xla_executable, device_assignment):
def _get_op_sharding_shardings_from_executable(
xla_executable, device_assignment, num_in_avals, num_out_avals):
from jax.experimental import pjit
from jax.experimental.sharding import OpShardingSharding
from jax.experimental.sharding import OpShardingSharding, SingleDeviceSharding

in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)

# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case,
# just return SingleDeviceShardings since we know the computation is running
# only on 1 device.
if not in_op_shardings and not out_op_shardings and len(device_assignment) == 1:
return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)],
[SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)])

return ([OpShardingSharding(device_assignment, i) for i in in_op_shardings],
[OpShardingSharding(device_assignment, o) for o in out_op_shardings])

Expand Down Expand Up @@ -2746,7 +2756,8 @@ def from_hlo(name: str,
elif out_shardings and all(_is_unspecified(o) for o in out_shardings):
assert mesh is None
in_shardings, out_shardings = _get_op_sharding_shardings_from_executable(
xla_executable, first_sharding._device_assignment)
xla_executable, first_sharding._device_assignment,
len(global_in_avals), len(global_out_avals))

in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
Expand Down
154 changes: 103 additions & 51 deletions tests/pjit_test.py
Expand Up @@ -1443,27 +1443,53 @@ class ArrayPjitTest(jtu.JaxTestCase):
('fully_sharded_output', P('x', 'y'), (2, 4)),
('fully_replicated_output', P(None), (8, 8)),
)
@jax._src.config.jax_array(True)
def test_pjit_array_single_output(self, out_axis_resources, shard_shape):
global_input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mesh_axes = P('x', 'y')

input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)

with jax._src.config.jax_array(True):
with global_mesh:
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
global_mesh, out_axis_resources))
expected_matrix_mul = input_data @ input_data.T
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
global_mesh, out_axis_resources))
expected_matrix_mul = input_data @ input_data.T

out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertArraysEqual(out._value, expected_matrix_mul)
out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertArraysEqual(out._value, expected_matrix_mul)

@parameterized.named_parameters(
('fully_sharded_output', P('x', 'y'), (2, 4)),
('fully_replicated_output', P(None), (8, 8)),
)
@jax._src.config.jax_array(True)
def test_pjit_array_single_output_with_mesh_context_manager(
self, out_axis_resources, shard_shape):
global_input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mesh_axes = P('x', 'y')

input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)

with global_mesh:
f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding(
global_mesh, out_axis_resources))
expected_matrix_mul = input_data @ input_data.T

out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertArraysEqual(out._value, expected_matrix_mul)

def test_non_array_input_error(self):
input_shape = (8, 2)
Expand Down Expand Up @@ -1498,6 +1524,7 @@ def test_numpy_array_input(self):
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(out._value, input_data)

@jax._src.config.jax_array(True)
def test_unspecified_out_axis_resources(self):

def _checks(out, input_data):
Expand All @@ -1516,21 +1543,20 @@ def _checks(out, input_data):

input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)

with jax._src.config.jax_array(True):
with global_mesh:
f = pjit(lambda x: x)
f = pjit(lambda x: x)

out = f(input_array)
_checks(out, input_data)
out = f(input_array)
_checks(out, input_data)

out2 = f(out)
_checks(out2, input_data)
out2 = f(out)
_checks(out2, input_data)

@parameterized.named_parameters(
('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)),
('mesh2', (2, 2), (4, 1), (4, 2), (2, 2), (8, 2)),
('mesh3', (2, 1), (4, 2), (4, 2), (4, 2), (8, 2)),
)
@jax._src.config.jax_array(True)
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
s2_shape, s3_shape, s4_shape):
# Disable on SE runtime type because XLA sharding propagation is not
Expand All @@ -1549,37 +1575,35 @@ def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
spec4 = P(None)
a4, _ = create_array(global_input_shape, global_mesh, spec4)

with jax._src.config.jax_array(True):
with global_mesh:
@pjit
def f(tree):
return tree
out_tree = f((a1, (a2, (a3, a4))))
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)

self.assertIsInstance(out1, array.Array)
self.assertEqual(out1.shape, (8, 2))
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])

self.assertIsInstance(out2, array.Array)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
for s in out2.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])

self.assertIsInstance(out3, array.Array)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
for s in out3.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])

self.assertIsInstance(out4, array.Array)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
for s in out4.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data)
@pjit
def f(tree):
return tree
out_tree = f((a1, (a2, (a3, a4))))
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)

self.assertIsInstance(out1, array.Array)
self.assertEqual(out1.shape, (8, 2))
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])

self.assertIsInstance(out2, array.Array)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
for s in out2.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])

self.assertIsInstance(out3, array.Array)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
for s in out3.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])

self.assertIsInstance(out4, array.Array)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
for s in out4.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data)

def test_in_axis_resources_mismatch_error(self):
global_input_shape = (8, 2)
Expand Down Expand Up @@ -1736,6 +1760,34 @@ def test_pjit_array_same_sharding_aot(self):
compiled = f.lower(jax.ShapedArray(input_shape, jnp.float32)).compile()
compiled(a1) # no error

@jax._src.config.jax_array(True)
def test_pjit_single_device_sharding_add(self):
a = jnp.array([1, 2, 3], dtype=jnp.float32)
b = jnp.array([4, 5, 6], dtype=jnp.float32)

@pjit
def add(x, y):
return x + y
out = add(a, b)
self.assertIsInstance(out, array.Array)
self.assertArraysEqual(out, a + b)

out2 = add(out, out)
self.assertIsInstance(out2, array.Array)
self.assertArraysEqual(out2, 2 * (a + b))

@jax._src.config.jax_array(True)
def test_pjit_single_device_sharding_mul(self):
a = jnp.arange(16).reshape((8, 2))

@pjit
def mul(x):
return x @ x.T

out = mul(a)
self.assertIsInstance(out, array.Array)
self.assertArraysEqual(out, a @ a.T)


def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
Expand Down

0 comments on commit c02359b

Please sign in to comment.