From b3c31ebe7d16c45705efb6ebb7d63831bf240de6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 13 Sep 2022 12:43:51 -0700 Subject: [PATCH] Add typing_test.py --- .pre-commit-config.yaml | 2 +- jax/_src/typing.py | 8 ++- tests/BUILD | 10 ++++ tests/typing_test.py | 117 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 3 deletions(-) create mode 100644 tests/typing_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4688f0a1f4d9..159c9c9c56a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: rev: 'v0.971' hooks: - id: mypy - files: jax/ + files: (jax/|tests/typing_test\.py) additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5] - repo: https://github.com/mwouts/jupytext diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 38862c88442d..ea754dd43d99 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -35,8 +35,12 @@ class HasDtypeAttribute(Protocol): Dtype = np.dtype -# Any is here to allow scalar types like np.int32. -# TODO(jakevdp) figure out how to specify these more strictly. +# DtypeLike is meant to annotate inputs to np.dtype that return +# a valid JAX dtype. It's different than numpy.typing.DTypeLike +# because JAX doesn't support objects or structured dtypes. +# It does not include JAX dtype extensions such as KeyType and others. +# For now, we use Any to allow scalar types like np.int32 & jnp.int32. +# TODO(jakevdp) specify these more strictly. DtypeLike = Union[Any, str, np.dtype, HasDtypeAttribute] # Shapes are tuples of dimension sizes, which are normally integers. We allow diff --git a/tests/BUILD b/tests/BUILD index 4efd2c0727bf..153cb5961361 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -732,6 +732,16 @@ py_test( ], ) +# TODO(jakevdp): make this a py_strict_test +py_test( + name = "typing_test", + srcs = ["typing_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + py_test( name = "util_test", srcs = ["util_test.py"], diff --git a/tests/typing_test.py b/tests/typing_test.py new file mode 100644 index 000000000000..9eb87c1e6ab0 --- /dev/null +++ b/tests/typing_test.py @@ -0,0 +1,117 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Typing tests +------------ +This test is meant to be both a runtime test and a static type annotation test, +so it should be checked with pytype/mypy as well as being run with pytest. +""" +from typing import Union + +import jax +from jax._src import test_util as jtu +from jax._src import typing +from jax import lax +import jax.numpy as jnp + +from absl.testing import absltest +import numpy as np + + +# DtypeLike is meant to annotate inputs to np.dtype that return +# a valid JAX dtype, so we test with np.dtype. +def dtypelike_to_dtype(x: typing.DtypeLike) -> typing.Dtype: + return np.dtype(x) + + +# ArrayLike is meant to annotate object that are valid as array +# inputs to jax primitive functions; use convert_element_type here +# for simplicity. +def arraylike_to_array(x: typing.ArrayLike) -> typing.Array: + return lax.convert_element_type(x, np.result_type(x)) + + +class HasDtype: + dtype: np.dtype + def __init__(self, dt): + self.dtype = np.dtype(dt) + +float32_dtype = np.dtype("float32") + + +# Avoid test parameterization because we want to statically check these annotations. +class TypingTest(jtu.JaxTestCase): + + def testDtypeLike(self) -> None: + out1: typing.Dtype = dtypelike_to_dtype("float32") + self.assertEqual(out1, float32_dtype) + + out2: typing.Dtype = dtypelike_to_dtype(np.float32) + self.assertEqual(out2, float32_dtype) + + out3: typing.Dtype = dtypelike_to_dtype(jnp.float32) + self.assertEqual(out3, float32_dtype) + + out4: typing.Dtype = dtypelike_to_dtype(np.dtype('float32')) + self.assertEqual(out4, float32_dtype) + + out5: typing.Dtype = dtypelike_to_dtype(HasDtype("float32")) + self.assertEqual(out5, float32_dtype) + + def testArrayLike(self) -> None: + out1: typing.Array = arraylike_to_array(jnp.arange(4)) + self.assertArraysEqual(out1, jnp.arange(4)) + + out2: typing.Array = jax.jit(arraylike_to_array)(jnp.arange(4)) + self.assertArraysEqual(out2, jnp.arange(4)) + + out3: typing.Array = arraylike_to_array(np.arange(4)) + self.assertArraysEqual(out3, jnp.arange(4)) + + out4: typing.Array = arraylike_to_array(True) + self.assertArraysEqual(out4, jnp.array(True)) + + out5: typing.Array = arraylike_to_array(1) + self.assertArraysEqual(out5, jnp.array(1)) + + out6: typing.Array = arraylike_to_array(1.0) + self.assertArraysEqual(out6, jnp.array(1.0)) + + out7: typing.Array = arraylike_to_array(1 + 1j) + self.assertArraysEqual(out7, jnp.array(1 + 1j)) + + out8: typing.Array = arraylike_to_array(np.bool_(0)) + self.assertArraysEqual(out8, jnp.bool_(0)) + + out9: typing.Array = arraylike_to_array(np.float32(0)) + self.assertArraysEqual(out9, jnp.float32(0)) + + def testArrayInstanceChecks(self): + # TODO(jakevdp): enable this test when `typing.Array` instance checks are implemented. + self.skipTest("Test is broken for now.") + + def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]: + return isinstance(x, typing.Array) + + x = jnp.arange(5) + + self.assertFalse(is_array(1.0)) + self.assertTrue(jax.jit(is_array)(1.0)) + self.assertTrue(is_array(x)) + self.assertTrue(jax.jit(is_array)(x)) + self.assertTrue(jax.vmap(is_array)(x).all()) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())