forked from google/jax
/
dtypes.py
267 lines (227 loc) · 7.92 KB
/
dtypes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Array type functions.
#
# JAX dtypes differ from NumPy in both:
# a) their type promotion rules, and
# b) the set of supported types (e.g., bfloat16),
# so we need our own implementation that deviates from NumPy in places.
from distutils.util import strtobool
import functools
import os
import numpy as np
from . import util
from .config import flags
from .lib import xla_client
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_enable_x64',
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
'Enable 64-bit types to be used.')
# bfloat16 support
bfloat16 = xla_client.bfloat16
_bfloat16_dtype = np.dtype(bfloat16)
class _bfloat16_finfo(object):
bits = 16
eps = bfloat16(float.fromhex("0x1p-7"))
epsneg = bfloat16(float.fromhex("0x1p-8"))
machep = -7
negep = -8
max = bfloat16(float.fromhex("0x1.FEp127"))
min = -max
nexp = 8
nmant = 7
iexp = nexp
precision = 2
resolution = 10 ** -2
tiny = bfloat16(float.fromhex("0x1p-126"))
# Default types.
bool_ = np.bool_
int_ = np.int64
float_ = np.float64
complex_ = np.complex128
# TODO(phawkins): change the above defaults to:
# int_ = np.int32
# float_ = np.float32
# complex_ = np.complex64
# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0 = np.dtype([('float0', np.void, 0)])
_dtype_to_32bit_dtype = {
np.dtype('int64'): np.dtype('int32'),
np.dtype('uint64'): np.dtype('uint32'),
np.dtype('float64'): np.dtype('float32'),
np.dtype('complex128'): np.dtype('complex64'),
}
@util.memoize
def canonicalize_dtype(dtype):
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
if isinstance(dtype, str) and dtype == "bfloat16":
dtype = bfloat16
try:
dtype = np.dtype(dtype)
except TypeError as e:
raise TypeError(f'dtype {dtype!r} not understood') from e
if FLAGS.jax_enable_x64:
return dtype
else:
return _dtype_to_32bit_dtype.get(dtype, dtype)
# Default dtypes corresponding to Python scalars.
python_scalar_dtypes = {
bool: np.dtype(bool_),
int: np.dtype(int_),
float: np.dtype(float_),
complex: np.dtype(complex_),
float0: float0
}
def scalar_type_of(x):
typ = dtype(x)
if np.issubdtype(typ, np.bool_):
return bool
elif np.issubdtype(typ, np.integer):
return int
elif np.issubdtype(typ, np.floating):
return float
elif np.issubdtype(typ, np.complexfloating):
return complex
else:
raise TypeError("Invalid scalar value {}".format(x))
def coerce_to_array(x):
"""Coerces a scalar or NumPy array to an np.array.
Handles Python scalar type promotion according to JAX's rules, not NumPy's
rules.
"""
dtype = python_scalar_dtypes.get(type(x), None)
return np.array(x, dtype) if dtype else np.array(x)
iinfo = np.iinfo
def finfo(dtype):
# Since NumPy doesn't consider bfloat16 a floating-point type, we have to
# provide an alternative implementation of finfo that does so.
if ((isinstance(dtype, str) and dtype == "bfloat16") or
np.result_type(dtype) == _bfloat16_dtype):
return _bfloat16_finfo
else:
return np.finfo(dtype)
def _issubclass(a, b):
"""Determines if ``a`` is a subclass of ``b``.
Similar to issubclass, but returns False instead of an exception if `a` is not
a class.
"""
try:
return issubclass(a, b)
except TypeError:
return False
def issubdtype(a, b):
if a == bfloat16:
if isinstance(b, np.dtype):
return b == _bfloat16_dtype
else:
return b in [bfloat16, np.floating, np.inexact, np.number]
if not _issubclass(b, np.generic):
# Workaround for JAX scalar types. NumPy's issubdtype has a backward
# compatibility behavior for the second argument of issubdtype that
# interacts badly with JAX's custom scalar types. As a workaround,
# explicitly cast the second argument to a NumPy type object.
b = np.dtype(b).type
return np.issubdtype(a, b)
can_cast = np.can_cast
issubsctype = np.issubsctype
# List of all valid JAX dtypes, in the order they appear in the type promotion
# table.
_jax_types = [
np.dtype('bool'),
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
np.dtype(bfloat16),
np.dtype('float16'),
np.dtype('float32'),
np.dtype('float64'),
np.dtype('complex64'),
np.dtype('complex128'),
]
# Mapping from types to their type numbers.
_jax_type_nums = {t: i for i, t in enumerate(_jax_types)}
def _make_type_promotion_table():
b1, u1, u2, u4, u8, s1, s2, s4, s8, bf, f2, f4, f8, c4, c8 = _jax_types
# b1, u1, u2, u4, u8, s1, s2, s4, s8, bf, f2, f4, f8, c4, c8
return np.array([
[b1, u1, u2, u4, u8, s1, s2, s4, s8, bf, f2, f4, f8, c4, c8], # b1
[u1, u1, u2, u4, u8, s2, s2, s4, s8, bf, f2, f4, f8, c4, c8], # u1
[u2, u2, u2, u4, u8, s4, s4, s4, s8, bf, f2, f4, f8, c4, c8], # u2
[u4, u4, u4, u4, u8, s8, s8, s8, s8, bf, f2, f4, f8, c4, c8], # u4
[u8, u8, u8, u8, u8, f8, f8, f8, f8, bf, f2, f4, f8, c4, c8], # u8
[s1, s2, s4, s8, f8, s1, s2, s4, s8, bf, f2, f4, f8, c4, c8], # s1
[s2, s2, s4, s8, f8, s2, s2, s4, s8, bf, f2, f4, f8, c4, c8], # s2
[s4, s4, s4, s8, f8, s4, s4, s4, s8, bf, f2, f4, f8, c4, c8], # s4
[s8, s8, s8, s8, f8, s8, s8, s8, s8, bf, f2, f4, f8, c4, c8], # s8
[bf, bf, bf, bf, bf, bf, bf, bf, bf, bf, f4, f4, f8, c4, c8], # bf
[f2, f2, f2, f2, f2, f2, f2, f2, f2, f4, f2, f4, f8, c4, c8], # f2
[f4, f4, f4, f4, f4, f4, f4, f4, f4, f4, f4, f4, f8, c4, c8], # f4
[f8, f8, f8, f8, f8, f8, f8, f8, f8, f8, f8, f8, f8, c8, c8], # f8
[c4, c4, c4, c4, c4, c4, c4, c4, c4, c4, c4, c4, c8, c4, c8], # c4
[c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8], # c8
])
_type_promotion_table = _make_type_promotion_table()
def promote_types(a, b):
"""Returns the type to which a binary operation should cast its arguments.
For details of JAX's type promotion semantics, see :ref:`type-promotion`.
Args:
a: a :class:`numpy.dtype` or a dtype specifier.
b: a :class:`numpy.dtype` or a dtype specifier.
Returns:
A :class:`numpy.dtype` object.
"""
a = np.dtype(a)
b = np.dtype(b)
try:
return _type_promotion_table[_jax_type_nums[a], _jax_type_nums[b]]
except KeyError:
pass
raise TypeError("Invalid type promotion of {} and {}".format(a, b))
def is_python_scalar(x):
try:
return x.aval.weak_type and np.ndim(x) == 0
except AttributeError:
return type(x) in python_scalar_dtypes
def _dtype_priority(dtype):
if issubdtype(dtype, np.bool_):
return 0
elif issubdtype(dtype, np.integer):
return 1
elif issubdtype(dtype, np.floating):
return 2
elif issubdtype(dtype, np.complexfloating):
return 3
else:
raise TypeError("Dtype {} is not supported by JAX".format(dtype))
def dtype(x):
if type(x) in python_scalar_dtypes:
return python_scalar_dtypes[type(x)]
return np.result_type(x)
def result_type(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return dtype(args[0])
scalars = []
dtypes = []
for x in args:
(scalars if is_python_scalar(x) else dtypes).append(dtype(x))
array_priority = max(map(_dtype_priority, dtypes)) if dtypes else -1
dtypes += [x for x in scalars if _dtype_priority(x) > array_priority]
return canonicalize_dtype(functools.reduce(promote_types, dtypes))