Skip to content

Commit

Permalink
Bump the minimum jaxlib version to 0.3.15.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Sep 8, 2022
1 parent 001ae22 commit 6c59d72
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 77 deletions.
5 changes: 0 additions & 5 deletions jax/_src/api.py
Expand Up @@ -3268,11 +3268,6 @@ def clear_backends():
"""
Clear all backend clients so that new backend clients can be created later.
"""

if xc._version < 79:
raise RuntimeError("clear_backends is not supported in the jaxlib used."
"Please update your jaxlib package.")

xb._clear_backends()
jax.lib.xla_bridge._backends = {}
dispatch.xla_callable.cache_clear() # type: ignore
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/distributed.py
Expand Up @@ -22,7 +22,6 @@
from jax._src import cloud_tpu_init
from jax._src.config import config
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension

class State:
Expand Down Expand Up @@ -79,7 +78,7 @@ def initialize(self,
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()

if xla_client._version >= 77 and config.jax_coordination_service:
if config.jax_coordination_service:
self.initialize_preemption_sync_manager()

def shutdown(self):
Expand Down
30 changes: 13 additions & 17 deletions jax/_src/lax/convolution.py
Expand Up @@ -19,7 +19,6 @@

import numpy as np

from jax._src.lib import mlir_api_version
from jax import core
from jax._src import dtypes
from jax._src.lax import lax
Expand Down Expand Up @@ -704,23 +703,20 @@ def _conv_general_dilated_lower(
if len(padding) == 0:
padding = np.zeros((0, 2), dtype=np.int64)
window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
if mlir_api_version < 24:
op = mhlo.ConvOp
else:
op = mhlo.ConvolutionOp
return [
op(mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
dimension_numbers=dnums,
feature_group_count=mlir.i64_attr(feature_group_count),
batch_group_count=mlir.i64_attr(batch_group_count),
window_strides=mlir.dense_int_elements(window_strides),
padding=mlir.dense_int_elements(padding),
lhs_dilation=mlir.dense_int_elements(lhs_dilation),
rhs_dilation=mlir.dense_int_elements(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision)).result
mhlo.ConvolutionOp(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
dimension_numbers=dnums,
feature_group_count=mlir.i64_attr(feature_group_count),
batch_group_count=mlir.i64_attr(batch_group_count),
window_strides=mlir.dense_int_elements(window_strides),
padding=mlir.dense_int_elements(padding),
lhs_dilation=mlir.dense_int_elements(lhs_dilation),
rhs_dilation=mlir.dense_int_elements(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision)).result
]

mlir.register_lowering(conv_general_dilated_p, _conv_general_dilated_lower)
Expand Down
22 changes: 5 additions & 17 deletions jax/_src/lax/lax.py
Expand Up @@ -1755,17 +1755,11 @@ def logistic_impl(x):

sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
if mlir_api_version < 27:
mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SinOp))
else:
mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SineOp))
mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SineOp))

cos_p = standard_unop(_float | _complex, 'cos')
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
if mlir_api_version < 28:
mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosOp))
else:
mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosineOp))
mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosineOp))

@_upcast_fp16_for_computation
def _tan_impl(x):
Expand Down Expand Up @@ -2170,10 +2164,7 @@ def _sub_transpose(t, x, y):
sub_p = standard_naryop([_num, _num], 'sub')
ad.primitive_jvps[sub_p] = _sub_jvp
ad.primitive_transposes[sub_p] = _sub_transpose
if mlir_api_version < 29:
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp))
else:
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubtractOp))
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubtractOp))


def _mul_transpose(ct, x, y):
Expand Down Expand Up @@ -4244,11 +4235,8 @@ def _rng_uniform_lowering(ctx, a, b, *, shape):
aval_out, = ctx.avals_out
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
canonicalize_types=False)
if mlir_api_version <= 22:
return mhlo.RngUniformOp(a, b, shape).results
else:
return mhlo.RngOp(a, b, shape,
mhlo.RngDistributionAttr.get('UNIFORM')).results
return mhlo.RngOp(a, b, shape,
mhlo.RngDistributionAttr.get('UNIFORM')).results

mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)

Expand Down
7 changes: 1 addition & 6 deletions jax/_src/lax/linalg.py
Expand Up @@ -39,7 +39,6 @@
from jax._src.lax import eigh as lax_eigh
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lib import mlir_api_version
from jax._src.lib import lapack

from jax._src.lib import gpu_linalg
Expand Down Expand Up @@ -1189,11 +1188,7 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
m = operand_aval.shape[-2]
lu, pivot, info = getrf_impl(operand_aval.dtype, operand)
# Subtract 1 from the pivot to get 0-based indices.
if mlir_api_version < 29:
op = mhlo.SubOp
else:
op = mhlo.SubtractOp
pivot = op(pivot, mlir.full_like_aval(1, pivot_aval)).result
pivot = mhlo.SubtractOp(pivot, mlir.full_like_aval(1, pivot_aval)).result
ok = mlir.compare_mhlo(
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
"GE", "SIGNED")
Expand Down
17 changes: 2 additions & 15 deletions jax/_src/lib/mlir/dialects/__init__.py
Expand Up @@ -17,18 +17,5 @@
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.func as func

try:
import jaxlib.mlir.dialects.ml_program as ml_program
except (ModuleNotFoundError, ImportError):
# TODO(phawkins): make this unconditional when jaxlib > 0.3.14
# is the minimum version.
pass

try:
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
except (ModuleNotFoundError, ImportError):
# TODO(ajcbik,phawkins): make this unconditional when jaxlib > 0.3.7
# is the minimum version.
pass

import jaxlib.mlir.dialects.ml_program as ml_program
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
3 changes: 1 addition & 2 deletions jax/_src/lib/xla_bridge.py
Expand Up @@ -162,8 +162,7 @@ def get_compile_options(
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False

if lib.xla_extension_version >= 68:
compile_options.profile_version = FLAGS.jax_xla_profile_version
compile_options.profile_version = FLAGS.jax_xla_profile_version
return compile_options


Expand Down
7 changes: 1 addition & 6 deletions jax/_src/tree_util.py
Expand Up @@ -23,8 +23,6 @@
import warnings

from jax._src.lib import pytree
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version

from jax._src.util import safe_zip, unzip2

Expand All @@ -34,10 +32,7 @@
T = TypeVar("T")
U = TypeVar("U")

if TYPE_CHECKING or xla_extension_version >= 78:
PyTreeDef = pytree.PyTreeDef
else:
PyTreeDef = xla_extension.PyTreeDef # pytype: disable=module-attr
PyTreeDef = pytree.PyTreeDef


def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
Expand Down
2 changes: 1 addition & 1 deletion jax/version.py
Expand Up @@ -16,7 +16,7 @@
# eval()-ed by setup.py, so it should not have any dependencies.

__version__ = "0.3.18"
_minimum_jaxlib_version = "0.3.14"
_minimum_jaxlib_version = "0.3.15"

def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
Expand Down
10 changes: 4 additions & 6 deletions tests/clear_backends_test.py
Expand Up @@ -19,7 +19,6 @@
from jax.config import config
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc

config.parse_flags_with_absl()

Expand All @@ -29,11 +28,10 @@ class ClearBackendsTest(jtu.JaxTestCase):
def test_clear_backends(self):
g = jax.jit(lambda x, y: x * y)
self.assertEqual(g(1, 2), 2)
if xc._version >= 79:
self.assertNotEmpty(xb.get_backend().live_executables())
jax.clear_backends()
self.assertEmpty(xb.get_backend().live_executables())
self.assertEqual(g(1, 2), 2)
self.assertNotEmpty(xb.get_backend().live_executables())
jax.clear_backends()
self.assertEmpty(xb.get_backend().live_executables())
self.assertEqual(g(1, 2), 2)


if __name__ == "__main__":
Expand Down

0 comments on commit 6c59d72

Please sign in to comment.