Skip to content

Commit

Permalink
Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation war… (
Browse files Browse the repository at this point in the history
#3543)

* Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation warnings.

* Pin a newer tf-nightly to fix jax2tf tests for NumPy 1.19.0
  • Loading branch information
hawkinsp committed Jun 24, 2020
1 parent 90894a2 commit f036f5d
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion build/test-requirements.txt
Expand Up @@ -5,5 +5,5 @@ mypy==0.770
pytest-benchmark
pytest-xdist
# jax2tf needs some fixes that are not in tensorflow==2.2.0
tf-nightly==2.3.0.dev20200525
tf-nightly==2.3.0.dev20200624
wheel
6 changes: 4 additions & 2 deletions jax/dtypes.py
Expand Up @@ -144,8 +144,10 @@ def _issubclass(a, b):

def issubdtype(a, b):
if a == bfloat16:
return b in [bfloat16, _bfloat16_dtype, np.floating, np.inexact,
np.number]
if isinstance(b, np.dtype):
return b == _bfloat16_dtype
else:
return b in [bfloat16, np.floating, np.inexact, np.number]
if not _issubclass(b, np.generic):
# Workaround for JAX scalar types. NumPy's issubdtype has a backward
# compatibility behavior for the second argument of issubdtype that
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/lax_numpy.py
Expand Up @@ -139,7 +139,7 @@ def __hash__(self):
return hash(self.dtype.type)

def __eq__(self, other):
return id(self) == id(other) or self.dtype == other
return id(self) == id(other) or self.dtype.type == other

def __ne__(self, other):
return not (self == other)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -29,7 +29,7 @@
packages=find_packages(exclude=["examples"]),
python_requires='>=3.6',
install_requires=[
'numpy >=1.12, <1.19', 'absl-py', 'opt_einsum'
'numpy >=1.12', 'absl-py', 'opt_einsum'
],
url='https://github.com/google/jax',
license='Apache-2.0',
Expand Down

0 comments on commit f036f5d

Please sign in to comment.