-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
checkify.py
1216 lines (1022 loc) · 47.2 KB
/
checkify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import functools
import itertools as it
import types
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List
import jax
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax._src.lax import control_flow as cf
from jax._src.sharding import OpShardingSharding
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
safe_zip)
from jax.api_util import flatten_fun
from jax.api_util import flatten_fun_nokwargs
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Bool = Union[bool, Array]
Int = Union[int, Array]
ErrorCategory = Type['JaxException']
Payload = List[Union[np.ndarray, Array]]
PyTreeDef = jtu.PyTreeDef
## Utils
def popattr(obj, attrname):
val = getattr(obj, attrname)
delattr(obj, attrname)
return val
def setnewattr(obj, name, val):
sentinel = object()
assert getattr(obj, name, sentinel) is sentinel
setattr(obj, name, val)
# Concrete errors
class JaxException(Exception):
"""Python exception which can contain an error message with JAX run-time info."""
def __init__(self, traceback_info):
self.traceback_info = traceback_info
# TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
# self.with_traceback(self.traceback_info)
def __init_subclass__(cls):
jtu.register_pytree_node_class(cls)
def tree_flatten(self):
return ([], self.traceback_info)
@classmethod
def tree_unflatten(cls, metadata, payload):
del payload
return cls(metadata)
def get_effect_type(self) -> core.Effect:
pass
@functools.total_ordering
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect:
error_type: Type[JaxException]
shape_dtypes: Tuple[jax.ShapeDtypeStruct, ...]
def __post_init__(self):
cf.allowed_effects.add(self)
mlir.lowerable_effects.add(self)
def __lt__(self, other: 'ErrorEffect'):
shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype)) # dtype is not comparable
for sd in x.shape_dtypes)
unpack = lambda x: (str(x.error_type), shape_dtypes(x))
return (unpack(self) < unpack(other))
class DivisionByZeroError(JaxException):
def __str__(self):
return f'division by zero at {self.traceback_info}'
def get_effect_type(self):
return ErrorEffect(DivisionByZeroError, ())
class NaNError(JaxException):
def __init__(self, traceback_info, primitive_name):
super().__init__(traceback_info)
self.prim = primitive_name
def tree_flatten(self):
return ([], (self.traceback_info, self.prim))
@classmethod
def tree_unflatten(cls, metadata, _):
return cls(*metadata)
def get_effect_type(self):
return ErrorEffect(NaNError, ())
def __str__(self):
return f'nan generated by primitive: {self.prim} at {self.traceback_info}'
class OOBError(JaxException):
def __init__(self, traceback_info, primitive_name, operand_shape, payload):
super().__init__(traceback_info)
self.prim = primitive_name
self.operand_shape = operand_shape
self._payload = payload
def tree_flatten(self):
return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))
@classmethod
def tree_unflatten(cls, metadata, payload):
return cls(*metadata, payload[0])
def __str__(self):
return (f'out-of-bounds indexing for array of '
f'shape {self.operand_shape}: '
f'index {self._payload[0]} is out of bounds for axis '
f'{self._payload[1]} with size {self._payload[2]}. '
f'Failed at {self.traceback_info}')
def get_effect_type(self):
return ErrorEffect(OOBError, (jax.ShapeDtypeStruct((3,), jnp.int32),))
class FailedCheckError(JaxException):
def __init__(self, traceback_info, fmt_string, *a, **k):
super().__init__(traceback_info)
self.fmt_string = fmt_string
self.args = a
self.kwargs = k
def tree_flatten(self):
return ((jnp.array([], jnp.int32), self.args, self.kwargs),
(self.traceback_info, self.fmt_string))
@classmethod
def tree_unflatten(cls, metadata, payload):
_, args, kwargs = payload
return cls(*metadata, *args, **kwargs)
def __str__(self):
return (self.fmt_string.format(*self.args, **self.kwargs)
+ f' (check failed at {self.traceback_info})')
def get_effect_type(self):
vals = jtu.tree_leaves((self.args, self.kwargs))
return ErrorEffect(
FailedCheckError,
# Need a 0-size array here for data-dependence.
(jax.ShapeDtypeStruct((0,), jnp.int32),
*tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in vals)))
@dataclasses.dataclass
class BatchedError(JaxException):
error_mapping: Dict[Tuple[int, ...], JaxException]
def __post_init__(self):
traceback_info = list(self.error_mapping.values())[0].traceback_info
super().__init__(traceback_info)
def __str__(self):
return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
for idx, e in self.error_mapping.items())
# Error Value
@jtu.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Error:
_pred: Dict[ErrorEffect, Bool]
_code: Dict[ErrorEffect, Int]
_metadata: Dict[Int, PyTreeDef] # mapping of code to JaxException treedef.
_payload: Dict[ErrorEffect, Payload]
def get(self) -> Optional[str]:
"""Returns error message if error happened, None if no error happened."""
exp = self.get_exception()
if exp is not None:
return str(exp)
return None
def get_exception(self) -> Optional[JaxException]:
"""Returns Python exception if error happened, None if no error happened."""
if any(map(np.shape, self._pred.values())):
return self._get_batched_exception()
else:
min_code = None
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect]:
if min_code is None or code < min_code:
min_code = code
cur_effect = error_effect
if cur_effect is not None:
return tree_unflatten(self._metadata[int(min_code)], # type: ignore
self._payload[cur_effect])
return None
def throw(self):
check_error(self)
def __str__(self):
return f'Error({self.get()})'
# Internal helpers
def _get_batched_exception(self):
shape = np.shape(list(self._pred.values())[0])
error_mapping = {}
for idx in np.ndindex(*shape):
min_code = None
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect][idx]: # type: ignore
if min_code is None or code[idx] < min_code:
min_code = code[idx] # type: ignore
cur_effect = error_effect
if cur_effect is not None:
payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
jax_error = tree_unflatten(self._metadata[int(min_code)], payload) # type: ignore
error_mapping[idx] = jax_error
return BatchedError(error_mapping)
def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
new_errs = {**self._pred, **{effect_type: pred}} # type: ignore
new_codes = {**self._code, **{effect_type: code}} # type: ignore
new_payload = {**self._payload, **{effect_type: payload}} # type: ignore
new_metadata = {**self._metadata, **metadata}
return Error(new_errs, new_codes, new_metadata, new_payload)
def _add_placeholder_effects(self, effects: Set[ErrorEffect]):
"""Fill out Error with `effects` and np.ones arrays of their payloads."""
new_err = self._pred.copy()
new_code = self._code.copy()
new_payload = self._payload.copy()
for effect in effects:
if effect not in self._pred.keys():
new_err[effect] = False
new_payload[effect] = list(
tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
# The error value associated with this effect will never become True, so
# we don't need to set a meaningful code.
new_code[effect] = -1
return Error(new_err, new_code, self._metadata, new_payload)
def _replace(self, *args, **kwargs):
return dataclasses.replace(self, *args, **kwargs)
# PyTree methods
def tree_flatten(self):
return ((self._pred, self._code, self._payload), (self._metadata))
@classmethod
def tree_unflatten(cls, metadata, data):
pred, code, payload = data
return cls(pred, code, metadata, payload)
init_error = Error({}, {}, {}, {}) # value used as initial (empty) error.
next_code = it.count(1).__next__ # globally unique ids, could be uuid4
def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
code = next_code()
effect_type = new_error.get_effect_type()
new_payload, new_metadata = tree_flatten(new_error)
return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)
def update_error(error, pred, code, metadata, payload, effect_type):
err_of_type = error._pred.get(effect_type, False)
out_err = err_of_type | pred
out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code)
cur_payload = error._payload.get(effect_type, None)
if cur_payload is not None:
out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload)
else:
out_payload = payload
return error._update(effect_type, out_err, out_code, metadata, out_payload)
## Checkify transformation for plumbing functional error values.
class CheckifyTracer(core.Tracer):
def __init__(self, trace, val):
self._trace = trace
self.val = val
aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self
class CheckifyTrace(core.Trace):
pure = lift = lambda self, val: CheckifyTracer(self, val)
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
enabled_errors: FrozenSet['ErrorCategory']) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
self.main.enabled_errors = enabled_errors
def sublift(self, tracer):
return CheckifyTracer(self, tracer.val)
def process_primitive(self, primitive, tracers, params):
in_vals = [t.val for t in tracers]
rule = error_checks.get(primitive)
if rule:
out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore
*in_vals, **params)
else:
out = primitive.bind(*in_vals, **params)
if primitive.multiple_results:
return [CheckifyTracer(self, x) for x in out]
else:
return CheckifyTracer(self, out)
def process_call(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
flat_vals, in_tree = tree_flatten((e, *in_vals))
f = checkify_subtrace(f, self.main)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
if 'donated_invars' in params:
params = dict(params, donated_invars=(*[False]*len(jtu.tree_leaves(e)),
*params['donated_invars']))
all_vals = primitive.bind(f, *flat_vals, **params)
error, *out_vals = tree_unflatten(out_tree(), all_vals)
setnewattr(self.main, 'error', error)
return [CheckifyTracer(self, x) for x in out_vals]
def process_map(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(jtu.tree_leaves(e))
f = checkify_subtrace(f, self.main)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
@as_hashable_function(closure=params['out_axes_thunk'])
def new_out_axes_thunk():
out_val_axes = params['out_axes_thunk']()
out_err_num = out_tree().num_leaves - len(out_val_axes)
return (*(0,)*out_err_num, *out_val_axes)
params_ = dict(params, in_axes=(*(None,)*num_error_vals, *params['in_axes']),
out_axes_thunk=new_out_axes_thunk,
donated_invars=(*(False,)*num_error_vals, *params['donated_invars']))
all_vals = primitive.bind(f, *flat_vals, **params_)
error, *out_vals = tree_unflatten(out_tree(), all_vals)
error = _reduce_any_error(error)
setnewattr(self.main, 'error', error)
return [CheckifyTracer(self, x) for x in out_vals]
def post_process_call(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
setnewattr(main, 'err_tree', err_tree)
def todo(vals):
err_tree = popattr(main, 'err_tree')
err_vals, vals = split_list(vals, [err_tree.num_leaves])
setnewattr(main, 'error', tree_unflatten(err_tree, err_vals))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (*err_leaves, *vals), todo
def post_process_map(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
num_err_leaves = len(err_leaves)
setnewattr(main, 'err_tree', err_tree)
def todo(vals):
err_tree = popattr(main, 'err_tree')
err_vals, vals = split_list(vals, [err_tree.num_leaves])
error = tree_unflatten(err_tree, err_vals)
error = _reduce_any_error(error)
setnewattr(main, 'error', error)
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
def out_axes_transform(out_axes):
return (*(0,)*num_err_leaves, *out_axes)
return (*err_leaves, *vals), (todo, out_axes_transform)
def process_custom_jvp_call(self, prim, f, jvp, tracers):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
err_vals, err_tree = tree_flatten(e)
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(err_vals)
f = checkify_subtrace(f, self.main)
f, f_out_tree = flatten_fun_nokwargs(f, in_tree)
jvp, jvp_err_tree = checkify_custom_jvp_subtrace(jvp, self.main,
num_error_vals, err_tree)
all_outs = prim.bind(f, jvp, *flat_vals)
fst, out_tree = lu.merge_linear_aux(f_out_tree, jvp_err_tree)
if fst:
out_err, *out_vals = tree_unflatten(out_tree, all_outs)
else:
err_vals, out_vals = split_list(all_outs, [num_error_vals])
# forward input error values to output
out_err = tree_unflatten(out_tree, err_vals)
setattr(self.main, 'error', out_err)
return [CheckifyTracer(self, x) for x in out_vals]
def post_process_custom_jvp_call(self, tracers, jvp_was_run):
if jvp_was_run:
msg = ('support for custom_jvp rules which close over checkify values is '
'not implemented. If you see this, open an issue at '
'https://github.com/google/jax/issues!')
raise NotImplementedError(msg)
vals = [t.val for t in tracers]
main = self.main
e = popattr(main, 'error')
err_leaves, err_tree = tree_flatten(e)
def todo(vals):
err_vals, vals = split_list(vals, [len(err_leaves)])
setnewattr(main, 'error', tree_unflatten(err_tree, err_vals))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (*err_leaves, *vals), todo
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
err_vals, err_tree = tree_flatten(e)
flat_vals, in_tree = tree_flatten((e, *in_vals))
num_error_vals = len(err_vals)
fun = checkify_subtrace(fun, self.main)
fun, fun_out_tree = flatten_fun_nokwargs(fun, in_tree)
fwd, fwd_err_tree = checkify_custom_vjp_subtrace(fwd, self.main,
err_tree, num_error_vals)
all_out_vals = prim.bind(fun, fwd, bwd, *flat_vals, out_trees=out_trees)
fst, out_tree = lu.merge_linear_aux(fun_out_tree, fwd_err_tree)
if fst:
error, *out = tree_unflatten(out_tree, all_out_vals)
else:
_, out = split_list(all_out_vals, [num_error_vals])
# forward input error values to output
error = tree_unflatten(err_tree, err_vals)
setattr(self.main, 'error', error)
return [CheckifyTracer(self, x) for x in out]
def _reduce_any_error(error: Error):
out_error = init_error
for error_effect in error._pred.keys():
errs, codes, payloads = (error._pred[error_effect],
error._code[error_effect],
error._payload[error_effect])
reduced_idx = jnp.argsort(errs)[-1]
pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
(errs, codes, payloads))
out_error = out_error._update(error_effect, pred, code, {}, payload)
out_error = out_error._replace(_metadata=error._metadata)
return out_error
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'],
*args):
fun = checkify_subtrace(fun)
fun = checkify_traceable(fun, enabled_errors)
error, *outvals = fun.call_wrapped(init_error, *args)
return error, outvals
@lu.transformation
def checkify_traceable(enabled_errors, error, *args):
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
outs = yield (main, error, *args), {}
del main
yield outs
@lu.transformation
def checkify_subtrace(main, error, *args):
setnewattr(main, 'error', error)
trace = main.with_cur_sublevel()
in_tracers = [CheckifyTracer(trace, x) for x in args]
out = yield in_tracers, {}
out_tracers = map(trace.full_raise, out)
out_vals = [t.val for t in out_tracers]
error = main.error
del main.error
yield (error, *out_vals)
@lu.transformation_with_aux
def checkify_custom_jvp_subtrace(main, num_error_vals, out_tree, *args):
# Like checkify_subtrace, but used specifically on the custom JVP rules
# associated with a custom_jvp. This code is called in the context of a
# jvp-of-checkify-of-custom_jvp. It takes both primal and tangent inputs,
# flattened into a single args tuple, and similarly must produce flattened
# primal and tangent outputs. Both primals and tangents include error values,
# but the tangent error values are trivially zero.
# The types to have in mind are:
# jvp : (a -> b) -> (a, T a) -> (b, T b)
# checkify : (a -> b) -> a -> Err b
# jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
# where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
# where the other Err' components are trivial (of float0 dtype).
# Semantically, we don't add checks to the JVP rule. To check the result of a
# JVP rule, one must instead use checkify-of-jvp. Thus this implementation
# just forwards the input error and code (and trivial tangents) to the output.
del main
n, ragged = divmod(len(args), 2)
assert not ragged
err_primals, primals = split_list(args[:n], [num_error_vals])
err_tangents, tangents = split_list(args[n:], [num_error_vals])
outs = yield (*primals, *tangents), {}
m, ragged = divmod(len(outs), 2)
assert not ragged
out_primals, out_tangents = outs[:m], outs[m:]
yield (*err_primals, *out_primals, *err_tangents, *out_tangents), out_tree
@lu.transformation_with_aux
def checkify_custom_vjp_subtrace(main, err_tree, num_error_vals, *args):
del main
# We don't add any checks; just drop input error values.
_, args = split_list(args, [num_error_vals])
outs = yield args, {}
yield outs, err_tree
@lu.transformation_with_aux
def query_error_effects(*args):
(error, *outs) = yield args, {}
yield (error, *outs), set(error._pred.keys())
def checkify_jaxpr(jaxpr, error,
enabled_errors) -> Tuple[core.ClosedJaxpr,
Tuple[PyTreeDef,
FrozenSet[ErrorEffect]]]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
def checkify_fun_to_jaxpr(
f, error, enabled_errors,
in_avals) -> Tuple[core.ClosedJaxpr, Tuple[PyTreeDef, FrozenSet[ErrorEffect]]]:
flat_error_vals, in_tree = tree_flatten(error)
f = checkify_subtrace(f)
f = checkify_traceable(f, enabled_errors)
f, error_effect = query_error_effects(f)
in_tree = jtu.tree_structure((error, *in_avals))
f, out_tree = flatten_fun_nokwargs(f, in_tree)
err_vals = map(lambda x: core.raise_to_shaped(core.get_aval(x)),
flat_error_vals)
avals_in = [*err_vals, *in_avals]
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
return (core.ClosedJaxpr(jaxpr_out, literals_out), (out_tree(), error_effect()))
def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can't be staged (jitted/scanned/...).
Before staging a function with checks, :func:`~checkify` it!
Args:
pred: if False, an error is added.
msg: error message if error is added. Can be a format string.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "{x} needs to be positive!", x=x)
... return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(-3.)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
"""
if not is_scalar_pred(pred):
raise TypeError(f'check takes a scalar pred as argument, got {pred}')
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
error = assert_func(init_error, jnp.logical_not(pred), new_error)
return check_error(error)
def is_scalar_pred(pred) -> bool:
return (isinstance(pred, bool) or
isinstance(pred, jnp.ndarray) and pred.shape == () and
pred.dtype == jnp.dtype('bool'))
def check_error(error: Error) -> None:
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
The semantics of this function are equivalent to:
>>> def check_error(err: Error) -> None:
... err.throw() # can raise ValueError
But unlike that implementation, ``check_error`` can be functionalized using
the :func:`~checkify` transformation.
This function is similar to :func:`~check` but with a different signature: whereas
:func:`~check` takes as arguments a boolean predicate and a new error message
string, this function takes an ``Error`` value as argument. Both :func:`~check`
and this function raise a Python Exception on failure (a side-effect), and
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
:func:`~jax.lax.scan`, etc. Both also can
be functionalized by using :func:`~checkify`.
But unlike :func:`~check`, this function is like a direct inverse of
:func:`~checkify`:
whereas :func:`~checkify` takes as input a function which
can raise a Python
Exception and produces a new function without that effect but which produces
an ``Error`` value as output, this ``check_error`` function can accept an
``Error`` value as input and can produce the side-effect of raising an
Exception. That is, while :func:`~checkify` goes from
functionalizable Exception
effect to error value, this ``check_error`` goes from error value to
functionalizable Exception effect.
``check_error`` is useful when you want to turn checks represented by an
``Error`` value (produced by functionalizing ``checks`` via
:func:`~checkify`) back into Python Exceptions.
Args:
error: Error to check.
For example, you might want to functionalize part of your program through
checkify, stage out your functionalized code through :func:`~jax.jit`, then
re-inject your error value outside of the :func:`~jax.jit`:
>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x>0, "must be positive!")
... return x
>>> def with_inner_jit(x):
... checked_f = checkify.checkify(f)
... # a checkified function can be jitted
... error, out = jax.jit(checked_f)(x)
... checkify.check_error(error)
... return out
>>> _ = with_inner_jit(1) # no failed check
>>> with_inner_jit(-1) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.JaxRuntimeError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
"""
if not isinstance(error, Error):
raise ValueError('check_error takes an Error as argument, '
f'got type {type(error)} instead.')
error = tree_map(core.raise_as_much_as_possible, error)
if any(map(np.shape, error._pred.values())):
error = _reduce_any_error(error)
err_args, tree_def = tree_flatten(error)
return check_p.bind(*err_args, err_tree=tree_def)
## check primitive
check_p = core.Primitive('check')
check_p.multiple_results = True # zero results
# TODO(lenamartens): inherit from Exception instead of ValueError.
class JaxRuntimeError(ValueError):
pass
@check_p.def_impl
def check_impl(*args, err_tree):
error = tree_unflatten(err_tree, args)
exc = error.get_exception()
if exc:
raise JaxRuntimeError(str(exc)) from exc
return []
@check_p.def_effectful_abstract_eval
def check_abstract_eval(*args, err_tree):
return [], set(tree_unflatten(err_tree, args)._pred.keys())
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
functionalization_error = ValueError(
'Cannot abstractly evaluate a checkify.check which was not'
' functionalized. This probably means you tried to stage'
' (jit/scan/pmap/...) a `check` without functionalizing it'
' through `checkify.checkify`.'
)
def check_lowering_rule(ctx, *args, err_tree):
if not config.jax_experimental_unsafe_xla_runtime_errors:
raise functionalization_error
out_op, _, keep_alive = mlir.emit_python_callback(
ctx, callback=functools.partial(python_err, err_tree),
token=None,
operands=args,
operand_avals=list(ctx.avals_in),
result_avals=list(ctx.avals_out),
has_side_effect=True)
ctx.module_context.add_keepalive(keep_alive)
return out_op
def check_lowering_rule_unsupported(*a, **k):
raise functionalization_error
def python_err(err_tree, *args):
error = tree_unflatten(err_tree, args)
check_error(error)
return []
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
platform='tpu')
mlir.register_lowering(check_p, check_lowering_rule,
platform='cpu')
mlir.register_lowering(check_p, check_lowering_rule,
platform='gpu')
def check_batching_rule(batched_args, batch_dims, *, err_tree):
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
if dim is not batching.not_mapped)
batched_args = (batching.bdim_at_front(a, d, size)
for a, d in zip(batched_args, batch_dims))
err = tree_unflatten(err_tree, batched_args)
check_error(err)
return [], []
batching.primitive_batchers[check_p] = check_batching_rule
def check_jvp_rule(primals, _, *, err_tree):
# Check primals, discard tangents.
check_p.bind(*primals, err_tree=err_tree)
return [], []
ad.primitive_jvps[check_p] = check_jvp_rule
## checkify rules
def _get_current_traceback(skip_frames = 0) -> Optional[types.TracebackType]:
# TODO(lenamartens): use c++ version from XLA?
tb = None
import inspect
for frame_info in inspect.stack():
frame = frame_info.frame
if skip_frames:
skip_frames -= 1
elif not traceback_util.include_frame(frame):
continue
else:
tb = types.TracebackType(tb, frame, frame.f_lasti, frame.f_lineno)
return tb
def summary() -> str:
return str(source_info_util.summarize(source_info_util.current()))
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
out = prim.bind(*in_vals, **params)
err = check_nans(prim, error, enabled_errors, out)
return out, err
def check_nans(prim, error, enabled_errors, out):
if NaNError not in enabled_errors:
return error
def isnan(x):
if isinstance(x, prng.PRNGKeyArray):
return False
return jnp.any(jnp.isnan(x))
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
if prim.multiple_results else isnan(out))
return assert_func(error, any_nans, NaNError(summary(), prim.name))
# All primitives which can generate a NaN.
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
lax.reduce_sum_p, lax.reduce_window_p,
lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]
for prim in nan_primitives:
error_checks[prim] = functools.partial(nan_error_check, prim)
def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
out = lax.gather_p.bind(
operand, start_indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
if OOBError not in enabled_errors:
return out, error
# compare to OOB masking logic in lax._gather_translation_rule
dnums = dimension_numbers
operand_dims = np.array(operand.shape)
num_batch_dims = len(start_indices.shape) - 1
upper_bound = operand_dims[np.array(dnums.start_index_map)]
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))
payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
return out, assert_func(error, jnp.any(oob_mask), OOBError(summary(), "gather", operand.shape, payload))
error_checks[lax.gather_p] = gather_error_check
def div_error_check(error, enabled_errors, x, y):
"""Checks for division by zero and NaN."""
if DivisionByZeroError in enabled_errors:
any_zero = jnp.any(jnp.equal(y, 0))
error = assert_func(error, any_zero, DivisionByZeroError(summary()))
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check
def oob_payload(oob_mask, indices, dims_map, operand_shape):
# Get first OOB index, axis and axis size so it can be added to the error msg.
flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
multi_idx = jnp.unravel_index(flat_idx, indices.shape)
oob_axis = jnp.array(dims_map)[multi_idx[-1]]
oob_axis_size = jnp.array(operand_shape)[oob_axis]
oob_index = jnp.ravel(indices)[flat_idx]
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
return payload
def scatter_oob(operand, indices, updates, dnums):
# Ref: see clamping code used in scatter_translation_rule
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in dnums.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
pos += 1
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims],
np.int64)
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
lower_oob = jnp.less(indices, 0)
upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
oob_mask = jnp.logical_or(lower_oob, upper_oob)
payload = oob_payload(oob_mask, indices,
dnums.scatter_dims_to_operand_dims, operand.shape)
return jnp.any(oob_mask), payload
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
"""Checks if indices are within bounds and update does not generate NaN."""
out = prim.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
if OOBError not in enabled_errors:
return out, error
out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
oob_error = OOBError(summary(), prim.name, operand.shape, payload)
error = assert_func(error, out_of_bounds, oob_error)
return out, check_nans(prim, error, enabled_errors, out)
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
lax.scatter_add_p)
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
lax.scatter_mul_p)
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
lax.scatter_min_p)
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
lax.scatter_max_p)
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
_, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, error,
enabled_errors)
for jxpr in branches)
_, effects = unzip2(out_trees_and_effects)
merged_error = error._add_placeholder_effects(set().union(*effects))
new_branches, out_trees_and_effects = unzip2(checkify_jaxpr(jxpr, merged_error,
enabled_errors)
for jxpr in branches)
out_trees, _ = unzip2(out_trees_and_effects)
flat_error, _ = tree_flatten(merged_error)
new_linear = (*[False] * len(flat_error), *linear)
err_and_outs = lax.cond_p.bind(
index, *flat_error, *ops,
branches=tuple(new_branches), linear=new_linear)
# we need to merge metadata across out_trees (a tuple)
# maybe there's a better way to do this, but we can use the outs
# to unflatten all trees.
err0, *out = tree_unflatten(out_trees[0], err_and_outs)
merged_metadata = err0._metadata
for tr in out_trees[1:]:
err, *_ = tree_unflatten(tr, err_and_outs)
merged_metadata = {**merged_metadata, **err._metadata}
return out, err0._replace(_metadata=merged_metadata)
error_checks[lax.cond_p] = cond_error_check
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
_, (_, effects) = checkify_jaxpr(jaxpr, error, enabled_errors)
merged_error = error._add_placeholder_effects(effects)
checked_jaxpr_, (out_tree, _) = checkify_jaxpr(jaxpr, merged_error, enabled_errors)
flat_error_vals, _ = tree_flatten(merged_error)
tomove = [False] * len(flat_error_vals) + [True] * len(consts) + [False] * (len(carry) + len(xs))
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
new_linear = (*[False] * len(flat_error_vals), *linear)
new_in_flat = [*consts, *flat_error_vals, *carry, *xs]
err_and_out = lax.scan_p.bind(
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
num_consts=len(consts), num_carry=len(carry)+len(flat_error_vals),
linear=new_linear, unroll=unroll)
err, *out = tree_unflatten(out_tree, err_and_out)
return out, err
error_checks[lax.scan_p] = scan_error_check
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts):
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)