Skip to content

Commit

Permalink
Xmap GDA integration. Non-contiguous mesh is allowed!
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 429376557
  • Loading branch information
yashk2810 authored and jax authors committed Feb 17, 2022
1 parent 83a5020 commit 6bb58e6
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 98 deletions.
16 changes: 8 additions & 8 deletions jax/experimental/global_device_array.py
Expand Up @@ -46,21 +46,21 @@ def wrapper(*args, **kwargs):
return wrapper


def _canonicalize_mesh_axes(mesh_axes):
def _get_array_mapping(mesh_axes):
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
return pspec
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
return get_array_mapping(parsed_pspec)


def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Tuple[Index, ...]:
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

pspec = _canonicalize_mesh_axes(mesh_axes)
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
array_mapping = get_array_mapping(parsed_pspec)
array_mapping = _get_array_mapping(mesh_axes)
# The dtype doesn't matter for creating sharding specs.
aval = core.ShapedArray(global_shape, np.float32)
sharding_spec = pxla.mesh_sharding_specs(
Expand Down

0 comments on commit 6bb58e6

Please sign in to comment.