Skip to content

Commit

Permalink
Bump minimum jaxlib version to 0.4.6 which means xla_extension_versio…
Browse files Browse the repository at this point in the history
…n == 137 and mlir_api_version == 45

PiperOrigin-RevId: 516364523
  • Loading branch information
yashk2810 authored and jax authors committed Mar 14, 2023
1 parent 31aee3a commit 136749d
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 117 deletions.
10 changes: 2 additions & 8 deletions jax/_src/api.py
Expand Up @@ -64,7 +64,6 @@
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib import xla_extension_version
from jax._src.sharding_impls import PmapSharding
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix, _generate_key_paths
Expand Down Expand Up @@ -2453,13 +2452,8 @@ def cache_miss(*args, **kwargs):

return out, fastpath_data

shard_arg_fallback: Callable[..., Any]
if xla_extension_version >= 133:
shard_arg_fallback = pxla.shard_arg
else:
shard_arg_fallback = lambda x, ds, idxs: pxla.shard_arg(x, ds, idxs, None)
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple, shard_arg_fallback)
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg)

pmap_f = wraps(fun)(cpp_mapped_f)

Expand Down Expand Up @@ -3126,7 +3120,7 @@ def _device_put_sharded(*xs):
return pxla.batched_device_put(
stacked_aval,
PmapSharding(np.array(devices), sharding_spec),
xs, devices)
xs, list(devices))
else:
buffers = [buf for x, d in zip(xs, devices)
for buf in dispatch.device_put(x, d)]
Expand Down
34 changes: 10 additions & 24 deletions jax/_src/array.py
Expand Up @@ -28,9 +28,8 @@
from jax._src import dispatch
from jax._src import dtypes
from jax._src.config import config
from jax._src.util import safe_zip, use_cpp_class, use_cpp_method
from jax._src.util import use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import api
from jax._src.typing import ArrayLike
from jax.interpreters import mlir
Expand Down Expand Up @@ -655,21 +654,14 @@ def _array_shard_arg(x, devices, indices, sharding):
x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
if not x.is_fully_addressable:
if tuple(x_indices) == tuple(indices):
if xla_extension_version >= 136:
return x
else:
return x._arrays
return x
else:
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
else:
if tuple(x_indices) == tuple(indices):
if xla_extension_version >= 136:
return xc.copy_array_to_devices_with_sharding(
x, list(devices), sharding)
else:
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
return xc.copy_array_to_devices_with_sharding(
x, list(devices), sharding)
# Resharding starts here:
if dispatch.is_single_device_sharding(x.sharding):
return pxla.shard_device_array(x, devices, indices, sharding)
Expand All @@ -688,12 +680,9 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed, is_out_sharding_from_xla)
if xla_extension_version >= 131:
return xc.array_result_handler(
global_aval, out_sharding, committed=committed, _skip_checks=True
)
return lambda bufs: ArrayImpl(global_aval, out_sharding, bufs,
committed=committed, _skip_checks=True)
return xc.array_result_handler(
global_aval, out_sharding, committed=committed, _skip_checks=True
)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
Expand All @@ -706,11 +695,8 @@ def _array_local_result_handler(aval, sharding, indices):
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
if xla_extension_version >= 131:
return xc.array_result_handler(
aval, sharding, committed=True, _skip_checks=True
)
return lambda bufs: ArrayImpl(aval, sharding, bufs, committed=True,
_skip_checks=True)
return xc.array_result_handler(
aval, sharding, committed=True, _skip_checks=True
)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
5 changes: 1 addition & 4 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1533,10 +1533,7 @@ def set_sharding(op, sharding_proto: xc.OpSharding):


def get_sharding_attr(sharding_proto: xc.OpSharding):
if xc.mlir_api_version >= 44:
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
else:
return ir.StringAttr.get(sharding_proto.SerializeToString())
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))


# MLIR lowerings for lax primitives
Expand Down
71 changes: 22 additions & 49 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -76,7 +76,6 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
Expand Down Expand Up @@ -429,10 +428,7 @@ def _shard_token(x, devices, indices, sharding):
zeros = np.zeros((), dtype=np.dtype(np.bool_))
aval = api_util.shaped_abstractify(zeros)
out = batched_device_put(aval, sharding, [zeros for i in indices], devices)
if xla_extension_version >= 136:
return out
else:
return out._arrays
return out
return device_put(np.zeros((), dtype=np.dtype(np.bool_)), devices, replicate=True)
shard_arg_handlers[core.Token] = _shard_token

Expand All @@ -447,10 +443,7 @@ def _shard_array(x, devices, indices, sharding):
if jax.config.jax_array:
aval = api_util.shaped_abstractify(x)
out = batched_device_put(aval, sharding, [x[i] for i in indices], devices)
if xla_extension_version >= 136:
return out
else:
return out._arrays
return out
return device_put([x[i] for i in indices], devices)
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
Expand All @@ -462,22 +455,13 @@ def shard_device_array(x, devices, indices, sharding):
if jax.config.jax_array:
aval = api_util.shaped_abstractify(x)
out = batched_device_put(aval, sharding, shards, devices)
if xla_extension_version >= 136:
return out
else:
return out._arrays
return out
return device_put(shards, devices)
for t in device_array.device_array_types:
shard_arg_handlers[t] = shard_device_array


if xla_extension_version >= 136:
batched_device_put = xc.batched_device_put # pytype: disable=module-attr
else:
def batched_device_put(aval, sharding, xs, devices, committed=True):
from jax._src.array import ArrayImpl
bufs = [d.client.buffer_from_pyval(x, d) for x, d in safe_zip(xs, devices)]
return ArrayImpl(aval, sharding, bufs, committed, _skip_checks=True)
batched_device_put = xc.batched_device_put # pytype: disable=module-attr

# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
Expand Down Expand Up @@ -980,7 +964,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
break
else:
bufs.append(buf.copy_to_device(device))
if xla_extension_version >= 136 and isinstance(x, ArrayImpl):
if isinstance(x, ArrayImpl):
return ArrayImpl(x.aval, sharding, bufs, committed=True)
return bufs

Expand Down Expand Up @@ -2196,35 +2180,24 @@ def _call_with_tokens(self, input_bufs):
def __call__(self, *args):
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
input_bufs = self.in_handler(args)
if xla_extension_version >= 131:
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
input_bufs = self._add_tokens_to_inputs(input_bufs)
results = self.xla_executable.execute_sharded(
input_bufs, with_tokens=True
)
self._handle_token_bufs(
results.disassemble_prefix_into_single_device_arrays(
len(self.ordered_effects)),
results.consume_token())
else:
results = self.xla_executable.execute_sharded(input_bufs)
if dispatch.needs_check_special():
out_arrays = results.disassemble_into_single_device_arrays()
for arrays in out_arrays:
dispatch.check_special(self.name, arrays)
return self.out_handler(out_arrays)
return results.consume_with_handlers(self.out_handler.handlers)
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
input_bufs = self._add_tokens_to_inputs(input_bufs)
results = self.xla_executable.execute_sharded(
input_bufs, with_tokens=True
)
self._handle_token_bufs(
results.disassemble_prefix_into_single_device_arrays(
len(self.ordered_effects)),
results.consume_token())
else:
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
out_bufs = self._call_with_tokens(input_bufs)
else:
out_bufs = self.xla_executable.execute_sharded_on_local_devices(input_bufs)
if dispatch.needs_check_special():
for bufs in out_bufs:
dispatch.check_special(self.name, bufs)
return self.out_handler(out_bufs)
results = self.xla_executable.execute_sharded(input_bufs)
if dispatch.needs_check_special():
out_arrays = results.disassemble_into_single_device_arrays()
for arrays in out_arrays:
dispatch.check_special(self.name, arrays)
return self.out_handler(out_arrays)
return results.consume_with_handlers(self.out_handler.handlers)


xla_pmap_p = core.MapPrimitive('xla_pmap')
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/lax/lax.py
Expand Up @@ -1988,10 +1988,7 @@ def _bessel_i1e_jvp(g, y, x):
erf_inv_p = standard_unop(_float, 'erf_inv')
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
mul(g, exp(square(ans)))))
if xla_client.mlir_api_version >= 45:
mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp))
else:
xla.register_translation(erf_inv_p, standard_translate(erf_inv_p))
mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp))

real_p = unop(_complex_basetype, _complex, 'real')
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/sharding_impls.py
Expand Up @@ -25,7 +25,6 @@
from jax._src import sharding
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.interpreters import mlir
from jax._src.interpreters import pxla

Expand Down Expand Up @@ -579,7 +578,7 @@ def __eq__(self, other) -> bool:
self._ids == other._ids)


@use_cpp_class(xc.GSPMDSharding if xla_extension_version >= 129 else xc.OpShardingSharding) # type: ignore
@use_cpp_class(xc.GSPMDSharding)
class GSPMDSharding(XLACompatibleSharding):

@use_cpp_method()
Expand Down
2 changes: 1 addition & 1 deletion jax/version.py
Expand Up @@ -16,7 +16,7 @@
# eval()-ed by setup.py, so it should not have any dependencies.

__version__ = "0.4.7"
_minimum_jaxlib_version = "0.4.4"
_minimum_jaxlib_version = "0.4.6"

def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
Expand Down
4 changes: 1 addition & 3 deletions tests/array_test.py
Expand Up @@ -28,7 +28,6 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax.interpreters import pxla
from jax.experimental.pjit import pjit
Expand Down Expand Up @@ -83,8 +82,7 @@ def create_array(shape, sharding, global_data=None):
class JaxArrayTest(jtu.JaxTestCase):

def test_array_impl_name(self):
expected = "Array" if xla_extension_version < 135 else "ArrayImpl"
self.assertEqual(array.ArrayImpl.__name__, expected)
self.assertEqual(array.ArrayImpl.__name__, "ArrayImpl")

@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
Expand Down
11 changes: 2 additions & 9 deletions tests/lax_numpy_operators_test.py
Expand Up @@ -32,7 +32,6 @@

from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -499,10 +498,7 @@ def testOperatorOverloadErrors(self, name, othertype):
other = othertype(data)

if config.jax_array:
if xla_extension_version < 135:
val_str = 'Array'
else:
val_str = 'ArrayImpl'
val_str = 'ArrayImpl'
else:
val_str = 'DeviceArray'
msg = f"unsupported operand type.* '{val_str}' and '{othertype.__name__}'"
Expand All @@ -521,10 +517,7 @@ def testRightOperatorOverloadErrors(self, name, othertype):
other = othertype(data)

if config.jax_array:
if xla_extension_version < 135:
val_str = 'Array'
else:
val_str = 'ArrayImpl'
val_str = 'ArrayImpl'
else:
val_str = 'DeviceArray'
msg = f"unsupported operand type.* '{othertype.__name__}' and '{val_str}'"
Expand Down
3 changes: 1 addition & 2 deletions tests/lax_test.py
Expand Up @@ -45,7 +45,6 @@
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src import lax_reference
from jax._src.lib import xla_extension_version
from jax._src.lax import lax as lax_internal
from jax._src.internal_test_util import lax_test_util

Expand Down Expand Up @@ -2987,7 +2986,7 @@ def shard_foo_array_handler(x, devices, indices, sharding):
if isinstance(x.data, array.ArrayImpl):
bufs = dispatch._device_put_jax_array(x.data, device)
bufs = dispatch._device_put_array(x.data, device)
if config.jax_array and xla_extension_version >= 136:
if config.jax_array:
aval = core.raise_to_shaped(core.get_aval(x.data))
return array.ArrayImpl(aval, sharding, list(bufs), committed=True)
return bufs
Expand Down
7 changes: 2 additions & 5 deletions tests/lax_vmap_op_test.py
Expand Up @@ -24,7 +24,6 @@

from jax._src import test_util as jtu
from jax._src.internal_test_util import lax_test_util
from jax._src import lib
from jax._src import util

from jax.config import config
Expand Down Expand Up @@ -70,10 +69,8 @@ def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
dtype=rec.dtypes,
) for rec in lax_test_util.lax_ops()))
def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
# TODO(pizzud): Make this unconditional after the next minimum jaxlib bump.
if lib.xla_extension_version >= 134:
if dtype == np.float64 or any(len(shape) > 2 for shape in shapes):
self.skipTest('Skipping big tests under sanitizers due to slowdown.')
if dtype == np.float64 or any(len(shape) > 2 for shape in shapes):
self.skipTest('Skipping big tests under sanitizers due to slowdown.')

rng = rng_factory(self.rng())
op = getattr(lax, op_name)
Expand Down
11 changes: 5 additions & 6 deletions tests/pjit_test.py
Expand Up @@ -360,12 +360,11 @@ def f(x):
self.assertAllClose(np.asarray(actual.device_buffers[0]), expected,
check_dtypes=False)

if xc.mlir_api_version >= 44:
hlo = f.lower(np.ones(shape)).compiler_ir()
# Annotation from with_sharding_constraint
self.assertIn('sharding = "{devices=[2,1]0,1}"', str(hlo))
# Annotation from pjit
self.assertIn('sharding = "{replicated}"', str(hlo))
hlo = f.lower(np.ones(shape)).compiler_ir()
# Annotation from with_sharding_constraint
self.assertIn('sharding = "{devices=[2,1]0,1}"', str(hlo))
# Annotation from pjit
self.assertIn('sharding = "{replicated}"', str(hlo))

@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingConstraint(self):
Expand Down

0 comments on commit 136749d

Please sign in to comment.