diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 136c6b9e66a3..e4cd19b7f577 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -43,8 +43,9 @@ from jax._src.numpy import lax_numpy import jax._src.pretty_printer as pp from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip - from jax._src.lib import gpu_prng +from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -296,11 +297,24 @@ def physical_op_sharding(aval, sharding): op_sharding = sharding._to_xla_op_sharding(aval.ndim) key_shape = aval.dtype.impl.key_shape - new_op_sharding = op_sharding.clone() - tad = list(new_op_sharding.tile_assignment_dimensions) - tad.extend([1] * len(key_shape)) - new_op_sharding.tile_assignment_dimensions = tad - return new_op_sharding + if xla_extension_version >= 83: + new_op_sharding = op_sharding.clone() + tad = list(new_op_sharding.tile_assignment_dimensions) + tad.extend([1] * len(key_shape)) + new_op_sharding.tile_assignment_dimensions = tad + return new_op_sharding + else: + # TODO(yashkatariya): This is hacky. Remove this once + # minimum_jaxlib_version is bumped to 0.3.19 + new_op_sharding = xc.OpSharding() + new_op_sharding.type = op_sharding.type + new_op_sharding.tile_assignment_devices = op_sharding.tile_assignment_devices + tad = list(op_sharding.tile_assignment_dimensions) + tad.extend([1] * len(key_shape)) + new_op_sharding.tile_assignment_dimensions = tad + new_op_sharding.last_tile_dims = op_sharding.last_tile_dims + new_op_sharding.replicate_on_last_tile_dim = op_sharding.replicate_on_last_tile_dim + return new_op_sharding @staticmethod def result_handler(sticky_device, aval):