Skip to content

Commit

Permalink
Make mesh_axes on GDA strict by only allowing PartitionSpecs to be …
Browse files Browse the repository at this point in the history
…consistent with pjit.

PiperOrigin-RevId: 432957496
  • Loading branch information
yashk2810 authored and jax authors committed Mar 7, 2022
1 parent 17f11e0 commit 99a1037
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 73 deletions.
11 changes: 6 additions & 5 deletions jax/experimental/gda_serialization/serialization_test.py
Expand Up @@ -21,6 +21,7 @@
from jax._src import test_util as jtu
from jax._src import util
from jax.config import config
from jax.experimental import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.gda_serialization import serialization
from jax.experimental.maps import Mesh
Expand All @@ -43,7 +44,7 @@ class CheckpointTest(jtu.JaxTestCase):
def test_checkpointing(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
mesh_axes = P('x', 'y')
num = util.prod(global_input_shape)

# First GDA
Expand All @@ -66,7 +67,7 @@ def cb2(index):
def cb3(index):
return np.array([])
global_mesh1d = create_global_mesh((8,), ('x',))
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, [None], cb3)
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)

ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
Expand All @@ -76,7 +77,7 @@ def cb3(index):

m1, m2, m3 = serialization.run_deserialization(
[global_mesh, global_mesh, global_mesh1d],
[mesh_axes, ['x'], [None]],
[mesh_axes, P('x'), P(None)],
tspecs)

self.assertArraysEqual(m1.local_shards[0].data.to_py(),
Expand Down Expand Up @@ -109,7 +110,7 @@ def test_checkpointing_with_bigger_shape(self):
def cb1(index):
return global_input_data1[index]
gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
['x', 'y'], cb1)
P('x', 'y'), cb1)
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

ckpt_paths = [str(ckpt_dir1)]
Expand All @@ -119,7 +120,7 @@ def cb1(index):

m1, = serialization.run_deserialization(
[create_global_mesh((4, 2), ('x', 'y'))],
[['x', 'y']],
[P('x', 'y')],
tspecs,
[(12, 2)],
)
Expand Down
31 changes: 17 additions & 14 deletions jax/experimental/global_device_array.py
Expand Up @@ -27,7 +27,7 @@
from jax.interpreters.sharded_jit import PartitionSpec

Shape = Tuple[int, ...]
MeshAxes = Sequence[Union[str, Tuple[str], None]]
MeshAxes = PartitionSpec
DeviceArray = xc.Buffer
Device = xc.Device
ArrayLike = Union[np.ndarray, DeviceArray]
Expand All @@ -50,11 +50,7 @@ def _get_array_mapping(mesh_axes):
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
return get_array_mapping(parsed_pspec)


Expand Down Expand Up @@ -297,7 +293,7 @@ def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh,

self._local_shards = self._create_local_shards()

ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes)
assert all(db.shape == ss for db in device_buffers), (
f"Expected shard shape {ss} doesn't match the device buffer "
f"shape, got: {[db.shape for db in device_buffers]}")
Expand All @@ -322,8 +318,8 @@ def __str__(self):

def __repr__(self):
return (f'GlobalDeviceArray(shape={self.shape}, dtype={self.dtype}, '
f'global_mesh_shape={dict(self._global_mesh.shape)}, '
f'mesh_axes={self._mesh_axes})')
f'global_mesh_shape={dict(self.mesh.shape)}, '
f'mesh_axes={self.mesh_axes})')

@property
def shape(self) -> Shape:
Expand All @@ -341,6 +337,10 @@ def size(self):
def mesh(self):
return self._global_mesh

@property
def mesh_axes(self) -> MeshAxes:
return self._mesh_axes

@property
def is_fully_replicated(self) -> bool:
return self.shape == self.local_data(0).shape
Expand All @@ -350,7 +350,7 @@ def _create_local_shards(self) -> Sequence[Shard]:
global_indices_rid = self._gda_fast_path_args.global_indices_replica_ids
else:
global_indices_rid = get_shard_indices_replica_ids(
self._global_shape, self._global_mesh, self._mesh_axes)
self._global_shape, self._global_mesh, self.mesh_axes)

out = []
for db in self._device_buffers:
Expand Down Expand Up @@ -379,7 +379,7 @@ def global_shards(self) -> Sequence[Shard]:
# Also as this a cached property, once calculated, it should be cached. So
# multiple accesses should be cheap.
global_indices_rid = get_shard_indices_replica_ids(
self._global_shape, self._global_mesh, self._mesh_axes)
self._global_shape, self._global_mesh, self.mesh_axes)
device_to_buffer = dict((db.device(), db) for db in self._device_buffers)
global_shards = []
for device, (index, rid) in global_indices_rid.items():
Expand Down Expand Up @@ -410,10 +410,11 @@ def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 8)
>>> mesh_axes = ['x', 'y']
>>> mesh_axes = P('x', 'y')
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...
Expand Down Expand Up @@ -456,10 +457,11 @@ def from_batched_callback(cls, global_shape: Shape,
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 2)
>>> mesh_axes = ['x']
>>> mesh_axes = P('x')
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...
Expand Down Expand Up @@ -502,10 +504,11 @@ def from_batched_callback_with_devices(
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 2)
>>> mesh_axes = [('x', 'y')]
>>> mesh_axes = P(('x', 'y'))
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
...
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/maps.py
Expand Up @@ -2028,7 +2028,7 @@ def _check_gda_xmap_partitioning(axis_resources, resource_env,
f"mesh: {resource_env.physical_mesh},\n"
f"GDA mesh: {arg.mesh}")

gda_array_mapping = _get_array_mapping(arg._mesh_axes)
gda_array_mapping = _get_array_mapping(arg.mesh_axes)
if gda_array_mapping != xmap_array_mapping:
raise ValueError(
"Got an input GDA to xmap with different partitioning than "
Expand Down
11 changes: 3 additions & 8 deletions jax/experimental/pjit.py
Expand Up @@ -1053,7 +1053,9 @@ def _create_cpspec(x):
def _maybe_replace_from_gda_with_pspec(
in_axis_resources_flat: CanonicalizedParsedPartitionSpec, arg) -> CanonicalizedParsedPartitionSpec:
if isinstance(arg, GDA):
gda_cpspec = gda_mesh_axes_to_canonicalized_parsed_pspec(arg._mesh_axes)
gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(
arg.mesh_axes, arg_name="GDA mesh_axes"))
assert type(gda_cpspec) is CanonicalizedParsedPartitionSpec
if (not _is_from_gda(in_axis_resources_flat) and
in_axis_resources_flat != gda_cpspec):
Expand All @@ -1066,13 +1068,6 @@ def _maybe_replace_from_gda_with_pspec(
return gda_cpspec
return in_axis_resources_flat

def gda_mesh_axes_to_canonicalized_parsed_pspec(mesh_axes) -> CanonicalizedParsedPartitionSpec:
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
return CanonicalizedParsedPartitionSpec(ParsedPartitionSpec.from_user_input(
pspec, arg_name='GDA mesh_axes'))

def _maybe_check_pjit_gda_mesh(args, mesh):
for x in args:
Expand Down
21 changes: 12 additions & 9 deletions jax/interpreters/pxla.py
Expand Up @@ -414,20 +414,24 @@ def _shard_abstract_array(size, axis: int, x):
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]

AxisResource = Tuple[Optional[Tuple[Any, ...]], ...]

def array_mapping_to_axis_resources(array_mapping: ArrayMapping) -> AxisResource:
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
# TODO(yashkatariya): Move PartitionSpec into a place where all files can
# import it without cyclic dependency.
from jax.interpreters.sharded_jit import PartitionSpec

if not array_mapping:
return tuple()
return PartitionSpec()
max_index = -1
reverse_map = defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
if index > max_index:
max_index = index
return tuple(
tuple(reverse_map[i]) if reverse_map[i] else None for i in range(max_index + 1)
)
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
for i in range(max_index + 1))
return PartitionSpec(*partitions)


def local_aval_to_result_handler(
aval: core.AbstractValue,
Expand Down Expand Up @@ -465,14 +469,13 @@ def sda_array_result_handler(aval: ShapedArray, sharding_spec, indices):


def global_aval_to_result_handler(
aval: core.AbstractValue,
out_axis_resources: Optional[AxisResource], global_mesh,
aval: core.AbstractValue, out_axis_resources, global_mesh,
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The global output AbstractValue.
out_axis_resources: A tuple specifying the sharding of outputs.
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
Expand Down

0 comments on commit 99a1037

Please sign in to comment.