Skip to content

Commit

Permalink
Drop support for NumPy 1.16.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Jun 11, 2021
1 parent 87a533e commit b130257
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
os: ubuntu-latest
enable-x64: 1
# Test with numpy version that matches Google-internal version
package-overrides: "numpy==1.16.4 scipy==1.2.1"
package-overrides: "numpy==1.17.5 scipy==1.2.1"
num_generated_cases: 10
- name-prefix: "with 3.7"
python-version: 3.7
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* New features:

* Breaking changes:
* Support for NumPy 1.16 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).

* Bug fixes:
* Fixed bug that prevented round-tripping from JAX to TF and back:
Expand All @@ -33,8 +35,6 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA.

* Breaking changes:

* Bug fixes:
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
for Python scalars and for choosing 32-bit vs. 64-bit computations
Expand Down
4 changes: 2 additions & 2 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def check_numpy_version(python_bin_path):
version = shell(
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
numpy_version = tuple(map(int, version.split('.')[:2]))
if numpy_version < (1, 16):
print("ERROR: JAX requires NumPy 1.16 or newer, found " + version + ".")
if numpy_version < (1, 17):
print("ERROR: JAX requires NumPy 1.17 or newer, found " + version + ".")
sys.exit(-1)
return version

Expand Down
2 changes: 1 addition & 1 deletion jaxlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension-stubs'],
python_requires='>=3.6',
install_requires=['scipy', 'numpy>=1.16', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
install_requires=['scipy', 'numpy>=1.17', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
url='https://github.com/google/jax',
license='Apache-2.0',
package_data={
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
package_data={'jax': ['py.typed']},
python_requires='>=3.6',
install_requires=[
'numpy >=1.12',
'numpy>=1.17',
'absl-py',
'opt_einsum',
],
Expand Down
34 changes: 5 additions & 29 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
config.parse_flags_with_absl()
FLAGS = config.FLAGS

numpy_version = tuple(map(int, np.version.version.split('.')))
numpy_version = tuple(map(int, np.__version__.split('.')))

nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
Expand Down Expand Up @@ -140,6 +140,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
jtu.rand_default, [], check_dtypes=False),
op_record("greater", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("greater_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
op_record("ldexp", 2, int_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False),
op_record("less", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
op_record("less_equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal, []),
Expand Down Expand Up @@ -198,13 +200,6 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
inexact=True, tolerance={np.float64: 1e-9}),
]

# Skip np.i0() tests on older numpy: https://github.com/numpy/numpy/issues/11205
if numpy_version >= (1, 17, 0):
JAX_ONE_TO_ONE_OP_RECORDS.append(
op_record("i0", 1, float_dtypes, all_shapes, jtu.rand_default, [],
check_dtypes=False),
)

JAX_COMPOUND_OP_RECORDS = [
# angle has inconsistent 32/64-bit return types across numpy versions.
op_record("angle", 1, number_dtypes, all_shapes, jtu.rand_default, [],
Expand Down Expand Up @@ -821,7 +816,6 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 17), "where parameter not supported in older numpy")
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format(
Expand Down Expand Up @@ -1567,9 +1561,6 @@ def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type
tol={np.float32: 1e-3, np.complex64: 1e-3})
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 16, 6),
"numpy <= 1.16.5 has a bug in linear_ramp")
# https://github.com/numpy/numpy/commit/1c45e0df150b1f49982aaa3fc1a328407b5eff7e
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format(
jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values),
Expand Down Expand Up @@ -1618,7 +1609,6 @@ def testPadLinearRamp(self, shape, dtype, pad_width, end_values):
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 17, 0), "empty mode is new in numpy 1.17.0")
def testPadEmpty(self):
arr = np.arange(6).reshape(2, 3)

Expand Down Expand Up @@ -1670,7 +1660,6 @@ def testPadKwargs(self):
with self.assertRaisesRegex(NotImplementedError, match):
jnp.pad(arr, pad_width, mode)

@unittest.skipIf(numpy_version < (1, 17, 0), "function mode is new in numpy 1.17.0")
def testPadFunction(self):
def np_pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
Expand Down Expand Up @@ -2894,7 +2883,6 @@ def testOnesWithInvalidShape(self):
with self.assertRaises(TypeError):
jnp.ones((-1, 1))

@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format(
jtu.format_shape_dtype_string(shape, in_dtype),
Expand All @@ -2921,7 +2909,6 @@ def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format(
func, jtu.format_shape_dtype_string(shape, in_dtype),
Expand All @@ -2944,7 +2931,6 @@ def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype):
self._CompileAndCheck(jnp_fun, args_maker)


@unittest.skipIf(numpy_version < (1, 17), "shape parameter not supported in older numpy")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format(
func, jtu.format_shape_dtype_string(shape, in_dtype),
Expand Down Expand Up @@ -3949,8 +3935,6 @@ def testRollaxis(self, shape, dtype, start, axis):
for shape in [(1, 2, 3, 4)]
for axis in [None, 0, 1, -2, -1]))
def testPackbits(self, shape, dtype, axis, bitorder):
if numpy_version < (1, 17, 0):
raise SkipTest("bitorder arg added in numpy 1.17.0")
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder)
Expand All @@ -3969,8 +3953,6 @@ def testPackbits(self, shape, dtype, axis, bitorder):
for axis in [None, 0, 1, -2, -1]
for count in [None, 20]))
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
if numpy_version < (1, 17, 0):
raise SkipTest("bitorder arg added in numpy 1.17.0")
rng = jtu.rand_int(self.rng(), 0, 256)
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
Expand Down Expand Up @@ -4144,14 +4126,8 @@ def testIx_(self, shapes, dtypes):
for sparse in [True, False]))
def testIndices(self, dimensions, dtype, sparse):
def args_maker(): return []
if numpy_version < (1, 17):
if sparse:
raise SkipTest("indices does not have sparse on numpy < 1.17")
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype)
else:
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
np_fun = partial(np.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
jnp_fun = partial(jnp.indices, dimensions=dimensions,
dtype=dtype, sparse=sparse)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
Expand Down

0 comments on commit b130257

Please sign in to comment.