Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 409591497
  • Loading branch information
jax authors committed Nov 13, 2021
1 parent e94cc97 commit 6fa860d
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 52 deletions.
6 changes: 6 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,12 @@ def update_thread_local_jit_state(**kw):
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))

gsda_out = config.define_bool_state(
name='jax_gsda_out',
default=False,
help='If True, pjit will output GSDAs.')


distributed_debug = config.define_bool_state(
name='jax_distributed_debug',
default=False,
Expand Down
127 changes: 92 additions & 35 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,41 +412,94 @@ def _shard_abstract_array(size, axis: int, x):
return x.update(shape=tuple_delete(x.shape, axis))
shard_aval_handlers[ShapedArray] = _shard_abstract_array

def aval_to_result_handler(sharding_spec: Optional[ShardingSpec],
indices: Optional[Tuple[Index]],
aval: core.AbstractValue) -> Callable[
[List[xb.xla_client.Buffer]], Any]:
MeshAxisName = Any
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]

AxisResource = Tuple[Optional[Tuple[Any, ...]], ...]

def array_mapping_to_axis_resources(array_mapping: ArrayMapping) -> AxisResource:
if not array_mapping:
return tuple()
max_index = array_mapping[max(array_mapping, key=array_mapping.get)] # type: ignore
reverse_map = defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
return tuple(
tuple(reverse_map[i]) if reverse_map[i] else None for i in range(max_index + 1)
)

def aval_to_result_handler(
sharding_spec: Optional[ShardingSpec],
indices: Optional[Tuple[Index]],
aval: core.AbstractValue,
global_aval: Optional[ShapedArray] = None,
out_axis_resources: Optional[AxisResource] = None,
global_mesh = None,
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
sharding_spec: indicates how the output is sharded across devices, or None
sharding_spec: Indicates how the output is sharded across devices, or None
for non-array avals.
indices: the pre-computed result of spec_to_indices, or None for non-array
indices: The pre-computed result of spec_to_indices, or None for non-array
avals.
aval: the output AbstractValue.
aval: The output AbstractValue.
global_aval: Global output AbstractValue. Used for creating GSDAs.
out_axis_resources: A tuple specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
try:
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval)
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval,
global_aval, out_axis_resources, global_mesh)
except KeyError as err:
raise TypeError("No pxla_result_handler for type: {}".format(type(aval))
) from err

PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
def array_result_handler(sharding_spec, indices, aval: ShapedArray):
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)

def array_result_handler(sharding_spec, indices, aval: ShapedArray, global_aval,
out_axis_resources, global_mesh):
if config.jax_gsda_out:
return gsda_array_result_handler(global_aval, global_mesh, out_axis_resources)
else:
return sda_array_result_handler(sharding_spec, indices, aval)

pxla_result_handlers[ShapedArray] = array_result_handler
pxla_result_handlers[ConcreteArray] = array_result_handler

def sda_array_result_handler(sharding_spec, indices, aval: ShapedArray):
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)

def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
from ..experimental.gsda import GlobalShardedDeviceArray

return lambda bufs: GlobalShardedDeviceArray(
global_aval.shape, global_mesh, out_axis_resources, bufs)

### lazy device-memory persistence and result handling

Expand Down Expand Up @@ -1172,12 +1225,31 @@ def __init__(self, handlers, out_specs, out_indices, unmapped_local_out_avals):
def __call__(self, out_bufs):
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]

def avals_to_results_handler(nrep, npart, out_specs, unmapped_local_out_avals):

def avals_to_results_handler(
nrep,
npart,
out_specs,
unmapped_local_out_avals,
global_out_avals: Optional[Sequence[ShapedArray]] = None,
out_axis_resources: Optional[Sequence[AxisResource]] = None,
global_mesh=None):
out_indices = [spec_to_indices(aval.shape, spec)
if aval is not core.abstract_unit else None
for aval, spec in safe_zip(unmapped_local_out_avals, out_specs)] # pytype: disable=attribute-error
handlers = [aval_to_result_handler(spec, idcs, aval)
for spec, idcs, aval in safe_zip(out_specs, out_indices, unmapped_local_out_avals)]
if global_out_avals and out_axis_resources and global_mesh:
handlers = [
aval_to_result_handler(spec, idcs, aval, global_aval, out_axis, global_mesh)
for spec, idcs, aval, global_aval, out_axis in safe_zip(
out_specs, out_indices, unmapped_local_out_avals,
global_out_avals, out_axis_resources)
]
else:
handlers = [
aval_to_result_handler(spec, idcs, aval)
for spec, idcs, aval, in safe_zip(out_specs, out_indices,
unmapped_local_out_avals)
]

return ResultsHandler(handlers, out_specs, out_indices, unmapped_local_out_avals)

Expand Down Expand Up @@ -1397,24 +1469,6 @@ def _unravel_index(c, axis_env):

# ------------------- xmap -------------------

MeshAxisName = Any
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]

class Mesh:

def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
Expand Down Expand Up @@ -1670,8 +1724,8 @@ def lower_mesh_computation(
built = c.Build(out_tuple)
return MeshComputation(
built, mesh, local_in_untiled_avals,
local_out_untiled_avals, in_axes, out_axes,
spmd_lowering, tuple_args)
local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None),
in_axes, out_axes, spmd_lowering, tuple_args)


class MeshComputation:
Expand Down Expand Up @@ -1703,6 +1757,7 @@ def __init__(self,
mesh: Mesh,
local_in_untiled_avals: Sequence[ShapedArray],
local_out_untiled_avals: Sequence[ShapedArray],
global_out_avals: Optional[Sequence[ShapedArray]],
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
spmd_lowering: bool, tuple_args: bool,
Expand Down Expand Up @@ -1744,8 +1799,10 @@ def __init__(self,

local_output_specs = [local_sharding_spec(aval, aval_out_axes)
for aval, aval_out_axes in safe_zip(local_out_untiled_avals, out_axes)]
out_axis_resources = [array_mapping_to_axis_resources(o) for o in out_axes]
handle_outs = avals_to_results_handler(num_local_replicas, num_local_partitions,
local_output_specs, local_out_untiled_avals)
local_output_specs, local_out_untiled_avals,
global_out_avals, out_axis_resources, mesh)

if _allow_compile_replicated and hasattr(backend, "compile_replicated"):
self.unsafe_call = backend.compile_replicated(
Expand Down
126 changes: 109 additions & 17 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,16 +597,99 @@ def cb(index):
gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)

@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
def f(x):
return x @ x.T
with jax._src.config.gsda_out(True):
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
def f(x):
return x @ x.T
expected_matrix_mul = input_data @ input_data.T

out = f(gsda_obj)
self.assertIsInstance(out, gsda.GlobalShardedDeviceArray)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out._global_mesh.shape, {'x': 4, 'y': 2})
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])

out1 = f(input_data)
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])

@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_multi_input_multi_output(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]

mesh_axes1 = P('x', 'y')
gsda1 = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes1, cb)
mesh_axes2 = P('x')
gsda2 = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes2, cb)
mesh_axes3 = P(('x', 'y'))
gsda3 = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes3, cb)
mesh_axes4 = P(None)
gsda4 = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes4, cb)

with jax._src.config.gsda_out(True):
@partial(
pjit,
in_axis_resources=(mesh_axes1, mesh_axes2, mesh_axes3, mesh_axes4),
out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
def f(x, y, z, a):
return x @ x.T, y, z, a
out1, out2, out3, out4 = f(gsda1, gsda2, gsda3, gsda4)

self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
self.assertEqual(out1.local_shards[0].index, (slice(0, 2), slice(0, 4)))
self.assertEqual(out1.local_shards[1].index, (slice(0, 2), slice(4, 8)))
self.assertListEqual([s.replica_id for s in out1.local_shards],
[0, 0, 0, 0, 0, 0, 0, 0])
expected_matrix_mul = input_data @ input_data.T
for s in out1.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])

self.assertIsInstance(out2, gsda.GlobalShardedDeviceArray)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.local_shards[0].data.shape, (8, 2))
self.assertEqual(out2.local_shards[0].index, (slice(None), slice(None)))
self.assertEqual(out2.local_shards[1].index, (slice(None), slice(None)))
self.assertListEqual([s.replica_id for s in out2.local_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
for s in out2.local_shards:
self.assertArraysEqual(s.data, input_data)

self.assertIsInstance(out3, gsda.GlobalShardedDeviceArray)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.local_shards[0].data.shape, (2, 2))
self.assertEqual(out3.local_shards[0].index, (slice(0, 2), slice(None)))
self.assertEqual(out3.local_shards[1].index, (slice(0, 2), slice(None)))
self.assertListEqual([s.replica_id for s in out3.local_shards],
[0, 1, 0, 1, 0, 1, 0, 1])
for s in out3.local_shards:
self.assertArraysEqual(s.data, input_data[s.index])

self.assertIsInstance(out4, gsda.GlobalShardedDeviceArray)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.local_shards[0].data.shape, (1, 2))
self.assertEqual(out4.local_shards[0].index, (slice(0, 1), slice(None)))
self.assertEqual(out4.local_shards[1].index, (slice(1, 2), slice(None)))
self.assertListEqual([s.replica_id for s in out4.local_shards],
[0, 0, 0, 0, 0, 0, 0, 0])
for s in out4.local_shards:
self.assertArraysEqual(s.data, input_data[s.index])

out = f(gsda_obj)
# TODO(yashkatariya): Enable the gsda_out flag and check for GSDA as the
# output.
self.assertIsInstance(out, pxla.ShardedDeviceArray)
self.assertLen(out.device_buffers, 8)
self.assertEqual(out.device_buffers[0].shape, (2, 4))

@jtu.with_mesh([('x', 2), ('y', 2)])
def test_pjit_gsda_mesh_mismatch(self):
Expand Down Expand Up @@ -820,15 +903,15 @@ def testAxisResourcesMismatch(self):
pjit(lambda x, y: x, p, p)(x, x) # Error, but make sure we hint at tupling
# TODO(apaszke): Disable implicit list casts and enable this
# error = re.escape(
# r"pjit in_axis_resources specification must be a tree prefix of the "
# r"corresponding value, got specification (None, None, None) for value "
# r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that "
# r"are non-trivial pytrees should always be wrapped in a tuple representing "
# r"the argument list. In particular, you're passing in a single argument "
# r"which means that pjit in_axis_resources might need to be wrapped in a "
# r"singleton tuple.")
# r"pjit in_axis_resources specification must be a tree prefix of the "
# r"corresponding value, got specification (None, None, None) for value "
# r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that "
# r"are non-trivial pytrees should always be wrapped in a tuple representing "
# r"the argument list. In particular, you're passing in a single argument "
# r"which means that pjit in_axis_resources might need to be wrapped in a "
# r"singleton tuple.")
# with self.assertRaisesRegex(ValueError, error):
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
error = re.escape(
r"pjit out_axis_resources specification must be a tree prefix of the "
r"corresponding value, got specification [[None, None, None], None] for "
Expand All @@ -853,6 +936,7 @@ def h(x):


class UtilTest(jtu.JaxTestCase):

def testOpShardingRoundTrip(self):
FakeDevice = namedtuple('FakeDevice', ['id'])
mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)])
Expand All @@ -879,6 +963,14 @@ def roundtrip(spec):
spec[rng.choice(dims)] += (axis,)
roundtrip(P(*spec))

@parameterized.named_parameters(
("linear", {'x': 0, 'y': 1, 'z': 2}, (('x',), ('y',), ('z',))),
("combine", {'x': 0, 'y': 0, 'z': 1}, (('x', 'y'), ('z',))),
("skip", {'x': 0, 'y': 0, 'z': 2}, (('x', 'y'), None, ('z',))),
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, (('x',), ('y',), None, ('z',))),
)
def test_array_mapping_to_axis_resources(self, inp, expected_out):
self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out)

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 6fa860d

Please sign in to comment.