-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
core.py
2701 lines (2226 loc) · 94.3 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
from collections import namedtuple
from contextlib import contextmanager
from dataclasses import dataclass
import functools
from functools import partial, partialmethod, total_ordering
import gc
import itertools as it
import operator
from operator import attrgetter
import threading
import types
from typing import (Any, Callable, ClassVar, DefaultDict, Dict, Generator,
Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple,
Type, Union, cast, Iterable, Hashable)
import warnings
from weakref import ref
import numpy as np
from jax._src import dtypes
from jax._src import config as jax_config
from jax._src.config import FLAGS, config
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax import linear_util as lu
from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
tuple_delete, as_hashable_function,
HashableFunction, weakref_lru_cache)
import jax._src.pretty_printer as pp
from jax._src import lib
from jax._src.lib import jax_jit
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
# -------------------- jaxprs --------------------
Effect = Hashable
Effects = Set[Effect]
no_effects: Effects = set()
ordered_effects: Set[Effect] = set()
class Jaxpr:
constvars: List[Var]
invars: List[Var]
outvars: List[Atom]
eqns: List[JaxprEqn]
effects: Effects
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
replaced with such variables while scalar constants are kept inline.
invars: list of input variables. Together, `constvars` and `invars` are
the inputs to the Jaxpr.
outvars: list of output variables.
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
"""
self.constvars = list(constvars)
self.invars = list(invars)
self.outvars = list(outvars)
self.eqns = list(eqns)
self.effects = effects
def __str__(self):
return str(pp_jaxpr(self, JaxprPpContext(), JaxprPpSettings()))
__repr__ = __str__
def pretty_print(self, *, source_info=False, print_shapes=True,
custom_pp_eqn_rules=True, name_stack=False, **kw):
doc = pp_jaxpr(self, JaxprPpContext(),
JaxprPpSettings(source_info=source_info,
print_shapes=print_shapes,
custom_pp_eqn_rules=custom_pp_eqn_rules,
name_stack=name_stack))
return doc.format(**kw)
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
def replace(self, *, constvars=None, invars=None, outvars=None, eqns=None,
effects=None):
constvars = self.constvars if constvars is None else constvars
invars = self.invars if invars is None else invars
outvars = self.outvars if outvars is None else outvars
eqns = self.eqns if eqns is None else eqns
effects = self.effects if effects is None else effects
return Jaxpr(constvars=constvars, invars=invars, outvars=outvars, eqns=eqns,
effects=effects)
def join_effects(*effects: Effects) -> Effects:
return set.union(*effects) if effects else no_effects
def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, Jaxpr):
yield v
elif isinstance(v, ClosedJaxpr):
yield v.jaxpr
def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
for eqn in jaxpr.eqns:
yield from jaxprs_in_params(eqn.params)
class ClosedJaxpr:
jaxpr: Jaxpr
consts: List[Any]
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
assert len(consts) == len(jaxpr.constvars)
self.jaxpr = jaxpr
self.consts = list(consts)
@property
def in_avals(self):
return [v.aval for v in self.jaxpr.invars]
@property
def out_avals(self):
return [v.aval for v in self.jaxpr.outvars]
@property
def literals(self):
return self.consts # backwards compatible alias
@property
def eqns(self):
return self.jaxpr.eqns
@property
def effects(self) -> Effects:
return self.jaxpr.effects
def map_jaxpr(self, f):
return ClosedJaxpr(f(self.jaxpr), self.consts)
def replace(self, *, jaxpr=None, consts=None):
jaxpr = self.jaxpr if jaxpr is None else jaxpr
consts = self.consts if consts is None else consts
return ClosedJaxpr(jaxpr, consts)
def __str__(self): return str(self.jaxpr)
def __repr__(self): return repr(self.jaxpr)
def pretty_print(self, *, source_info=False, print_shapes=True,
name_stack=False, custom_pp_eqn_rules=True, **kw):
settings = JaxprPpSettings(source_info=source_info,
print_shapes=print_shapes, name_stack=name_stack,
custom_pp_eqn_rules=custom_pp_eqn_rules)
return pp_jaxpr(self.jaxpr, JaxprPpContext(), settings).format(**kw)
def _repr_pretty_(self, p, cycle):
return p.text(self.pretty_print(use_color=True))
@curry
def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
class JaxprEqn(NamedTuple):
invars: List[Atom]
outvars: List[Var]
primitive: Primitive
params: Dict[str, Any]
effects: Effects
source_info: source_info_util.SourceInfo
def __repr__(self):
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()
def replace(self, *args, **kwargs):
return self._replace(*args, **kwargs)
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
source_info = source_info or source_info_util.new_source_info()
if config.jax_enable_checks:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)
@total_ordering
class Var:
# TODO(frostig,mattjj): We don't override __eq__ or __hash__, so comparison is
# by object id, but pretty printing might collide.
count: int
suffix: str
aval: AbstractValue
def __init__(self, count: int, suffix: str, aval: AbstractValue):
self.count = count
self.suffix = suffix
self.aval = raise_to_shaped(aval)
def __lt__(self, other):
if not isinstance(other, Var):
return NotImplemented
else:
return (self.count, self.suffix) < (other.count, other.suffix)
def __repr__(self):
return _encode_digits_alphabetic(self.count) + self.suffix
def _encode_digits_alphabetic(n):
if n == -1:
return '*'
s = ''
while len(s) == 0 or n:
n, i = n // 26, n % 26
s = chr(97 + i % 26) + s
return s
def _jaxpr_vars(jaxpr):
return it.chain(
jaxpr.invars, jaxpr.constvars,
(v for eqn in jaxpr.eqns for v in eqn.outvars))
def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
suffix: str = '') -> Callable[[AbstractValue], Var]:
"""Produce distinct variables, printed with the optional suffix.
If `jaxprs` is provided, the variables produced will be distinct from those in
any of the given jaxprs.
"""
if jaxprs is None:
start = 0
else:
all_vars = it.chain.from_iterable(_jaxpr_vars(j) for j in jaxprs)
start = 1 + max((v.count for v in all_vars), default=-1)
counter = it.count(start=start)
return lambda aval: Var(next(counter), suffix, aval)
# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
# the assignment is dropped, i.e. that an expression's output value will never
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
def __init__(self, aval: AbstractValue):
super().__init__(-1, '', aval)
def __repr__(self): return '_'
class Literal:
__slots__ = ["val", "aval", "hash"]
val: Any
aval: AbstractValue
hash: Optional[int]
def __init__(self, val, aval):
self.val = val
self.aval = aval
try:
self.hash = hash(val)
except TypeError:
if type(val) in literalable_types:
try:
self.hash = hash((val.item(), val.dtype))
except (TypeError, AttributeError, ValueError):
self.hash = None
__hash__ = None # type: ignore
def __repr__(self):
if hasattr(self, 'hash'):
return f'{self.val}'
else:
return f'Literal(val={self.val})'
literalable_types: Set[type] = set()
Atom = Union[Var, Literal]
class Primitive:
name: str
# set for multi-output primitives.
multiple_results: bool = False
# set for call primitives processed in final style.
call_primitive: bool = False
# set for map primitives processed in final style.
map_primitive: bool = False
def __init__(self, name: str):
self.name = name
def __repr__(self):
return f'{self.name}'
def bind(self, *args, **params):
assert (not config.jax_enable_checks or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
return self.bind_with_trace(find_top_trace(args), args, params)
def bind_with_trace(self, trace, args, params):
out = trace.process_primitive(self, map(trace.full_raise, args), params)
return map(full_lower, out) if self.multiple_results else full_lower(out)
def def_impl(self, impl):
self.impl = impl
return impl
def def_abstract_eval(self, abstract_eval):
self.abstract_eval = _effect_free_abstract_eval(abstract_eval) # type: ignore[assignment]
return abstract_eval
def def_effectful_abstract_eval(self, effectful_abstract_eval):
self.abstract_eval = effectful_abstract_eval # type: ignore[assignment]
return effectful_abstract_eval
def def_custom_bind(self, bind):
self.bind = bind
return bind
def impl(self, *args, **params):
raise NotImplementedError("Evaluation rule for '{}' not implemented"
.format(self.name))
def abstract_eval(self, *args, **params):
raise NotImplementedError("Abstract evaluation for '{}' not implemented"
.format(self.name))
def get_bind_params(self, params):
return [], params
def _effect_free_abstract_eval(abstract_eval):
def abstract_eval_(*args, **kwargs):
return abstract_eval(*args, **kwargs), no_effects
return abstract_eval_
# -------------------- lifting --------------------
# TODO(mattjj): replace this approach with a primitive-keyed table of rules
def traverse_jaxpr_params(f, params):
"""Applies f to each jaxpr parameter and returns a tuple of returned values."""
return {name: f(p)
for name, param in params.items()
for p in (param if isinstance(param, (tuple, list)) else [param])
if type(p) in (Jaxpr, ClosedJaxpr)}
def eval_jaxpr(jaxpr: Jaxpr, consts, *args):
def read(v: Atom) -> Any:
return v.val if isinstance(v, Literal) else env[v]
def write(v: Var, val: Any) -> None:
if config.jax_enable_checks and not config.jax_dynamic_shapes:
assert typecheck(v.aval, val), (v.aval, val)
env[v] = val
env: Dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
return map(read, jaxpr.outvars)
# -------------------- tracing --------------------
class Trace:
__slots__ = ['main', 'level', 'sublevel']
main: MainTrace
level: int
sublevel: Sublevel
def __init__(self, main: MainTrace, sublevel: Sublevel) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
def full_raise(self, val) -> Tracer:
if not isinstance(val, Tracer):
return self.pure(val)
val._assert_live()
level = self.level
sublevel = self.sublevel
if val._trace.main is self.main:
if val._trace.sublevel == sublevel:
return val
elif val._trace.sublevel < sublevel:
return self.sublift(val)
else:
raise escaped_tracer_error(
val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}")
elif val._trace.level < level:
if val._trace.sublevel > sublevel:
raise escaped_tracer_error(
val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}")
return self.lift(val)
elif val._trace.level > level:
raise escaped_tracer_error(
val, f"Can't lift level {val} to {self}")
else: # val._trace.level == self.level:
raise escaped_tracer_error(
val, f"Different traces at same level: {val}, {self}")
def pure(self, val):
raise NotImplementedError("must override")
def lift(self, tracer):
raise NotImplementedError("must override")
def sublift(self, tracer):
raise NotImplementedError("must override")
def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")
def __repr__(self):
return '{}(level={}/{})'.format(
self.__class__.__name__, self.level, self.sublevel)
def process_call(self, call_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_call to handle call-like "
"primitives")
raise NotImplementedError(msg)
def process_map(self, map_primitive, f, tracers, params):
msg = (f"{type(self)} must override process_map to handle map-like "
"primitives")
raise NotImplementedError(msg)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
msg = (f"{type(self)} must override process_custom_jvp_call "
"to handle custom_jvp primitives")
raise NotImplementedError(msg)
def process_custom_transpose(self, prim, call, tracers, **params):
msg = (f"{type(self)} must override process_custom_transpose "
"to handle custom_transpose_call primitives")
raise NotImplementedError(msg)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
def raise_as_much_as_possible(tracer) -> Tracer:
# Find effective bottom of trace stack (highest dynamic Trace on the stack).
trace_stack = thread_local_state.trace_state.trace_stack.stack
idx = next(i for i, m in enumerate(trace_stack) if m is
thread_local_state.trace_state.trace_stack.dynamic)
# Only pay attention to effective part of trace stack.
trace_stack = trace_stack[idx:]
# Lift tracer into everything in the effective stack higher than its level
for trace in trace_stack:
trace = trace.with_cur_sublevel()
if (not isinstance(tracer, Tracer) or tracer._trace.level < trace.level):
tracer = trace.full_raise(tracer)
return tracer
def escaped_tracer_error(tracer, detail=None):
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
'had a side effect, allowing for a reference to an intermediate value '
f'with shape {tracer.shape} and dtype {tracer.dtype} to escape.\n'
'JAX transformations require that functions explicitly return their '
'outputs, and disallow saving intermediate values to global state.')
dbg = getattr(tracer._trace.main, 'debug_info', None)
if dbg is not None:
msg += ('\nThe function being traced when the value leaked was '
f'{dbg.func_src_info} traced for {dbg.traced_for}.')
line_info = getattr(tracer, '_line_info', None)
if line_info is not None:
divider = '\n' + '-'*30 + '\n'
msg += divider
msg += ('The leaked intermediate value was created on line '
f'{source_info_util.summarize(line_info)}. ')
msg += divider
if num_frames > 0:
msg += (f'When the value was created, the final {num_frames} stack '
'frames (most recent last) excluding JAX-internal frames were:')
msg += divider + source_info_util.summarize(
line_info, num_frames=num_frames) + divider
msg += ('\nTo catch the leak earlier, try setting the environment variable '
'JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context '
'manager.')
if detail:
msg += f'Detail: {detail}'
return UnexpectedTracerError(msg)
class Tracer:
__array_priority__ = 1000
__slots__ = ['_trace', '__weakref__', '_line_info']
def __array__(self, *args, **kw):
raise TracerArrayConversionError(self)
def __index__(self):
raise TracerIntegerConversionError(self)
def __init__(self, trace: Trace):
self._trace = trace
def __iter__(self):
return iter(self.aval._iter(self))
def __len__(self):
return self.aval._len(self)
@property
def aval(self):
raise NotImplementedError("must override")
def _assert_live(self) -> None:
pass # Override for liveness checking
# Python looks up special methods only on classes, not instances. This means
# these methods needs to be defined explicitly rather than relying on
# __getattr__.
def __neg__(self): return self.aval._neg(self)
def __pos__(self): return self.aval._pos(self)
def __eq__(self, other): return self.aval._eq(self, other)
def __ne__(self, other): return self.aval._ne(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __le__(self, other): return self.aval._le(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __ge__(self, other): return self.aval._ge(self, other)
def __abs__(self): return self.aval._abs(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __sub__(self, other): return self.aval._sub(self, other)
def __rsub__(self, other): return self.aval._rsub(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __div__(self, other): return self.aval._div(self, other)
def __rdiv__(self, other): return self.aval._rdiv(self, other)
def __truediv__(self, other): return self.aval._truediv(self, other)
def __rtruediv__(self, other): return self.aval._rtruediv(self, other)
def __floordiv__(self, other): return self.aval._floordiv(self, other)
def __rfloordiv__(self, other): return self.aval._rfloordiv(self, other)
def __divmod__(self, other): return self.aval._divmod(self, other)
def __rdivmod__(self, other): return self.aval._rdivmod(self, other)
def __mod__(self, other): return self.aval._mod(self, other)
def __rmod__(self, other): return self.aval._rmod(self, other)
def __pow__(self, other): return self.aval._pow(self, other)
def __rpow__(self, other): return self.aval._rpow(self, other)
def __matmul__(self, other): return self.aval._matmul(self, other)
def __rmatmul__(self, other): return self.aval._rmatmul(self, other)
def __and__(self, other): return self.aval._and(self, other)
def __rand__(self, other): return self.aval._rand(self, other)
def __or__(self, other): return self.aval._or(self, other)
def __ror__(self, other): return self.aval._ror(self, other)
def __xor__(self, other): return self.aval._xor(self, other)
def __rxor__(self, other): return self.aval._rxor(self, other)
def __invert__(self): return self.aval._invert(self)
def __lshift__(self, other): return self.aval._lshift(self, other)
def __rlshift__(self, other): return self.aval._rlshift(self, other)
def __rshift__(self, other): return self.aval._rshift(self, other)
def __rrshift__(self, other): return self.aval._rrshift(self, other)
def __getitem__(self, idx): return self.aval._getitem(self, idx)
def __nonzero__(self): return self.aval._nonzero(self)
def __bool__(self): return self.aval._bool(self)
def __int__(self): return self.aval._int(self)
def __long__(self): return self.aval._long(self)
def __hex__(self): return self.aval._hex(self)
def __oct__(self): return self.aval._oct(self)
def __float__(self): return self.aval._float(self)
def __complex__(self): return self.aval._complex(self)
def __copy__(self): return self.aval._copy(self)
def __deepcopy__(self, memo): return self.aval._deepcopy(self, memo)
# raises a useful error on attempts to pickle a Tracer.
def __reduce__(self):
raise ConcretizationTypeError(
self, ("The error occurred in the __reduce__ method, which may "
"indicate an attempt to serialize/pickle a traced value."))
# raises the better error message from ShapedArray
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
# NumPy also only looks up special methods on classes.
def __array_module__(self, types): return self.aval._array_module(self, types)
def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert not config.jax_enable_checks or name != "aval"
try:
attr = getattr(self.aval, name)
except KeyError as err:
raise AttributeError(
f"{self.__class__.__name__} has no attribute {name}"
) from err
else:
t = type(attr)
if t is aval_property:
return attr.fget(self)
elif t is aval_method:
return types.MethodType(attr.fun, self)
else:
return attr
def _pretty_print(self):
base = pp.text(f'Traced<{self.aval}>with<{self._trace}>')
contents = [(name, attr._pretty_print() if isinstance(attr, Tracer)
else pp.text(repr(attr))) for name, attr in self._contents()]
if contents:
base = pp.group(pp.nest(2, pp.concat([
base, pp.text(' with'), pp.brk(), pp.join(pp.brk(), [
pp.text(f'{name} = ') + pp_payload
for name, pp_payload in contents])
])))
return base
def __repr__(self):
return self._pretty_print().format()
def _contents(self):
try:
return [(name, getattr(self, name)) for name in self.__slots__]
except AttributeError:
return ()
def _origin_msg(self) -> str:
return ""
# these can be used to set up forwarding of properties and instance methods from
# Tracer instances to the underlying avals
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])
class EvalTrace(Trace):
# See comments in https://github.com/google/jax/pull/3370
def pure(self, x): return x
lift = sublift = pure
def process_primitive(self, primitive, tracers, params):
return primitive.impl(*tracers, **params)
def process_call(self, primitive, f, tracers, params):
return primitive.impl(f, *tracers, **params)
process_map = process_call
def process_custom_transpose(self, primitive, call, tracers, **_):
del primitive
with new_sublevel():
return call.call_wrapped(*tracers)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
del primitive, jvp # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_):
del primitive, fwd, bwd # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)
class MainTrace:
level: int
trace_type: Type[Trace]
payload: Dict[str, Any]
def __init__(self, level, trace_type, **payload) -> None:
self.level = level
self.trace_type = trace_type
self.payload = payload
def __repr__(self) -> str:
return f"MainTrace({self.level},{self.trace_type.__name__})"
def __hash__(self) -> int:
return hash((self.level, self.trace_type))
def __eq__(self, other: object) -> bool:
return (isinstance(other, MainTrace) and
self.level == other.level and
self.trace_type == other.trace_type and
self.payload == other.payload)
def with_cur_sublevel(self):
return self.trace_type(self, cur_sublevel(), **self.payload)
class TraceStack:
# See comments in https://github.com/google/jax/pull/3370
stack: List[MainTrace]
dynamic: MainTrace
def __init__(self):
eval_trace = MainTrace(0, EvalTrace)
self.stack = [eval_trace]
self.dynamic = eval_trace
def next_level(self) -> int:
return len(self.stack)
def push(self, main_trace: MainTrace) -> None:
self.stack.append(main_trace)
def pop(self) -> None:
self.stack.pop()
def __repr__(self) -> str:
stack_str = map(' {}\n'.format, self.stack[::-1])
return f'Trace stack\n{stack_str}\n{self.dynamic}'
def copy(self):
new = self.__new__(TraceStack)
new.stack = self.stack[:]
new.dynamic = self.dynamic
return new
@total_ordering
class Sublevel:
def __init__(self, level: int):
self.level = level
def __repr__(self):
return str(self.level)
def __eq__(self, other):
return type(other) is Sublevel and self.level == other.level
def __lt__(self, other):
return type(other) is Sublevel and self.level < other.level
AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace'])
AxisName = Hashable
no_axis_name = object()
class TraceState:
trace_stack: TraceStack
substack: List[Sublevel]
axis_env: List[AxisEnvFrame]
def __init__(self) -> None:
self.trace_stack = TraceStack()
self.substack = [Sublevel(0)]
self.axis_env = []
def copy(self):
new = self.__new__(TraceState)
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
new.axis_env = self.axis_env[:]
return new
def _update_thread_local_jit_state(dynamic):
# Copies the MainTrace instance, removing any .debug_info or .jaxpr_stack
# fields that should not be kept alive as part of a cache key.
# TODO(mattjj): split debug_info and jaxpr_stack out of MainTrace.
# TODO(mattjj): add a test that verifies that JIT-ted functions are not kept
# alive by the JIT cache, particularly for nested JIT-ted functions.
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
jax_config.update_thread_local_jit_state(dynamic_trace_state=copy)
# The global state of the tracer is accessed by a thread-local object.
# This allows concurrent tracing in separate threads; passing traced objects
# between threads is forbidden.
class ThreadLocalState(threading.local):
def __init__(self):
self.trace_state = TraceState()
thread_local_state = ThreadLocalState()
def _initialize_jax_jit_thread_local_state():
"""Initializes the C++ thread-local context.
When the user spawns threads, the C++ `jax_jit.thread_local_state` is None.
The C++ accessor calls this function if it realizes the thread_local_state
is None (which means it's not yet initialized for this thread).
This function does not live in `config.py`, to prevent circular imports.
"""
tls = jax_jit.thread_local_state()
if tls.extra_jit_context is None:
dynamic = thread_local_state.trace_state.trace_stack.dynamic
copy = MainTrace(dynamic.level, dynamic.trace_type, **dynamic.payload)
tls.extra_jit_context = jax_config._ThreadLocalExtraJitContext(
dynamic_trace_state=copy)
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.11
if lib.xla_extension_version >= 70:
jax_jit.set_thread_local_state_initialization_callback(
_initialize_jax_jit_thread_local_state)
def trace_state_clean() -> bool:
trace_state = thread_local_state.trace_state
return (trace_state.substack == [Sublevel(0)] and
trace_state.axis_env == [] and
trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and
trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace))
def reset_trace_state() -> bool:
"""Resets the global trace state and returns True if it was already clean."""
if not trace_state_clean():
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
return True
def cur_sublevel() -> Sublevel:
return thread_local_state.trace_state.substack[-1]
TRACER_LEAK_DEBUGGER_WARNING = """\
JAX check_tracer_leaks behavior can trigger false positives when used with a debugger.
To avoid false positives and silence this warning, you can disable thread tracing using
the following:
import threading
threading.current_thread().pydev_do_not_trace = True
"""
def maybe_find_leaked_tracers(x: Optional[Union[MainTrace, Sublevel]]
) -> List[Tracer]:
"""Find the leaked tracers holding a reference to the MainTrace or SubLevel.
It's possible there's none! eg. there's some cases where JAX itself holds a
reference to `x` inside of a lambda closure, and no tracers were leaked
by the user. In this case an empty list is returned.
"""
if not getattr(threading.current_thread(), 'pydev_do_not_trace', True):
warnings.warn(TRACER_LEAK_DEBUGGER_WARNING)
# Trigger garbage collection to filter out cyclical dependency false positives
gc.collect()
traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x)))
tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces)))
return tracers
def leaked_tracer_error(name: str, t, tracers: List[Tracer]) -> Exception:
assert tracers
msgs = '\n\n'.join(f'{tracer}{tracer._origin_msg()}' for tracer in tracers)
return Exception(f'Leaked {name} {t}. Leaked tracer(s):\n\n{msgs}\n')
@contextmanager
def new_main(trace_type: Type[Trace],
dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
main = MainTrace(level, trace_type, **payload)
stack.push(main)
if dynamic:
prev_dynamic, stack.dynamic = stack.dynamic, main
_update_thread_local_jit_state(stack.dynamic)
try:
yield main
finally:
stack.pop()
if dynamic:
stack.dynamic = prev_dynamic
_update_thread_local_jit_state(stack.dynamic)
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type)
prev_dynamic, stack.dynamic = stack.dynamic, main
prev_base, stack.stack[0] = stack.stack[0], main
_update_thread_local_jit_state(stack.dynamic)
try:
yield main
finally:
stack.dynamic = prev_dynamic
stack.stack[0] = prev_base
_update_thread_local_jit_state(stack.dynamic)
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
@contextmanager
def ensure_compile_time_eval():
"""Context manager to ensure evaluation at trace/compile time (or error).
Some JAX APIs like ``jax.jit`` and ``jax.lax.scan`` involve staging, i.e.
delaying the evaluation of numerical expressions (like jax.numpy function
applications) so that instead of performing those computations eagerly while
evaluating the corresponding Python expressions, their computation is carried
out separately, e.g. after optimized compilation. But this delay can be
undesirable. For example, numerical values might be needed to evaluate Python
control flow and so their evaluation cannot be delayed. As another example, it
may be beneficial to ensure compile time evaluation (or "constant folding")
for performance reasons.
This context manager ensures that JAX computations are evaluated eagerly. If
eager evaluation is not possible, a ``ConcretizationError`` is raised.
Here's a contrived example::
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
with jax.ensure_compile_time_eval():
y = jnp.sin(3.0)
z = jnp.sin(y)
z_positive = z > 0
if z_positive: # z_positive is usable in Python control flow
return jnp.sin(x)
else:
return jnp.cos(x)
Here's a real-world example from https://github.com/google/jax/issues/3974::
import jax
import jax.numpy as jnp
from jax import random
@jax.jit
def jax_fn(x):
with jax.ensure_compile_time_eval():
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
y2 = y @ y
x2 = jnp.sum(y2) * x
return x2
A similar behavior can often be achieved simply by 'hoisting' the constant
expression out of the corresponding staging API::
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
@jax.jit
def jax_fn(x):
y2 = y @ y
x2 = jnp.sum(y2)*x
return x2
But in some cases it can be more convenient to use this context manager.
"""
with new_base_main(EvalTrace):
yield
eval_context = ensure_compile_time_eval # alias, backward compatibility
@contextmanager