Skip to content

Commit

Permalink
Increase minimum NumPy version to 1.20.
Browse files Browse the repository at this point in the history
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
  • Loading branch information
hawkinsp committed Aug 6, 2022
1 parent c02359b commit c735c6b
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 49 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.

## jax 0.3.16 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
* Breaking changes
* Support for NumPy 1.19 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to NumPy 1.20 or newer.
* Changes
* Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`.
* Added new documentation for [runtime value debugging](debugging/index)
Expand Down
4 changes: 2 additions & 2 deletions build/build.py
Expand Up @@ -83,8 +83,8 @@ def check_numpy_version(python_bin_path):
version = shell(
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
numpy_version = tuple(map(int, version.split(".")[:2]))
if numpy_version < (1, 19):
print("ERROR: JAX requires NumPy 1.19 or newer, found " + version + ".")
if numpy_version < (1, 20):
print("ERROR: JAX requires NumPy 1.20 or newer, found " + version + ".")
sys.exit(-1)
return version

Expand Down
2 changes: 1 addition & 1 deletion jaxlib/setup.py
Expand Up @@ -43,7 +43,7 @@
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.7',
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py'],
install_requires=['scipy>=1.5', 'numpy>=1.20', 'absl-py'],
url='https://github.com/google/jax',
license='Apache-2.0',
classifiers=[
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -65,7 +65,7 @@ def generate_proto(source):
python_requires='>=3.7',
install_requires=[
'absl-py',
'numpy>=1.19',
'numpy>=1.20',
'opt_einsum',
'scipy>=1.5',
'typing_extensions',
Expand Down
6 changes: 1 addition & 5 deletions tests/fft_test.py
Expand Up @@ -30,11 +30,7 @@
from jax.config import config
config.parse_flags_with_absl()

numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
if numpy_version < (1, 20):
FFT_NORMS = [None, "ortho"]
else:
FFT_NORMS = [None, "ortho", "forward", "backward"]
FFT_NORMS = [None, "ortho", "forward", "backward"]


float_dtypes = jtu.dtypes.floating
Expand Down
45 changes: 5 additions & 40 deletions tests/lax_numpy_test.py
Expand Up @@ -1012,7 +1012,6 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 20), "where parameter not supported in older numpy")
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
Expand Down Expand Up @@ -1912,8 +1911,6 @@ def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type
# following types lack precision
dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])))
def testPadLinearRamp(self, shape, dtype, pad_width, end_values):
if numpy_version < (1, 20) and np.issubdtype(dtype, np.integer):
raise unittest.SkipTest("NumPy 1.20 changed the semantics of np.linspace")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

Expand Down Expand Up @@ -2316,9 +2313,6 @@ def testConcatenate(self, axis, dtype, base_shape, arg_dtypes):
def np_fun(*args):
args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32)
for x in args]
if numpy_version < (1, 20):
_dtype = dtype or jnp.result_type(*arg_dtypes)
return np.concatenate(args, axis=axis).astype(_dtype)
return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe')
jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype)

Expand Down Expand Up @@ -2425,16 +2419,13 @@ def testDeleteSlice(self, shape, dtype, axis, slc):
def testDeleteIndexArray(self, shape, dtype, axis, idx_shape):
rng = jtu.rand_default(self.rng())
max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis]
# Previous to numpy 1.19, negative indices were ignored so we don't test this.
low = 0 if numpy_version < (1, 19, 0) else -max_idx
idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int)
idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int)
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda arg: np.delete(arg, idx, axis=axis)
jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
Expand Down Expand Up @@ -2606,8 +2597,6 @@ def testRepeatScalarFastPath(self):
for return_inverse in [False, True]
for return_counts in [False, True]))
def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts):
if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0:
self.skipTest("zero-sized axis in unique leads to error in older numpy.")
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
extra_args = (return_index, return_inverse, return_counts)
Expand Down Expand Up @@ -3462,8 +3451,6 @@ def testOnesWithInvalidShape(self):
for fill_value_shape in s(_compatible_shapes(shape if out_shape is None else out_shape))
for out_dtype in s(default_dtypes))))
def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
np_fun = lambda x, fill_value: np.full_like(
x, fill_value, dtype=out_dtype, shape=out_shape)
Expand All @@ -3485,8 +3472,6 @@ def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_
for func in ["ones_like", "zeros_like"]
for out_dtype in default_dtypes))
def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape)
jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape)
Expand All @@ -3509,8 +3494,6 @@ def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())]
for out_dtype in [None, float]))
def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype):
if numpy_version < (1, 19) and out_shape == ():
raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None")
rng = jtu.rand_default(self.rng())
x = lax_internal._convert_element_type(rng(shape, in_dtype),
weak_type=weak_type)
Expand Down Expand Up @@ -3806,8 +3789,7 @@ def testResize(self, arg_shape, out_shape, dtype):
np_fun = lambda x: np.resize(x, out_shape)
jnp_fun = lambda x: jnp.resize(x, out_shape)
args_maker = lambda: [rng(arg_shape, dtype)]
if len(out_shape) > 0 or numpy_version >= (1, 20, 0):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -5757,20 +5739,9 @@ def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype):
jnp_op = lambda start, stop: jnp.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
# NumPy 1.20.0 changed the semantics of linspace to floor for integer
# dtypes.
if numpy_version >= (1, 20) or not np.issubdtype(dtype, np.integer):
np_op = lambda start, stop: np.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)
else:
def np_op(start, stop):
out = np.linspace(start, stop, num, endpoint=endpoint,
retstep=retstep, axis=axis)
if retstep:
return np.floor(out[0]).astype(dtype), out[1]
else:
return np.floor(out).astype(dtype)
np_op = lambda start, stop: np.linspace(
start, stop, num,
endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis)

self._CheckAgainstNumpy(np_op, jnp_op, args_maker,
check_dtypes=False, tol=tol)
Expand Down Expand Up @@ -6448,12 +6419,6 @@ def testWrappedSignaturesMatch(self):
mismatches = {}

for name, (jnp_fun, np_fun) in func_pairs.items():
# broadcast_shapes is not available in numpy < 1.20
if numpy_version < (1, 20) and name == "broadcast_shapes":
continue
# Some signatures have changed; skip for older numpy versions.
if numpy_version < (1, 19) and name in ['einsum_path', 'gradient', 'isscalar']:
continue
if numpy_version < (1, 22) and name in ['quantile', 'nanquantile',
'percentile', 'nanpercentile']:
continue
Expand Down

0 comments on commit c735c6b

Please sign in to comment.