diff --git a/jax/_src/api.py b/jax/_src/api.py index ceea2a5a0d53..fc72059a30d6 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3268,11 +3268,6 @@ def clear_backends(): """ Clear all backend clients so that new backend clients can be created later. """ - - if xc._version < 79: - raise RuntimeError("clear_backends is not supported in the jaxlib used." - "Please update your jaxlib package.") - xb._clear_backends() jax.lib.xla_bridge._backends = {} dispatch.xla_callable.cache_clear() # type: ignore diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 6f39364cbbc8..d4c96dd37d1e 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -22,7 +22,6 @@ from jax._src import cloud_tpu_init from jax._src.config import config from jax._src.lib import xla_bridge -from jax._src.lib import xla_client from jax._src.lib import xla_extension class State: @@ -79,7 +78,7 @@ def initialize(self, logging.info('Connecting to JAX distributed service on %s', coordinator_address) self.client.connect() - if xla_client._version >= 77 and config.jax_coordination_service: + if config.jax_coordination_service: self.initialize_preemption_sync_manager() def shutdown(self): diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 52a8024803bd..dc03a4602e32 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -19,7 +19,6 @@ import numpy as np -from jax._src.lib import mlir_api_version from jax import core from jax._src import dtypes from jax._src.lax import lax @@ -704,23 +703,20 @@ def _conv_general_dilated_lower( if len(padding) == 0: padding = np.zeros((0, 2), dtype=np.int64) window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims) - if mlir_api_version < 24: - op = mhlo.ConvOp - else: - op = mhlo.ConvolutionOp return [ - op(mlir.aval_to_ir_type(aval_out), - lhs, - rhs, - dimension_numbers=dnums, - feature_group_count=mlir.i64_attr(feature_group_count), - batch_group_count=mlir.i64_attr(batch_group_count), - window_strides=mlir.dense_int_elements(window_strides), - padding=mlir.dense_int_elements(padding), - lhs_dilation=mlir.dense_int_elements(lhs_dilation), - rhs_dilation=mlir.dense_int_elements(rhs_dilation), - window_reversal=window_reversal, - precision_config=lax.precision_attr(precision)).result + mhlo.ConvolutionOp( + mlir.aval_to_ir_type(aval_out), + lhs, + rhs, + dimension_numbers=dnums, + feature_group_count=mlir.i64_attr(feature_group_count), + batch_group_count=mlir.i64_attr(batch_group_count), + window_strides=mlir.dense_int_elements(window_strides), + padding=mlir.dense_int_elements(padding), + lhs_dilation=mlir.dense_int_elements(lhs_dilation), + rhs_dilation=mlir.dense_int_elements(rhs_dilation), + window_reversal=window_reversal, + precision_config=lax.precision_attr(precision)).result ] mlir.register_lowering(conv_general_dilated_p, _conv_general_dilated_lower) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6e9def67758a..aa433343aa9c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1755,17 +1755,11 @@ def logistic_impl(x): sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -if mlir_api_version < 27: - mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SinOp)) -else: - mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SineOp)) +mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SineOp)) cos_p = standard_unop(_float | _complex, 'cos') ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) -if mlir_api_version < 28: - mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosOp)) -else: - mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosineOp)) +mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosineOp)) @_upcast_fp16_for_computation def _tan_impl(x): @@ -2170,10 +2164,7 @@ def _sub_transpose(t, x, y): sub_p = standard_naryop([_num, _num], 'sub') ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose -if mlir_api_version < 29: - mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp)) -else: - mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubtractOp)) +mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubtractOp)) def _mul_transpose(ct, x, y): @@ -4244,11 +4235,8 @@ def _rng_uniform_lowering(ctx, a, b, *, shape): aval_out, = ctx.avals_out shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64), canonicalize_types=False) - if mlir_api_version <= 22: - return mhlo.RngUniformOp(a, b, shape).results - else: - return mhlo.RngOp(a, b, shape, - mhlo.RngDistributionAttr.get('UNIFORM')).results + return mhlo.RngOp(a, b, shape, + mhlo.RngDistributionAttr.get('UNIFORM')).results mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 7bd023ebe3df..2555ad2095f5 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -39,7 +39,6 @@ from jax._src.lax import eigh as lax_eigh from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd -from jax._src.lib import mlir_api_version from jax._src.lib import lapack from jax._src.lib import gpu_linalg @@ -1189,11 +1188,7 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand): m = operand_aval.shape[-2] lu, pivot, info = getrf_impl(operand_aval.dtype, operand) # Subtract 1 from the pivot to get 0-based indices. - if mlir_api_version < 29: - op = mhlo.SubOp - else: - op = mhlo.SubtractOp - pivot = op(pivot, mlir.full_like_aval(1, pivot_aval)).result + pivot = mhlo.SubtractOp(pivot, mlir.full_like_aval(1, pivot_aval)).result ok = mlir.compare_mhlo( info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))), "GE", "SIGNED") diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index 4b53dedde9c9..81d892a24d58 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -17,18 +17,5 @@ import jaxlib.mlir.dialects.chlo as chlo import jaxlib.mlir.dialects.mhlo as mhlo import jaxlib.mlir.dialects.func as func - -try: - import jaxlib.mlir.dialects.ml_program as ml_program -except (ModuleNotFoundError, ImportError): - # TODO(phawkins): make this unconditional when jaxlib > 0.3.14 - # is the minimum version. - pass - -try: - import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor -except (ModuleNotFoundError, ImportError): - # TODO(ajcbik,phawkins): make this unconditional when jaxlib > 0.3.7 - # is the minimum version. - pass - +import jaxlib.mlir.dialects.ml_program as ml_program +import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor diff --git a/jax/_src/lib/xla_bridge.py b/jax/_src/lib/xla_bridge.py index 4f94aad8f582..38e648296550 100644 --- a/jax/_src/lib/xla_bridge.py +++ b/jax/_src/lib/xla_bridge.py @@ -162,8 +162,7 @@ def get_compile_options( debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_test_all_input_layouts = False - if lib.xla_extension_version >= 68: - compile_options.profile_version = FLAGS.jax_xla_profile_version + compile_options.profile_version = FLAGS.jax_xla_profile_version return compile_options diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index d90257e1f983..3d7e259df34a 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -23,8 +23,6 @@ import warnings from jax._src.lib import pytree -from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.util import safe_zip, unzip2 @@ -34,10 +32,7 @@ T = TypeVar("T") U = TypeVar("U") -if TYPE_CHECKING or xla_extension_version >= 78: - PyTreeDef = pytree.PyTreeDef -else: - PyTreeDef = xla_extension.PyTreeDef # pytype: disable=module-attr +PyTreeDef = pytree.PyTreeDef def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None): diff --git a/jax/version.py b/jax/version.py index f79264cb068e..4769e008e9ca 100644 --- a/jax/version.py +++ b/jax/version.py @@ -16,7 +16,7 @@ # eval()-ed by setup.py, so it should not have any dependencies. __version__ = "0.3.18" -_minimum_jaxlib_version = "0.3.14" +_minimum_jaxlib_version = "0.3.15" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/tests/clear_backends_test.py b/tests/clear_backends_test.py index 66639f10b77e..446b790b75ec 100644 --- a/tests/clear_backends_test.py +++ b/tests/clear_backends_test.py @@ -19,7 +19,6 @@ from jax.config import config from jax._src import test_util as jtu from jax._src.lib import xla_bridge as xb -from jax._src.lib import xla_client as xc config.parse_flags_with_absl() @@ -29,11 +28,10 @@ class ClearBackendsTest(jtu.JaxTestCase): def test_clear_backends(self): g = jax.jit(lambda x, y: x * y) self.assertEqual(g(1, 2), 2) - if xc._version >= 79: - self.assertNotEmpty(xb.get_backend().live_executables()) - jax.clear_backends() - self.assertEmpty(xb.get_backend().live_executables()) - self.assertEqual(g(1, 2), 2) + self.assertNotEmpty(xb.get_backend().live_executables()) + jax.clear_backends() + self.assertEmpty(xb.get_backend().live_executables()) + self.assertEqual(g(1, 2), 2) if __name__ == "__main__":