forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 2
/
typing_test.py
184 lines (141 loc) · 5.85 KB
/
typing_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Typing tests
------------
This test is meant to be both a runtime test and a static type annotation test,
so it should be checked with pytype/mypy as well as being run with pytest.
"""
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import jax
from jax._src import core
from jax._src import test_util as jtu
from jax._src import typing
from jax import lax
import jax.numpy as jnp
from jax._src.array import ArrayImpl
from absl.testing import absltest
import numpy as np
# DTypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype, so we test with np.dtype.
def dtypelike_to_dtype(x: typing.DTypeLike) -> typing.DType:
return np.dtype(x)
# ArrayLike is meant to annotate object that are valid as array
# inputs to jax primitive functions; use convert_element_type here
# for simplicity.
def arraylike_to_array(x: typing.ArrayLike) -> typing.Array:
return lax.convert_element_type(x, np.result_type(x))
class HasDType:
dtype: np.dtype
def __init__(self, dt):
self.dtype = np.dtype(dt)
float32_dtype = np.dtype("float32")
# Avoid test parameterization because we want to statically check these annotations.
class TypingTest(jtu.JaxTestCase):
def testDTypeLike(self) -> None:
out1: typing.DType = dtypelike_to_dtype("float32")
self.assertEqual(out1, float32_dtype)
out2: typing.DType = dtypelike_to_dtype(np.float32)
self.assertEqual(out2, float32_dtype)
out3: typing.DType = dtypelike_to_dtype(jnp.float32)
self.assertEqual(out3, float32_dtype)
out4: typing.DType = dtypelike_to_dtype(np.dtype('float32'))
self.assertEqual(out4, float32_dtype)
out5: typing.DType = dtypelike_to_dtype(HasDType("float32"))
self.assertEqual(out5, float32_dtype)
def testArrayLike(self) -> None:
out1: typing.Array = arraylike_to_array(jnp.arange(4))
self.assertArraysEqual(out1, jnp.arange(4))
out2: typing.Array = jax.jit(arraylike_to_array)(jnp.arange(4))
self.assertArraysEqual(out2, jnp.arange(4))
out3: typing.Array = arraylike_to_array(np.arange(4))
self.assertArraysEqual(out3, jnp.arange(4), check_dtypes=False)
out4: typing.Array = arraylike_to_array(True)
self.assertArraysEqual(out4, jnp.array(True))
out5: typing.Array = arraylike_to_array(1)
self.assertArraysEqual(out5, jnp.array(1))
out6: typing.Array = arraylike_to_array(1.0)
self.assertArraysEqual(out6, jnp.array(1.0))
out7: typing.Array = arraylike_to_array(1 + 1j)
self.assertArraysEqual(out7, jnp.array(1 + 1j))
out8: typing.Array = arraylike_to_array(np.bool_(0))
self.assertArraysEqual(out8, jnp.bool_(0))
out9: typing.Array = arraylike_to_array(np.float32(0))
self.assertArraysEqual(out9, jnp.float32(0))
def testArrayInstanceChecks(self):
def is_array(x: typing.ArrayLike) -> bool | typing.Array:
return isinstance(x, typing.Array)
x = jnp.arange(5)
self.assertFalse(is_array(1.0))
self.assertTrue(jax.jit(is_array)(1.0))
self.assertTrue(is_array(x))
self.assertTrue(jax.jit(is_array)(x))
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
def testAnnotations(self):
# This test is mainly meant for static type checking: we want to ensure that
# Tracer and ArrayImpl are valid as array.Array.
def f(x: Any) -> typing.Array | None:
if isinstance(x, core.Tracer):
return x
elif isinstance(x, ArrayImpl):
return x
else:
return None
x = jnp.arange(10)
y = f(x)
self.assertArraysEqual(x, y)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
if TYPE_CHECKING:
# Here we do a number of static type assertions. We purposely don't cover the
# entire public API here, but rather spot-check some potentially problematic
# areas. The goals are:
#
# - Confirm the correctness of a few basic APIs
# - Confirm that types from *.pyi files are correctly pulled-in
# - Confirm that non-trivial overloads are behaving as expected.
#
import sys
if sys.version_info >= (3, 11):
from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade
else:
from typing_extensions import assert_type # pytype: disable=not-supported-yet
mat = jnp.zeros((2, 5))
vals = jnp.arange(5)
mask = jnp.array([True, False, True, False])
assert_type(mat, jax.Array)
assert_type(vals, jax.Array)
assert_type(mask, jax.Array)
# Functions with non-trivial typing overloads:
# jnp.linspace
assert_type(jnp.linspace(0, 10), jax.Array)
assert_type(jnp.linspace(0, 10, retstep=False), jax.Array)
assert_type(jnp.linspace(0, 10, retstep=True), tuple[jax.Array, jax.Array])
# jnp.where
assert_type(mask, jax.Array)
assert_type(jnp.where(mask, 0, 1), jax.Array)
assert_type(jnp.where(mask), tuple[jax.Array, ...])
# jnp.einsum
assert_type(jnp.einsum('ij', mat), jax.Array)
assert_type(jnp.einsum('ij,j->i', mat, vals), jax.Array)
assert_type(jnp.einsum(mat, (0, 0)), jax.Array)
# jnp.indices
assert_type(jnp.indices([2, 3]), jax.Array)
assert_type(jnp.indices([2, 3], sparse=False), jax.Array)
assert_type(jnp.indices([2, 3], sparse=True), tuple[jax.Array, ...])
# jnp.average
assert_type(jnp.average(vals), jax.Array)
assert_type(jnp.average(vals, returned=False), jax.Array)
assert_type(jnp.average(vals, returned=True), tuple[jax.Array, jax.Array])