Skip to content

Commit

Permalink
Update minimum jaxlib version to 0.3.14.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Jul 8, 2022
1 parent 44bd311 commit 0b4b0ba
Show file tree
Hide file tree
Showing 32 changed files with 231 additions and 699 deletions.
6 changes: 2 additions & 4 deletions jax/_src/api.py
Expand Up @@ -59,7 +59,6 @@
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib import xla_extension_version
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
Expand Down Expand Up @@ -585,9 +584,8 @@ def get_device_info():
return _BackendAndDeviceInfo(default_device, committed_to_device)

jitted_f_kwargs = {}
if xla_extension_version >= 71:
jitted_f_kwargs["has_explicit_device"] = (
device is not None or backend is not None)
jitted_f_kwargs["has_explicit_device"] = (
device is not None or backend is not None)
cpp_jitted_f = jax_jit.jit(
fun,
cache_miss,
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/debugging.py
Expand Up @@ -109,9 +109,8 @@ def _callback(*flat_args):
return result
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
platform="cpu")
if jaxlib.version >= (0, 3, 11):
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="gpu")
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="gpu")
if jaxlib.version >= (0, 3, 15):
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="tpu")
Expand Down
16 changes: 4 additions & 12 deletions jax/_src/distributed.py
Expand Up @@ -66,22 +66,14 @@ def initialize(self,
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logging.info('Starting JAX distributed service on %s', coordinator_address)
if xla_client._version >= 72:
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
else:
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)

if self.client is not None:
raise RuntimeError('distributed.initialize should only be called once.')

if xla_client._version >= 72:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service)
else:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id)
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()

Expand Down
15 changes: 4 additions & 11 deletions jax/_src/lax/ann.py
Expand Up @@ -132,8 +132,6 @@ def approx_max_k(operand: Array,
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> dot_products, neighbors = mips(qy, db, k=10)
"""
if xc._version < 45:
aggregate_to_topk = True
return approx_top_k_p.bind(
operand,
k=k,
Expand Down Expand Up @@ -197,8 +195,6 @@ def approx_min_k(operand: Array,
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
arithmetics and produces the same set of neighbors.
"""
if xc._version < 45:
aggregate_to_topk = True
return approx_top_k_p.bind(
operand,
k=k,
Expand All @@ -225,13 +221,10 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
dims[reduction_dimension], k))
if not dtypes.issubdtype(operand.dtype, np.floating):
raise ValueError('operand must be a floating type')
if xc._version >= 45:
reduction_input_size = dims[reduction_dimension]
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
reduction_input_size_override)[0]
else:
dims[reduction_dimension] = k
reduction_input_size = dims[reduction_dimension]
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
reduction_input_size_override)[0]
return (operand.update(
shape=dims, dtype=operand.dtype, weak_type=operand.weak_type),
operand.update(shape=dims, dtype=np.dtype(np.int32)))
Expand Down
37 changes: 10 additions & 27 deletions jax/_src/lax/lax.py
Expand Up @@ -3070,17 +3070,10 @@ def _pad_masking_rule(padded_vals, logical_shapes, padding_config):

def _pad_lower(ctx, x, padding_value, *, padding_config):
low, high, interior = util.unzip3(padding_config)
if jax._src.lib.mlir_api_version < 15:
aval_out, = ctx.avals_out
return mhlo.PadOp(mlir.aval_to_ir_type(aval_out), x, padding_value,
mlir.dense_int_elements(low),
mlir.dense_int_elements(high),
mlir.dense_int_elements(interior)).results
else:
return mhlo.PadOp(x, padding_value,
mlir.dense_int_elements(low),
mlir.dense_int_elements(high),
mlir.dense_int_elements(interior)).results
return mhlo.PadOp(x, padding_value,
mlir.dense_int_elements(low),
mlir.dense_int_elements(high),
mlir.dense_int_elements(interior)).results
mlir.register_lowering(pad_p, _pad_lower)


Expand Down Expand Up @@ -3817,13 +3810,8 @@ def _reduce_precision_shape_rule(operand, *, exponent_bits, mantissa_bits):

def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
aval_out, = ctx.avals_out
if jax._src.lib.mlir_api_version >= 21:
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
else:
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results

mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)

Expand Down Expand Up @@ -4059,12 +4047,9 @@ def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
top_k_p.multiple_results = True
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
if jax._src.lib.mlir_api_version >= 16:
def _top_k_lower(ctx, operand, k):
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
mlir.register_lowering(top_k_p, _top_k_lower)
else:
xla.register_translation(top_k_p, _top_k_translation_rule)
def _top_k_lower(ctx, operand, k):
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
mlir.register_lowering(top_k_p, _top_k_lower)
ad.primitive_jvps[top_k_p] = _top_k_jvp
batching.primitive_batchers[top_k_p] = _top_k_batch_rule

Expand Down Expand Up @@ -4315,9 +4300,7 @@ def _rng_bit_generator_lowering(
key = mhlo.BitcastConvertOp(
ir.RankedTensorType.get([2], u64_type),
mhlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
algorithm_attr = (
_rng_algorithm(algorithm) if jax._src.lib.mlir_api_version >= 14
else mlir.i32_attr(algorithm))
algorithm_attr = _rng_algorithm(algorithm)
out_key, out_vals = mhlo.RngBitGeneratorOp(
key.type,
ir.RankedTensorType.get(shape, rbg_etype),
Expand Down

0 comments on commit 0b4b0ba

Please sign in to comment.