Skip to content

Commit

Permalink
Add typing_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 13, 2022
1 parent dc4922f commit b3c31eb
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -17,7 +17,7 @@ repos:
rev: 'v0.971'
hooks:
- id: mypy
files: jax/
files: (jax/|tests/typing_test\.py)
additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5]

- repo: https://github.com/mwouts/jupytext
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/typing.py
Expand Up @@ -35,8 +35,12 @@ class HasDtypeAttribute(Protocol):

Dtype = np.dtype

# Any is here to allow scalar types like np.int32.
# TODO(jakevdp) figure out how to specify these more strictly.
# DtypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
# because JAX doesn't support objects or structured dtypes.
# It does not include JAX dtype extensions such as KeyType and others.
# For now, we use Any to allow scalar types like np.int32 & jnp.int32.
# TODO(jakevdp) specify these more strictly.
DtypeLike = Union[Any, str, np.dtype, HasDtypeAttribute]

# Shapes are tuples of dimension sizes, which are normally integers. We allow
Expand Down
10 changes: 10 additions & 0 deletions tests/BUILD
Expand Up @@ -732,6 +732,16 @@ py_test(
],
)

# TODO(jakevdp): make this a py_strict_test
py_test(
name = "typing_test",
srcs = ["typing_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)

py_test(
name = "util_test",
srcs = ["util_test.py"],
Expand Down
117 changes: 117 additions & 0 deletions tests/typing_test.py
@@ -0,0 +1,117 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Typing tests
------------
This test is meant to be both a runtime test and a static type annotation test,
so it should be checked with pytype/mypy as well as being run with pytest.
"""
from typing import Union

import jax
from jax._src import test_util as jtu
from jax._src import typing
from jax import lax
import jax.numpy as jnp

from absl.testing import absltest
import numpy as np


# DtypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype, so we test with np.dtype.
def dtypelike_to_dtype(x: typing.DtypeLike) -> typing.Dtype:
return np.dtype(x)


# ArrayLike is meant to annotate object that are valid as array
# inputs to jax primitive functions; use convert_element_type here
# for simplicity.
def arraylike_to_array(x: typing.ArrayLike) -> typing.Array:
return lax.convert_element_type(x, np.result_type(x))


class HasDtype:
dtype: np.dtype
def __init__(self, dt):
self.dtype = np.dtype(dt)

float32_dtype = np.dtype("float32")


# Avoid test parameterization because we want to statically check these annotations.
class TypingTest(jtu.JaxTestCase):

def testDtypeLike(self) -> None:
out1: typing.Dtype = dtypelike_to_dtype("float32")
self.assertEqual(out1, float32_dtype)

out2: typing.Dtype = dtypelike_to_dtype(np.float32)
self.assertEqual(out2, float32_dtype)

out3: typing.Dtype = dtypelike_to_dtype(jnp.float32)
self.assertEqual(out3, float32_dtype)

out4: typing.Dtype = dtypelike_to_dtype(np.dtype('float32'))
self.assertEqual(out4, float32_dtype)

out5: typing.Dtype = dtypelike_to_dtype(HasDtype("float32"))
self.assertEqual(out5, float32_dtype)

def testArrayLike(self) -> None:
out1: typing.Array = arraylike_to_array(jnp.arange(4))
self.assertArraysEqual(out1, jnp.arange(4))

out2: typing.Array = jax.jit(arraylike_to_array)(jnp.arange(4))
self.assertArraysEqual(out2, jnp.arange(4))

out3: typing.Array = arraylike_to_array(np.arange(4))
self.assertArraysEqual(out3, jnp.arange(4))

out4: typing.Array = arraylike_to_array(True)
self.assertArraysEqual(out4, jnp.array(True))

out5: typing.Array = arraylike_to_array(1)
self.assertArraysEqual(out5, jnp.array(1))

out6: typing.Array = arraylike_to_array(1.0)
self.assertArraysEqual(out6, jnp.array(1.0))

out7: typing.Array = arraylike_to_array(1 + 1j)
self.assertArraysEqual(out7, jnp.array(1 + 1j))

out8: typing.Array = arraylike_to_array(np.bool_(0))
self.assertArraysEqual(out8, jnp.bool_(0))

out9: typing.Array = arraylike_to_array(np.float32(0))
self.assertArraysEqual(out9, jnp.float32(0))

def testArrayInstanceChecks(self):
# TODO(jakevdp): enable this test when `typing.Array` instance checks are implemented.
self.skipTest("Test is broken for now.")

def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]:
return isinstance(x, typing.Array)

x = jnp.arange(5)

self.assertFalse(is_array(1.0))
self.assertTrue(jax.jit(is_array)(1.0))
self.assertTrue(is_array(x))
self.assertTrue(jax.jit(is_array)(x))
self.assertTrue(jax.vmap(is_array)(x).all())


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b3c31eb

Please sign in to comment.