Skip to content

Commit

Permalink
Remove isinstance checks
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 425745786
  • Loading branch information
yashk2810 authored and jax authors committed Feb 2, 2022
1 parent dcca99b commit 3acbd44
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions jax/experimental/global_device_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _canonicalize_mesh_axes(mesh_axes):
return pspec

def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
mesh_axes: MeshAxes) -> Tuple[Index, ...]:
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

Expand All @@ -66,11 +66,7 @@ def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
sharding_spec = pxla.mesh_sharding_specs(
global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
indices = pxla.spec_to_indices(global_shape, sharding_spec)
for index in indices:
assert isinstance(index, tuple)
for idx in index:
assert isinstance(idx, slice)
return indices
return indices # type: ignore


@_convert_list_args_to_tuple
Expand Down

0 comments on commit 3acbd44

Please sign in to comment.