Skip to content

Commit

Permalink
Increase minimum jaxlib version to 0.1.62.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Mar 16, 2021
1 parent d326b07 commit 328930b
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 132 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* New features:
* Bug fixes:
* Breaking changes:
* The minimum jaxlib version is now 0.1.62.

## jaxlib 0.1.63 (Unreleased)

Expand Down
2 changes: 1 addition & 1 deletion build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ flake8
# For now, we pin the numpy version here
numpy>=1.16
# Must be kept in sync with the minimum jaxlib version in jax/lib/__init__.py
jaxlib==0.1.60
jaxlib==0.1.62
mypy==0.790
pillow
pytest-benchmark
Expand Down
99 changes: 45 additions & 54 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6086,75 +6086,66 @@ def _rng_uniform_translation_rule(c, a, b, *, shape):
xla.translations[rng_uniform_p] = _rng_uniform_translation_rule


if jax.lib.version >= (0, 1, 62):
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
_ = dtype, algorithm
return (key.shape, tuple(shape))
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
_ = dtype, algorithm
return (key.shape, tuple(shape))


def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
_ = key, shape, algorithm
return (key.dtype, dtype)
def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
_ = key, shape, algorithm
return (key.dtype, dtype)


def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
_ = shape, dtype, algorithm
return (key.weak_type, False)
def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
_ = shape, dtype, algorithm
return (key.weak_type, False)


def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm):
_ = c
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
return xops.RngBitGenerator(algorithm, key, xla_shape)
def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm):
_ = c
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
return xops.RngBitGenerator(algorithm, key, xla_shape)


def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]

rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl(
partial(xla.apply_primitive, rng_bit_generator_p))
rng_bit_generator_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
_rng_bit_generator_weak_type_rule,
_rng_bit_generator_named_shape_rule))
xla.translations[rng_bit_generator_p] = _rng_bit_generator_translation_rule
rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl(
partial(xla.apply_primitive, rng_bit_generator_p))
rng_bit_generator_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
_rng_bit_generator_weak_type_rule,
_rng_bit_generator_named_shape_rule))
xla.translations[rng_bit_generator_p] = _rng_bit_generator_translation_rule

RandomAlgorithm = xops.RandomAlgorithm
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name
RandomAlgorithm = xops.RandomAlgorithm
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name


def rng_bit_generator(key,
shape,
dtype=np.uint32,
algorithm=RandomAlgorithm.RNG_DEFAULT):
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
def rng_bit_generator(key,
shape,
dtype=np.uint32,
algorithm=RandomAlgorithm.RNG_DEFAULT):
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
Returns uniformly distributed random bits with the specified shape and dtype
(what is requirted to be an integer type) using the platform specific
default algorithm or the one specified.
Returns uniformly distributed random bits with the specified shape and dtype
(what is requirted to be an integer type) using the platform specific
default algorithm or the one specified.
It provides direct acces to the RngBitGenerator primitive exposed by XLA
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
level API access.
It provides direct acces to the RngBitGenerator primitive exposed by XLA
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
level API access.
Most users should use `jax.random` instead for a stable and more user
friendly API.
"""
shape = jax.core.canonicalize_shape(shape)
return tuple(
rng_bit_generator_p.bind(
key, shape=shape, dtype=dtype, algorithm=algorithm))
else:
# TODO(tberghammer): Remove when minimum jaxlib version is past (0, 1, 62).
rng_bit_generator_p = Primitive("rng_bit_generator")
class RandomAlgorithm: pass # type: ignore


def rng_bit_generator(key, shape, dtype=np.uint32, algorithm=None):
raise "rng_bit_generator needs jaxlib 0.1.62 or newer"
Most users should use `jax.random` instead for a stable and more user
friendly API.
"""
shape = jax.core.canonicalize_shape(shape)
return tuple(
rng_bit_generator_p.bind(
key, shape=shape, dtype=dtype, algorithm=algorithm))


def _iota_abstract_eval(*, dtype, shape, dimension):
Expand Down
38 changes: 3 additions & 35 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@
FLAGS = flags.FLAGS
flags.DEFINE_bool("jax_disable_jit", bool_env("JAX_DISABLE_JIT", False),
"Disable JIT compilation and just call original Python.")
# TODO(jblespiau): Remove the `if` when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 5:
jax_jit.set_disable_jit_cpp_flag(bool_env("JAX_DISABLE_JIT", False))
jax_jit.set_disable_jit_cpp_flag(bool_env("JAX_DISABLE_JIT", False))

flags.DEFINE_bool(
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
Expand Down Expand Up @@ -347,39 +345,9 @@ def get_device_info():

return _BackendAndDeviceInfo(default_device, committed_to_device)

# TODO(jblespiau): Delete `get_jax_enable_x64` and `get_jax_disable_jit_flag`
# when jaxlib 0.1.62 is the minimal version.
def get_jax_enable_x64():
"""Returns the value of the flag after GoogleInit.
We must wait until flags have been parsed (in particular for top-level
functions decorated with jax.jit), so we delay inspecting the value
of the jax_enable_x64 flag until JIT time.
"""
# TODO(jblespiau): Delete when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 4:
return config.read("jax_enable_x64")
else:
return config.x64_enabled

def get_jax_disable_jit_flag():
"""Returns the value of the `jax_disable_jit` flag.
Both a flag and the `disable_jit` context manager can disable jit. We access
the flag only once, when jitting the function, and the context manager
modifies a C++ thread-local value.
"""
return config.read("jax_disable_jit")

static_argnums_ = (0,) + tuple(i + 1 for i in static_argnums)
# TODO(jblespiau): Remove when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 5:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums_)
else:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
get_jax_enable_x64, get_jax_disable_jit_flag,
static_argnums_)
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums_)

# TODO(mattjj): make cpp callable follow descriptor protocol for bound methods
@wraps(fun)
Expand Down
22 changes: 4 additions & 18 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def __init__(self):


class Config:
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
_thread_local_state = _ThreadLocalState()

def __init__(self):
self.values = {}
self.meta = {}
Expand All @@ -70,10 +67,9 @@ def update(self, name, val):
raise Exception("Unrecognized config option: {}".format(name))
self.values[name] = val

# TODO(jblespiau): Remove when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 5 and name == "jax_disable_jit":
if name == "jax_disable_jit":
lib.jax_jit.set_disable_jit_cpp_flag(val)
elif lib._xla_extension_version >= 5 and name == "jax_enable_x64":
elif name == "jax_enable_x64":
lib.jax_jit.set_enable_x64_cpp_flag(val)

def read(self, name):
Expand Down Expand Up @@ -157,21 +153,11 @@ def disable_omnistaging(self):

@property
def x64_enabled(self):
if lib._xla_extension_version >= 5:
return lib.jax_jit.get_enable_x64()
else:
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
if self._thread_local_state.enable_x64 is None:
self._thread_local_state.enable_x64 = bool(self.read('jax_enable_x64'))
return self._thread_local_state.enable_x64
return lib.jax_jit.get_enable_x64()

# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
def _set_x64_enabled(self, state):
if lib._xla_extension_version >= 5:
lib.jax_jit.set_enable_x64_thread_local(bool(state))
else:
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
self._thread_local_state.enable_x64 = bool(state)
lib.jax_jit.set_enable_x64_thread_local(bool(state))


class NameSpace(object):
Expand Down
6 changes: 2 additions & 4 deletions jax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@
flags.DEFINE_bool('jax_enable_x64',
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
'Enable 64-bit types to be used.')
# TODO(jblespiau): Remove the `if` when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 5:
lib.jax_jit.set_enable_x64_cpp_flag(
strtobool(os.getenv('JAX_ENABLE_X64', 'False')))
lib.jax_jit.set_enable_x64_cpp_flag(
strtobool(os.getenv('JAX_ENABLE_X64', 'False')))

# bfloat16 support
bfloat16: type = xla_client.bfloat16
Expand Down
2 changes: 1 addition & 1 deletion jax/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
) from err

# Must be kept in sync with the jaxlib version in build/test-requirements.txt
_minimum_jaxlib_version = (0, 1, 60)
_minimum_jaxlib_version = (0, 1, 62)
try:
from jaxlib import version as jaxlib_version
except Exception as err:
Expand Down
6 changes: 1 addition & 5 deletions tests/jax_jit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import inspect
import unittest

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -29,8 +28,6 @@
# It covers all JAX numpy types types except bfloat16 and numpy array.
# TODO(jblespiau): Add support for float0 in the C++ path.
_EXCLUDED_TYPES = [np.ndarray]
if jax.lib._xla_extension_version < 6:
_EXCLUDED_TYPES.append(jax.dtypes.bfloat16)

_SCALAR_NUMPY_TYPES = [
x for x in jax.abstract_arrays.array_types if x not in _EXCLUDED_TYPES
Expand Down Expand Up @@ -138,10 +135,9 @@ def test_device_put_on_python_scalars(self):
self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)

@unittest.skipIf(jax.lib._xla_extension_version < 3, "jaxlib too old")
def test_convert_int_overflow(self):
with self.assertRaisesRegex(
RuntimeError if jax.lib._xla_extension_version >= 6 else OverflowError,
RuntimeError,
"(Python int too large|Unable to convert Python scalar).*"):
jaxlib.jax_jit.device_put(int(1e100), True, jax.devices()[0])

Expand Down
3 changes: 0 additions & 3 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,9 +2582,6 @@ def scan(state, xs):

def test_xla_cpu_gpu_loop_cond_bug(self):
# https://github.com/google/jax/issues/5900
if jax.lib.version < (0, 1, 62):
raise SkipTest("test is broken on jaxlib==0.1.61 and 0.1.60")

def deriv(f):
return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0)

Expand Down
2 changes: 0 additions & 2 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from functools import partial
import itertools
import operator
import unittest
from unittest import SkipTest

from absl.testing import absltest
Expand Down Expand Up @@ -2285,7 +2284,6 @@ def test_select_jvp_complexity(self):
(x,), (1.,)))(1.)
self.assertLen(jaxpr.jaxpr.eqns, 2)

@unittest.skipIf(jax.lib.version < (0, 1, 62), "Needs jaxlib 0.1.62 or newer")
def testRngBitGenerator(self):
if not config.x64_enabled:
raise SkipTest("RngBitGenerator requires 64bit key")
Expand Down
5 changes: 0 additions & 5 deletions tests/x64_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
from absl.testing import absltest
from absl.testing import parameterized

import jax
from jax import api
from jax import lax
from jax import partial
from jax import random
from jax.config import config
from jax.config import FLAGS
from jax.experimental import enable_x64, disable_x64
import jax.numpy as jnp
import jax.test_util as jtu
Expand Down Expand Up @@ -145,9 +143,6 @@ def func_x64():
def test_jit_cache(self):
if jtu.device_under_test() == "tpu":
self.skipTest("64-bit random not available on TPU")
if jax.lib._xla_extension_version < 4 and FLAGS.experimental_cpp_jit:
self.skipTest(
"Known failure due to https://github.com/google/jax/issues/5532")

f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1)
with disable_x64():
Expand Down
4 changes: 0 additions & 4 deletions tests/xmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,6 @@ def divisors2(n: int) -> Iterator[Tuple[int, int]]:

class XMapTestCase(jtu.BufferDonationTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
Expand Down Expand Up @@ -691,8 +689,6 @@ def testVarianceScaling(self, map_in, map_out, fan, distr):

class NewPrimitiveTest(XMapTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")

Expand Down

0 comments on commit 328930b

Please sign in to comment.