Skip to content

Commit

Permalink
Add a backwards compat path for op_sharding.clone() because it does…
Browse files Browse the repository at this point in the history
…n't exist with the latest jaxlib on pypi

PiperOrigin-RevId: 477034758
  • Loading branch information
yashk2810 authored and jax authors committed Sep 27, 2022
1 parent cbf34cb commit 389a2e5
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions jax/_src/prng.py
Expand Up @@ -43,8 +43,9 @@
from jax._src.numpy import lax_numpy
import jax._src.pretty_printer as pp
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip

from jax._src.lib import gpu_prng
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -296,11 +297,24 @@ def physical_op_sharding(aval, sharding):
op_sharding = sharding._to_xla_op_sharding(aval.ndim)
key_shape = aval.dtype.impl.key_shape

new_op_sharding = op_sharding.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
tad.extend([1] * len(key_shape))
new_op_sharding.tile_assignment_dimensions = tad
return new_op_sharding
if xla_extension_version >= 83:
new_op_sharding = op_sharding.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
tad.extend([1] * len(key_shape))
new_op_sharding.tile_assignment_dimensions = tad
return new_op_sharding
else:
# TODO(yashkatariya): This is hacky. Remove this once
# minimum_jaxlib_version is bumped to 0.3.19
new_op_sharding = xc.OpSharding()
new_op_sharding.type = op_sharding.type
new_op_sharding.tile_assignment_devices = op_sharding.tile_assignment_devices
tad = list(op_sharding.tile_assignment_dimensions)
tad.extend([1] * len(key_shape))
new_op_sharding.tile_assignment_dimensions = tad
new_op_sharding.last_tile_dims = op_sharding.last_tile_dims
new_op_sharding.replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim
return new_op_sharding

@staticmethod
def result_handler(sticky_device, aval):
Expand Down

0 comments on commit 389a2e5

Please sign in to comment.