Skip to content

Commit

Permalink
Method for obtaining the sharding of jax.Array-like objects.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 496394589
  • Loading branch information
jpuigcerver authored and Copybara-Service committed Dec 19, 2022
1 parent 7c168db commit c28c6fa
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions vmoe/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@
UnparsedPartitionSpec = Union[str, Tuple[Union[str, Tuple[str, ...]], ...]]


def get_array_sharding_or_default(arr: jax.Array) -> sharding.Sharding:
if hasattr(arr, 'sharding'):
return arr.sharding
else:
op_sharding = jax.xla.xc.OpSharding()
op_sharding.type = jax.xla.xc.OpSharding.Type.REPLICATED
return sharding.OpShardingSharding(jax.devices(), op_sharding)


def process_has_contiguous_device_slice(devices: np.ndarray,
process_index: int) -> bool:
"""Checks if the devices of a process form a contiguous slice in the mesh."""
Expand Down

0 comments on commit c28c6fa

Please sign in to comment.