-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
api_util.py
674 lines (582 loc) · 25.9 KB
/
api_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Sequence
import inspect
import operator
from functools import partial
from typing import Any, Callable, Optional, Union
import warnings
import numpy as np
from jax._src import core
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
treedef_children, generate_key_paths, keystr, broadcast_prefix,
prefix_errors)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable)
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
map = safe_map
def _ensure_index(x: Any) -> Union[int, tuple[int, ...]]:
"""Ensure x is either an index or a tuple of indices."""
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
try:
return operator.index(x)
except TypeError:
return tuple(map(operator.index, x))
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
"""Convert x to a tuple of indices."""
x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
try:
return (operator.index(x),)
except TypeError:
return tuple(map(operator.index, x))
def _ensure_str(x: str) -> str:
if not isinstance(x, str):
raise TypeError(f"argument is not a string: {x}")
return x
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> tuple[str, ...]:
"""Convert x to a tuple of strings."""
if isinstance(x, str):
return (x,)
else:
return tuple(map(_ensure_str, x))
@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)
def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
if in_tree != in_tree_expected:
raise TypeError(f"Expected {in_tree_expected}, got {in_tree}")
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError(f"Expected {in_tree_expected}, got {in_tree}")
ans = fun(*args)
return tree_unflatten(out_tree, ans)
def flattened_fun_in_tree(
fn: lu.WrappedFun
) -> Optional[tuple[PyTreeDef, Callable[[], PyTreeDef], bool]]:
# This implementation relies on internal details of linear_util.py's
# WrappedFun, but it's for the worthy cause of better user error messages.
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
# core.eval_jaxpr encounters a call primitive (though at that point we're just
# round-tripping jaxprs and the user errors in question are impossible).
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
assert (isinstance(flatten_fun_nokwargs, partial) and
len(flatten_fun_nokwargs.args) == 1)
flattens = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
try:
((in_tree,), out_tree_store, has_kwargs), = (
(args, store, f is flatten_fun.args[0])
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
except ValueError:
return None
else:
return in_tree, lambda: out_tree_store.val, has_kwargs
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
pair = yield py_args, {}
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
raise TypeError("expected function with aux output to return a two-element "
f"tuple, but got type {type(pair)} with value {pair!r}")
ans, aux = pair
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
class _HashableWithStrictTypeEquality:
"""Box object used when comparing static arguments as a jit key.
Requires exact type equality using `is` and value equality."""
__slots__ = ["val"]
def __init__(self, val):
self.val = val
def __hash__(self):
return hash(self.val)
def __eq__(self, other):
return type(self.val) is type(other.val) and self.val == other.val
_POSITIONAL_ARGUMENTS = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD
)
def validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None:
"""
Validate that the argnums are sensible for a given function.
For functions that accept a variable number of positions arguments
(`f(..., *args)`) all positive argnums are considered valid.
"""
n_pos_args = 0
for param in sig.parameters.values():
if param.kind in _POSITIONAL_ARGUMENTS:
n_pos_args += 1
elif param.kind is inspect.Parameter.VAR_POSITIONAL:
# We can have any number of positional arguments
return
if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args):
# raise ValueError(f"Jitted function has {argnums_name}={argnums}, "
# f"but only accepts {n_pos_args} positional arguments.")
# TODO: 2022-08-20 or later: replace with error
warnings.warn(f"Jitted function has {argnums_name}={argnums}, "
f"but only accepts {n_pos_args} positional arguments. "
"This warning will be replaced by an error after 2022-08-20 "
"at the earliest.", SyntaxWarning)
_INVALID_KEYWORD_ARGUMENTS = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.VAR_POSITIONAL
)
_KEYWORD_ARGUMENTS = (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
def validate_argnames(sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str) -> None:
"""
Validate that the argnames are sensible for a given function.
For functions that accept a variable keyword arguments
(`f(..., **kwargs)`) all argnames are considered valid except those
marked as position-only (`f(pos_only, /, ...)`).
"""
var_kwargs = False
valid_kwargs: set[str] = set()
invalid_kwargs: set[str] = set()
for param_name, param in sig.parameters.items():
if param.kind in _KEYWORD_ARGUMENTS:
valid_kwargs.add(param_name)
elif param.kind is inspect.Parameter.VAR_KEYWORD:
var_kwargs = True
elif param.kind in _INVALID_KEYWORD_ARGUMENTS:
invalid_kwargs.add(param_name)
# Check whether any kwargs are invalid due to position only
invalid_argnames = invalid_kwargs & set(argnames)
if invalid_argnames:
# raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
# f"in {argnames_name}. These are positional-only")
# TODO: 2022-08-20 or later: replace with error
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
f"in {argnames_name}. These are positional-only. "
"This warning will be replaced by an error after 2022-08-20 "
"at the earliest.", SyntaxWarning)
# Takes any kwargs
if var_kwargs:
return
# Check that all argnames exist on function
invalid_argnames = set(argnames) - valid_kwargs
if invalid_argnames:
# TODO: 2022-08-20 or later: replace with error
# raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} "
# f"in {argnames_name}. Function does not take these args.")
warnings.warn(f"Jitted function has invalid argnames {invalid_argnames} "
f"in {argnames_name}. Function does not take these args."
"This warning will be replaced by an error after 2022-08-20 "
"at the earliest.", SyntaxWarning)
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
dyn_argnums = _ensure_index_tuple(dyn_argnums)
dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums)
if require_static_args_hashable:
fixed_args = []
for i, arg in enumerate(args):
if i in dyn_argnums: continue
if not is_hashable(arg):
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(arg)} for function {f.__name__} is non-hashable.")
fixed_args.append(_HashableWithStrictTypeEquality(arg))
else:
fixed_args = [Unhashable(arg) for i, arg in enumerate(args)
if i not in dyn_argnums]
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int]
) -> tuple[int, ...]:
"""Ensure argnum is within bounds. Also resolves negative argnums."""
result = []
for i in argnums:
if i >= num_args and allow_invalid: continue
if not -num_args <= i < num_args:
raise ValueError(
"Positional argument indices, e.g. for `static_argnums`, must have "
"value greater than or equal to -len(args) and less than len(args), "
f"but got value {i} for len(args) == {num_args}.")
result.append(i % num_args) # Resolve negative
return tuple(result)
def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
args: tuple[Any, ...], *, allow_invalid: bool):
"Version of ``argnums_partial`` that checks hashability of static_argnums."
if not static_argnums:
return f, args
static_argnums = _ensure_inbounds(allow_invalid, len(args), static_argnums)
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
dyn_args = tuple(args[i] for i in dyn_argnums)
fixed_args = []
for i in static_argnums:
# TODO(shoyer): set allow_invalid=True permanently after static_argnames.
if allow_invalid and i >= len(args):
continue
static_arg = args[i]
if not is_hashable(static_arg):
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
else:
fixed_args.append(_HashableWithStrictTypeEquality(static_arg)) # type: ignore
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
sentinel = object()
args = [sentinel] * (len(fixed_args) + len(dyn_args))
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
fixed_args_ = iter(fixed_args)
args = [next(fixed_args_).val if x is sentinel else x for x in args]
assert next(fixed_args_, sentinel) is sentinel
ans = yield args, kwargs
yield ans
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
kwargs: dict[str, Any]):
if not static_argnames:
return f, kwargs
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
fixed_kwargs: dict[str, Any] = {}
for k, arg in kwargs.items():
if k not in dyn_kwargs:
try:
hash(arg)
except TypeError:
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (name {k}) of type "
f"{type(arg)} for function {f.__name__} is non-hashable.")
else:
fixed_kwargs[k] = Hashable(arg) # type: ignore
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
@lu.transformation
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
ans = yield args, kwargs
yield ans
def donation_vector(donate_argnums, donate_argnames, args, kwargs) -> tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args and kwargs.
What if a user specifies donate_argnums but calls the function with kwargs
or vice-versa? In that case, in `resolve_argnums` using the signature of the
function, the counterpart (donate_argnames or donate_argnums respectively) is
calculated so when this function is called both donate_argnums and
donate_argnames are available. This allows JAX to donate kwargs when only
donate_argnums is specified and vice-versa.
When both donate_argnums and donate_argnames are specified, only the args and
kwargs specified are donated.
"""
res: list[bool] = []
for i, arg in enumerate(args):
donate = bool(i in donate_argnums)
res.extend((donate,) * tree_structure(arg).num_leaves)
for key, val in kwargs.items():
donate = key in donate_argnames
res.extend((donate,) * tree_structure(val).num_leaves)
return tuple(res)
def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]:
"""Shifts donate to account for static.
>>> rebase_donate_argnums((3, 4), (0, 1))
(1, 2)
Args:
donate_argnums: An iterable of ints.
static_argnums: An iterable of ints.
Returns:
A tuple of unique, sorted integer values based on donate_argnums with each
element offset to account for static_argnums.
"""
if not (static_argnums or donate_argnums):
return tuple(sorted(donate_argnums))
static_argnums = sorted(set(static_argnums))
donate_argnums = sorted(set(donate_argnums))
i = j = o = 0
out = []
while j < len(donate_argnums):
if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
raise ValueError(f"`static_argnums` {static_argnums} and "
f"`donate_argnums` {donate_argnums} cannot intersect.")
if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
o += 1
i += 1
else:
out.append(donate_argnums[j] - o)
j += 1
return tuple(out)
def is_hashable(arg):
try:
hash(arg)
return True
except TypeError:
return False
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
# the given treedef, build a complete axis spec tree with the same structure
# and return the flattened result
# TODO(mattjj,phawkins): improve this implementation
proxy = object()
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
try:
tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
except ValueError:
if kws:
# if keyword arguments are included in the tree, we make adapt the error
# message only to be about the positional arguments
treedef, _ = treedef_children(treedef)
axis_tree, _ = axis_tree
hint = ""
if tupled_args:
hint += (f" Note that {name} that are non-trivial pytrees should always be "
f"wrapped in a tuple representing the argument list.")
if len(treedef.children()) == 1:
try:
flatten_axes(name, treedef, (axis_tree,))
except ValueError:
pass # That's not the issue.
else:
hint += (f" In particular, you're passing in a single argument which "
f"means that {name} might need to be wrapped in "
f"a singleton tuple.")
raise ValueError(f"{name} specification must be a tree prefix of the "
f"corresponding value, got specification {axis_tree} "
f"for value tree {treedef}.{hint}") from None
axes = [None if a is proxy else a for a in axes]
assert len(axes) == treedef.num_leaves
return axes
def flat_out_axes(
f: lu.WrappedFun, out_spec: Any
) -> tuple[lu.WrappedFun, Callable]:
leaves, treedef = tree_flatten(out_spec)
f, out_axes = _flat_out_axes(f, tuple(leaves), treedef)
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
@lu.transformation_with_aux
def _flat_out_axes(leaves, treedef, *args, **kwargs):
ans = yield args, kwargs
spec = tree_unflatten(treedef, leaves)
try:
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
except ValueError:
e, *_ = prefix_errors(spec, ans)
# TODO(mattjj): currently hardcoded for pmap; generalize to vmap in followup
msg, = e('pmap out_axes').args
msg += ("\n\nThe full pytree is the output of the pmapped function. Ensure "
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
"pmapped function's output.")
raise ValueError(msg) from None
yield ans, spec_flat
def check_callable(fun):
# In Python 3.10+, the only thing stopping us from supporting staticmethods
# is that we can't take weak references to them, which the C++ JIT requires.
if isinstance(fun, staticmethod):
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
if not callable(fun):
raise TypeError(f"Expected a callable value, got {fun}")
if inspect.isgeneratorfunction(fun):
raise TypeError(f"Expected a function, got a generator function: {fun}")
_POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
def infer_argnums_and_argnames(
sig: inspect.Signature,
argnums: Union[int, Iterable[int], None],
argnames: Union[str, Iterable[str], None],
) -> tuple[tuple[int, ...], tuple[str, ...]]:
"""Infer missing argnums and argnames for a function with inspect."""
if argnums is None and argnames is None:
return (), ()
if argnums is not None and argnames is not None:
argnums = _ensure_index_tuple(argnums)
argnames = _ensure_str_tuple(argnames)
return argnums, argnames
parameters = sig.parameters
if argnums is None:
assert argnames is not None
argnames = _ensure_str_tuple(argnames)
argnums = tuple(
i for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames
)
else:
argnums = _ensure_index_tuple(argnums)
argnames = tuple(
k for i, (k, param) in enumerate(parameters.items())
if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums
)
return argnums, argnames
def resolve_argnums(
fun, donate_argnums, donate_argnames, static_argnums, static_argnames
) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]:
try:
sig = inspect.signature(fun)
except ValueError as e:
# Some built-in functions don't support signature.
# See: https://github.com/python/cpython/issues/73485
# In this case no validation is done
static_argnums = () if static_argnums is None else _ensure_index_tuple(
static_argnums)
static_argnames = () if static_argnames is None else _ensure_str_tuple(
static_argnames)
donate_argnums = () if donate_argnums is None else _ensure_index_tuple(
donate_argnums)
if donate_argnames is not None:
raise ValueError(f"Getting the signature of function {fun} failed. "
"Pass donate_argnums instead of donate_argnames.") from e
assert donate_argnames is None
donate_argnames = ()
else:
# Infer argnums and argnames according to docstring
# If nums is None and names is not None, then nums are inferred from the
# names and vice-versa.
static_argnums, static_argnames = infer_argnums_and_argnames(
sig, static_argnums, static_argnames)
donate_argnums, donate_argnames = infer_argnums_and_argnames(
sig, donate_argnums, donate_argnames)
# Validation
validate_argnums(sig, static_argnums, "static_argnums")
validate_argnames(sig, static_argnames, "static_argnames")
validate_argnums(sig, donate_argnums, "donate_argnums")
validate_argnames(sig, donate_argnames, "donate_argnames")
# Compensate for static argnums absorbing args
assert_no_intersection(static_argnames, donate_argnames)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
return donate_argnums, donate_argnames, static_argnums, static_argnames
def assert_no_intersection(static_argnames, donate_argnames):
out = set(static_argnames).intersection(set(donate_argnames))
if out:
raise ValueError(
"static_argnames and donate_argnames cannot intersect. Argument names "
f"{out} appear in both static_argnames and donate_argnames")
def _dtype(x):
try:
return dtypes.result_type(x)
except ValueError:
return dtypes.result_type(getattr(x, 'dtype'))
def _shaped_abstractify_slow(x):
try:
return core.raise_to_shaped(
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
except TypeError:
pass
weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
named_shape=named_shape)
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
def shaped_abstractify(x):
try:
return _shaped_abstractify_handlers[type(x)](x)
except KeyError:
return _shaped_abstractify_slow(x)
_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {}
def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
_shaped_abstractify_handlers[str] = _str_abstractify
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x),
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers.update((t, _np_scalar_abstractify)
for t in numpy_scalar_types)
# This decorator exists to make it easier to monkey-patch APIs in JAX.
# By default it does nothing, but it can be monkey-patched to do other things.
def api_hook(fun, tag: str):
return fun
def debug_info(traced_for: str, fun: Callable, args: tuple[Any],
kwargs: dict[str, Any], static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]) -> Optional[TracingDebugInfo]:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
src = fun_sourceinfo(fun)
arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames)
if src is None or arg_names is None: return None
return TracingDebugInfo(traced_for, src, arg_names, None)
# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> Optional[str]:
while isinstance(fun, partial):
fun = fun.func
fun = inspect.unwrap(fun)
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return None
def _arg_names(fn, args, kwargs, static_argnums, static_argnames,
) -> Optional[tuple[str, ...]]:
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = inspect.signature(fn).bind(*args_, **kwargs)
except (ValueError, TypeError):
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
@lu.transformation_with_aux
def result_paths(*args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = yield args, kwargs
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
result_paths: Optional[tuple[Optional[str], ...]] = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is None:
return jaxpr
assert (result_paths is not None) ^ (trace_debug.result_paths is not None)
if result_paths is None:
result_paths = trace_debug.result_paths() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths))
return jaxpr.replace(debug_info=debug_info)
def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],
res_paths: Callable[[], tuple[str, ...]]) -> lu.WrappedFun:
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
if dbg is None: return f
assert dbg.result_paths is None
res_paths_ = HashableFunction(res_paths, closure=())
return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_))