/
partial_eval.py
1180 lines (1006 loc) · 49.1 KB
/
partial_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools as it
from collections import namedtuple
import contextlib
import functools
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
List, Union, cast, Type, no_type_check)
from weakref import ref
import numpy as np
from .. import core
from .. import dtypes
from .. import linear_util as lu
from ..abstract_arrays import ConcreteArray, raise_to_shaped
from ..ad_util import Zero
from ..util import (unzip2, safe_zip, safe_map, toposort, partial, split_list,
cache)
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
dropvar)
from .. import source_info_util
from ..config import config
map = safe_map
zip = safe_zip
def identity(x): return x
class PartialVal(tuple):
"""Partial value: either a known value or an unknown (abstract) value.
Represented as a pair `(aval_opt, const)` of one of two kinds:
* `(None, <Constant>)` indicates a known value, either a Python regular
value, or a Tracer.
* `(<AbstractValue>, *)` indicates an unknown value characterized by an
abstract value.
"""
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
pv, const = xs
if not core.skip_checks:
# type checks
assert isinstance(pv, (AbstractValue, type(None))), xs
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
# invariant checks
if isinstance(pv, AbstractValue):
assert get_aval(const) == core.abstract_unit, xs
return tuple.__new__(cls, xs)
@classmethod
def known(cls, const: core.Value) -> 'PartialVal':
return PartialVal((None, const))
@classmethod
def unknown(cls, aval: AbstractValue) -> 'PartialVal':
return PartialVal((aval, core.unit))
def is_known(self) -> bool:
return self[0] is None
def get_known(self) -> Optional[core.Value]:
"""Get the known value, if known, else None."""
return self[1] if self[0] is None else None
def get_aval(self) -> AbstractValue:
"""Get AbstractValue directly (if unknown) or from the constant (known)."""
known = self.get_known()
if known is not None:
return get_aval(known)
else:
return self[0]
def merge_with_known(self, val: core.Value) -> core.Value:
"""Either the stored known value, or the given 'val'."""
known = self.get_known()
return known if known is not None else val
class JaxprTrace(Trace):
def pure(self, val) -> 'JaxprTracer':
return self.new_const(val)
def lift(self, val) -> 'JaxprTracer':
return self.new_const(val)
def sublift(self, val) -> 'JaxprTracer':
return JaxprTracer(self, val.pval, FreeVar(val))
def new_const(self, val) -> 'JaxprTracer':
if isinstance(val, Tracer) and val._trace.level == self.level:
raise Exception
return JaxprTracer(self, PartialVal.known(val), unit)
def new_instantiated_literal(self, val) -> 'JaxprTracer':
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), Literal(val))
def new_instantiated_const(self, val) -> 'JaxprTracer':
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
def new_arg(self, pval: PartialVal) -> 'JaxprTracer':
const = pval.get_known()
if const is None:
return JaxprTracer(self, pval, LambdaBinding())
else:
return self.new_const(const)
def instantiate_const(self, tracer) -> Tracer:
const = tracer.pval.get_known()
if const is None:
return tracer
else:
if type(const) in core.literalable_types and np.shape(const) == ():
return self.new_instantiated_literal(const)
else:
return self.new_instantiated_const(const)
def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer':
const = tracer.pval.get_known()
if const is None:
return tracer
else:
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
def process_primitive(self, primitive, tracers, params):
if primitive in custom_partial_eval_rules:
return custom_partial_eval_rules[primitive](self, *tracers, **params)
else:
return self.default_process_primitive(primitive, tracers, params)
def default_process_primitive(self, primitive, tracers, params):
"""By default, if all the input tracers are known, then execute the primitive
and all the ouputs are known. Otherwise, all the outputs are unknown."""
consts = [t.pval.get_known() for t in tracers]
if all(c is not None for c in consts):
return primitive.bind(*consts, **params)
tracers = map(self.instantiate_const, tracers)
avals = [t.aval for t in tracers]
out_aval = primitive.abstract_eval(*avals, **params)
source = source_info_util.current()
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in out_aval]
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
for t in out_tracers: t.recipe = eqn
return out_tracers
else:
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
params, source)
return out_tracer
# We use process_call to handle both call and map primitives.
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
if not config.omnistaging_enabled:
if (self.main.trace_type is StagingJaxprTrace # type: ignore
and primitive in staged_out_calls): # type: ignore
tracers = map(self.instantiate_const_abstracted, tracers)
if primitive in call_partial_eval_rules:
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
in_pvals = [t.pval for t in tracers]
if primitive.map_primitive:
mapped_aval = partial(core.mapped_aval, params['axis_size'])
in_pvals = [pval if pval.is_known() or not is_mapped
else PartialVal.unknown(mapped_aval(pval[0]))
for pval, is_mapped in zip(in_pvals, params['mapped_invars'])]
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
f, in_pvals, partial(primitive.bind, **params))
if primitive.map_primitive:
unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
out_pvals = [pval if pval.is_known()
else PartialVal.unknown(unmapped_aval(pval[0]))
for pval in out_pvals]
# Avoid staging out trivial calls, but maps may involve broadcasting.
if not jaxpr.eqns and not primitive.map_primitive:
env = {core.unitvar: core.unit}
map(env.setdefault, jaxpr.invars, (*env_tracers, *tracers))
map(env.setdefault, jaxpr.constvars, consts)
return [v.val if type(v) is Literal
else pval.get_known() if pval.is_known()
else env[v]
for v, pval in zip(jaxpr.outvars, out_pvals)]
# Skip known invars and outvars, and lift constants as regular invars
in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers))
out_unknowns = tuple(not pval.is_known() for pval in out_pvals)
jaxpr = _drop_invars(jaxpr, in_knowns)
jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
# Known tracers get propagated as if they were constants
known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals
if pval.is_known()]
# Unknown tracers need to have the jaxpr set up as their recipe
unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals
if not pval.is_known()]
unknown_tracers_in = [t for t in tracers if not t.pval.is_known()]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in)
# Set up new params
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
if primitive.map_primitive:
mapped_invars = params['mapped_invars']
new_mapped_invars = ((True,) * len(const_tracers) +
(False,) * len(env_tracers) +
tuple(v for v, t in zip(mapped_invars, tracers)
if not t.pval.is_known()))
new_params = dict(new_params, mapped_invars=new_mapped_invars)
update_params = call_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, [not t.pval.is_known() for t in tracers])
eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params,
source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns)
process_map = process_call
# We use post_process_call to handle both call and map primitives.
def post_process_call(self, primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers)
out = out_pv_consts + consts
del consts, out_pv_consts
main = self.main
if primitive.map_primitive:
sz = params['axis_size']
out_pvs = [None if pv is None else core.unmapped_aval(sz, pv)
for pv in out_pvs]
def todo(x):
n = len(jaxpr.outvars)
out_pv_consts, consts = x[:n], x[n:]
trace = JaxprTrace(main, core.cur_sublevel())
const_tracers = map(trace.new_instantiated_const, consts)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
in_tracers = (*const_tracers, *map(trace.full_raise, env))
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
if primitive.map_primitive:
new_mapped_invars = (True,) * len(const_tracers) + (False,) * len(env)
new_params = dict(new_params, mapped_invars=new_mapped_invars)
update_params = call_param_updaters.get(primitive)
if update_params:
new_params = update_params(new_params, [])
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
source_info_util.current())
for t in out_tracers:
t.recipe = eqn
return out_tracers
return out, todo
post_process_map = post_process_call
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]]):
"""Partially evaluate f on a sequence of PartialVals."""
in_avals, in_consts = unzip2(pvals)
f = trace_to_subjaxpr(f, self.main, False)
f, aux = partial_eval_wrapper(f, tuple(in_avals))
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvs = map(PartialVal, zip(out_avals, out_consts))
env_tracers = map(self.full_raise, env)
return jaxpr, out_pvs, consts, env_tracers
@lu.transformation_with_aux
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
py_args = map(PartialVal, zip(pvs, consts))
jaxpr, (out_pvals, consts, env) = yield (py_args,), {}
out_pvs, out_consts = unzip2(out_pvals)
out = tuple(out_consts) + tuple(consts)
yield out, (out_pvs, jaxpr, env)
custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
call_param_updaters: Dict[core.Primitive, Callable] = {}
def abstract_eval_fun(fun, *avals, **params):
if config.omnistaging_enabled:
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
else:
pvals_in = [PartialVal.unknown(a) for a in avals]
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
instantiate=True, stage_out=True) # type: ignore
avals_out, _ = unzip2(pvals_out)
for aval_out in avals_out:
assert isinstance(aval_out, AbstractValue) # instantiate=True
return avals_out
JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar',
'ConstVar', Literal, core.Unit]
class JaxprTracer(Tracer):
__slots__ = ['pval', 'recipe']
def __init__(self, trace: JaxprTrace, pval: PartialVal,
recipe: Optional[JaxprTracerRecipe]):
assert isinstance(pval, PartialVal)
pv, const = pval
if isinstance(const, Tracer) and const._trace.level >= trace.level:
raise core.escaped_tracer_error(
"Tracer from a higher level: {} in trace {}".format(const, trace))
self._trace = trace
self.pval = pval
self.recipe = recipe
def __repr__(self):
return 'Traced<{}:{}>'.format(self.aval, self._trace)
@property
def aval(self) -> AbstractValue:
return self.pval.get_aval()
@property
def parents(self) -> Sequence['JaxprTracer']:
if isinstance(self.recipe, JaxprEqnRecipe):
return self.recipe.invars
else:
return []
def full_lower(self):
known = self.pval.get_known()
if known is not None:
return core.full_lower(known)
else:
return self
def is_known(self):
return self.pval.is_known()
# TODO(necula): this could return a ClosedJaxpr with out_pvals
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
"""Traces a function into a Jaxpr, given PartialVals for inputs.
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
for the outputs. The intermediate values that depend only on known inputs and
are needed to compute the output of `jaxpr` are in `consts` and are passed in
as the constvars of the `jaxpr`. The handling of the known outputs depends on
`instantiate`.
For example, given `fun` defined as follows::
def fun(ki, ui): # ki will be a known input in this example
ka = ki + 2
kb = ka + 3
return (kb, ui + ka)
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
computation that depends on unknown inputs is `ui + ka` and will be the only
computation in the body of the `jaxpr`. This computation depends on the known
intermediate value `ka`, which will be computed statically. Currently, such
constants are either embedded in the Jaxpr if they are scalars, or passed as a
constvar to `jaxpr`, and then the value of the actual constant will be in
`consts`:
When `instantiate=False` we get::
jaxpr =
{ lambda ka ; ki ui.
let c = add ui ka
in (*, c) } # known outputs are `*`
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
consts = [3] # the constant for `ka`
When `instantiate=True` we get::
jaxpr =
{ lambda ka kb ; ki ui.
let c = add ui ka
in (kb, c) } # known output are explicit
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
consts = [3, 6] # values for `ka` and `kb` constvars
"""
with core.new_main(JaxprTrace) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
del main
return jaxpr, out_pvals, consts
@lu.transformation
def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals
trace = JaxprTrace(main, core.cur_sublevel())
in_tracers = map(trace.new_arg, pvals)
ans = yield in_tracers, {}
instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
out_pvals = [t.pval for t in out_tracers]
del trace, in_tracers, out_tracers
yield jaxpr, (out_pvals, consts, env)
def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
if instantiate:
return trace.instantiate_const(trace.full_raise(tracer))
else:
return tracer
FreeVar = namedtuple('FreeVar', ['val'])
ConstVar = namedtuple('ConstVar', ['val'])
LambdaBinding = namedtuple('LambdaBinding', [])
class JaxprEqnRecipe(NamedTuple):
eqn_id: object
invars: Sequence[JaxprTracer]
outvars: 'Sequence[ref[JaxprTracer]]'
primitive: core.Primitive
params: Dict[str, Any]
source_info: Optional[source_info_util.Traceback]
def new_eqn_recipe(invars: Sequence[JaxprTracer],
outvars: Sequence[JaxprTracer],
primitive: core.Primitive,
params: Dict[str, Any],
source_info: Optional[source_info_util.Traceback]
) -> JaxprEqnRecipe:
"""Constructs a new JaxEqnRecipe.
Params:
invars: the tracers for the primitive inputs.
outvars: the tracers for the primitive outputs.
primitive: the primitive.
params: the primitive params
"""
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
if primitive.map_primitive:
assert ("mapped_invars" in params and
len(params["mapped_invars"]) == len(params["call_jaxpr"].invars))
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
params, source_info)
def recipe_to_eqn(getvar: Callable[[JaxprTracer], core.Atom],
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
invars = [getvar(t) for t in in_tracers]
outvars = [core.dropvar if t is None else cast(core.Var, getvar(t))
for t in out_tracers]
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
def tracers_to_jaxpr(
in_tracers: List[JaxprTracer],
out_tracers: List[JaxprTracer]
) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
"""Constructs Jaxpr given tracers for inputs and outputs.
Params:
in_tracers: the tracers that were created for the function inputs
out_tracers: the tracers that were output by the function.
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
the `constvars` in the returned Jaxps, and a list of environment values.
The vars for the environment values have been prepended to the Jaxpr's
`invars`.
"""
newvar = core.gensym()
t_to_var: Dict[int, core.Atom] = {}
def getvar(t: JaxprTracer) -> core.Atom:
var = t_to_var.get(id(t))
if var is None:
aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
var = t_to_var[id(t)] = newvar(aval)
return var
sorted_tracers = toposort(out_tracers)
invars = map(getvar, in_tracers)
eqns: List[core.JaxprEqn] = []
env: Dict[core.Var, Any] = {}
consts: Dict[core.Var, Any] = {}
const_to_var: Dict[int, core.Var] = {}
def getconstvar(c):
var = const_to_var.get(id(c))
if var is None:
var = const_to_var[id(c)] = newvar(get_aval(c))
return var
processed_eqn_ids = set()
for t in sorted_tracers:
recipe = t.recipe
if isinstance(recipe, JaxprEqnRecipe):
if recipe.eqn_id not in processed_eqn_ids:
eqns.append(recipe_to_eqn(getvar, recipe))
processed_eqn_ids.add(recipe.eqn_id)
elif isinstance(recipe, LambdaBinding):
if not any(t is in_tracer for in_tracer in in_tracers):
raise core.escaped_tracer_error(
"Tracer not among input tracers {}".format(t))
assert in_tracers, "Lambda binding with no args"
elif isinstance(recipe, FreeVar):
env[cast(core.Var, getvar(t))] = recipe.val
elif isinstance(recipe, ConstVar):
v = t_to_var[id(t)] = getconstvar(recipe.val)
consts[v] = recipe.val
elif isinstance(recipe, Literal):
t_to_var[id(t)] = recipe
elif recipe is unit:
t_to_var[id(t)] = unitvar
else:
raise TypeError(recipe)
env_vars, env_vals = unzip2(env.items())
const_vars, const_vals = unzip2(consts.items())
# The env_vars are pre-pended to the invars
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
core.skip_checks or core.check_jaxpr(jaxpr)
return jaxpr, const_vals, env_vals
@cache()
def convert_constvars_jaxpr(jaxpr: Jaxpr):
"""Moves the constvars to the start of invars."""
core.skip_checks or core.check_jaxpr(jaxpr)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
instantiate: Union[bool, Sequence[bool]],
) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
"""Specializes a Jaxpr given an indication of which inputs are known.
Returns: (jaxpr_known, jaxpr_unknown, out_unknowns).
`out_unknowns` specifies which outputs are unknown (depend on some unknown inputs).
`jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs,
and performs *all* the computation in `jaxpr` that depends only on the known inputs.
Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`,
appended with the known residuals (the intermediate computations in `jaxpr`
that depend only on known inputs and that are needed to compute the unknown outputs).
`jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals
computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known
outputs replaced by `*`.
Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively
unknown inputs into:
jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *)
let _, uout = jaxpr_unknown(ki, ui, kresidual)
in (kout, uout)
For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2
in (ki + 3, ui + ka)"
then
`jaxpr_known` = lambda ki, ui: let ka = ki + 2
in (ki + 3, *, ka)
'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka)
Note that if instantiate is True for a given output, then jaxpr_known always returns a
unit in its place. So when instantiate is True, the expectation is the one doesn't
run `jaxpr_known` for any of its outputs, but only to generate residuals that will allow
to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
simply get passed as residual constants and returned immediately.
"""
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
cell = []
def fun(*vals):
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
return out_consts_2 + consts_2
# For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the
# known inputs.
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
(out_pvs_2, jaxpr_2, num_res), = cell
assert len(jaxpr_2.constvars) == num_res
# jaxpr :: a -> b
# jaxpr_1 :: a1 -> [b1, res]
# jaxpr_2 :: res | a2 -> b2
# jaxpr_2 :: [a2, res] -> b2
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
if not unknown:
var.aval = abstract_unit
uk_out = [pv is not None for pv in out_pvs_2]
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
# out_avals_1 and in_avals_2 need the residuals added
res_avals = out_avals[len(jaxpr.out_avals):]
assert len(res_avals) == num_res
out_avals_1 = [*out_avals_1, *res_avals]
in_avals_2 = [*in_avals_2, *res_avals]
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
remat_call_p = core.CallPrimitive('remat_call')
remat_call = remat_call_p.bind
remat_call_p.def_impl(core.call_impl)
def _remat_partial_eval(trace, _, f, tracers, params):
concrete = params['concrete']
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
# the function being called, not just for the unknown parts. To do that, we
# instantiate all the input tracers as constants in the jaxpr being formed.
# Those tracers might have concrete avals, and doing abstract interpretation
# on concrete avals engenders a tradeoff: it allows data-dependent Python
# control flow to work, but it can in some cases lead to redundant FLOPs (done
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
# `concrete` parameter to switch this behavior, and if `concrete` is False
# then we raise the avals to the Shaped level.
if concrete:
instantiated_tracers = map(trace.instantiate_const, tracers)
else:
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvals = [t.pval for t in instantiated_tracers]
if config.omnistaging_enabled:
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
else:
with core.initial_style_staging(): # type: ignore
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
# Convert consts to inputs, since they may contain Tracer instances.
jaxpr = convert_constvars_jaxpr(jaxpr)
const_tracers = map(trace.new_instantiated_const, consts)
# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
if config.omnistaging_enabled:
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
closed_jaxpr, in_unknowns, instantiate=False) # type: ignore
else:
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
closed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
# Next, we need values for the outputs that should be known. Since consts
# weren't passed through Python for evaluation, we need to evaluate jaxpr_known,
# minus the residual outputs that we don't need. When `concrete=True`, as an
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
# produced concrete avals at the output, simply by using those as computed
# values. For the use case of inverse-mode ad in op-by-op ("eager mode")
# evaluation, all the primal outputs should be concrete (thus not recomputed).
to_compute = [type(pval[0]) is not ConcreteArray
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
num_outputs = len(jaxpr_unknown.out_avals)
num_res = len(jaxpr_known.out_avals) - num_outputs
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
# Known outputs should keep propagating as constants
assert all(pv.is_known() for pv in out_known_pvals)
known_output_tracers = [trace.new_const(pval.get_known())
for pval in out_known_pvals]
# Unknown outputs get wrapped in tracers with the appropriate recipe
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
for out_pval in out_unknown_pvals]
# dce jaxpr outputs
new_jaxpr = _dce_jaxpr(closed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
new_params = dict(params, call_jaxpr=new_jaxpr)
# set up eqn for unknown outputs
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p, new_params,
source_info_util.current())
for t in unknown_output_tracers: t.recipe = eqn
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
def _partition_knowns(pvals, unknowns: Sequence[bool]):
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
[e for e, unknown in zip(pvals, unknowns) if unknown])
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJaxpr:
new_jaxpr = _dce_open_jaxpr(closed_jaxpr.jaxpr, tuple(outputs), drop_outputs)
return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
@cache()
def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan, call, or other higher-order primitives.
# TODO(mattjj): better DCE
if drop_outputs:
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
else:
new_outvars = [var if output else unitvar
for var, output in zip(jaxpr.outvars, outputs)]
needed_vars = {v for v in new_outvars if type(v) is not Literal}
new_eqns = []
for eqn in jaxpr.eqns[::-1]:
if set(eqn.outvars) & needed_vars:
new_eqns.append(eqn)
needed_vars.update(v for v in eqn.invars if type(v) is not Literal)
new_eqns = new_eqns[::-1]
return core.Jaxpr(jaxpr.constvars, jaxpr.invars,
new_outvars, new_eqns)
@cache()
def _drop_invars(jaxpr: Jaxpr, drop: Tuple[bool, ...]):
return core.Jaxpr(jaxpr.constvars, [v for v, d in zip(jaxpr.invars, drop) if not d],
jaxpr.outvars, jaxpr.eqns)
def _reconstruct_pval(pval1: PartialVal, const2: core.Value):
pv1, _ = pval1
if pval1.is_known():
return pval1
else:
if type(pv1) is ConcreteArray:
return PartialVal.known(pv1.val) # pytype: disable=attribute-error
else:
return PartialVal.known(const2)
def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr:
"""Reorder the `invars` to move to front the ones for which `to_move` is True."""
assert not closed_jaxpr.jaxpr.constvars
assert len(closed_jaxpr.in_avals) == len(to_move)
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
new_jaxpr = core.Jaxpr((), new_invars, closed_jaxpr.jaxpr.outvars,
closed_jaxpr.jaxpr.eqns)
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
return new_closed_jaxpr
def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence:
return ([elt for elt, move in zip(lst, to_move) if move] +
[elt for elt, move in zip(lst, to_move) if not move])
class DynamicJaxprTracer(core.Tracer):
__slots__ = ['aval', 'line_info']
def __init__(self, trace, aval, line_info=None):
self._trace = trace
self.aval = aval
self.line_info = line_info
def full_lower(self):
return self
def _contents(self):
return ()
def _origin_msg(self):
invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self)
if invar_pos:
origin = (f"While tracing the function {self._trace.main.source_info}, "
"this concrete value was not available in Python because it "
"depends on the value of the arguments to "
f"{self._trace.main.source_info} at flattened positions {invar_pos}, "
"and the computation of these values is being staged out "
"(that is, delayed rather than executed eagerly).\n\n"
"You can use transformation parameters such as `static_argnums` "
"for `jit` to avoid tracing particular arguments of transformed "
"functions, though at the cost of more recompiles.")
elif progenitor_eqns:
msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n"
f" from line {source_info_util.summarize(eqn.source_info)}"
for eqn in progenitor_eqns]
origin = (f"While tracing the function {self._trace.main.source_info}, "
"this value became a tracer due to JAX operations on these lines:"
"\n\n" + "\n\n".join(msts))
else:
origin = ("The error occured while tracing the function "
f"{self._trace.main.source_info}.")
return origin
def _assert_live(self) -> None:
if not self._trace.main.jaxpr_stack: # type: ignore
msg = f"tracer created on line {source_info_util.summarize(self.line_info)}"
raise core.escaped_tracer_error(msg)
class JaxprStackFrame:
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
'tracers', 'eqns', 'invars']
def __init__(self):
self.newvar = core.gensym()
self.tracer_to_var = {}
self.constid_to_var = {}
self.constvar_to_val = {}
self.tracers = [] # circ refs, frame->tracer->trace->main->frame,
self.eqns = [] # cleared when we pop frame from main
self.invars = []
def to_jaxpr(self, in_tracers, out_tracers):
invars = [self.tracer_to_var[id(t)] for t in in_tracers]
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
# core.skip_checks or core.check_jaxpr(jaxpr)
out_avals = [t.aval for t in out_tracers]
return jaxpr, out_avals, constvals
def find_progenitors(self, tracer):
var = self.tracer_to_var.get(id(tracer))
if not var:
return []
active_vars = {var}
for eqn in self.eqns[::-1]:
produced = set(eqn.outvars) & active_vars
if produced:
active_vars.difference_update(produced)
active_vars.update(eqn.invars)
invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars]
constvars = active_vars & set(self.constvar_to_val)
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns
def _inline_literals(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
newvar = core.gensym()
class var(dict):
def __missing__(self, v):
new_v = self[v] = newvar(v.aval)
return new_v
var = var()
def lit(var: core.Var) -> Optional[Any]:
val = consts.get(var)
if type(val) in core.literalable_types and not np.shape(val):
return Literal(val)
else:
return None
used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
new_constvars = [var[v] for v in jaxpr.constvars if not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
new_invars = [var[v] for v in jaxpr.invars]
new_eqns = [new_jaxpr_eqn([lit(v) or var[v] for v in eqn.invars],
[var[v] if v in used else dropvar for v in eqn.outvars],
eqn.primitive, eqn.params, eqn.source_info)
for eqn in jaxpr.eqns]
new_outvars = [lit(v) or var[v] for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals
class DynamicJaxprTrace(core.Trace):
__slots__ = [] # type: ignore
@property
def frame(self):
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
def new_arg(self, aval):
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval)
self.frame.invars.append(var)
return tracer
def new_const(self, val):
aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val))
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
self.frame.constvar_to_val[var] = val
return tracer
pure = lift = sublift = new_const
def getvar(self, tracer):
var = self.frame.tracer_to_var.get(id(tracer))
if var is None:
self.frame.tracers.append(tracer)
var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
return var
def getconstvar(self, c):
var = self.frame.constid_to_var.get(id(c))
if var is None:
var = self.frame.constid_to_var[id(c)] = self.frame.newvar(get_aval(c))
return var
def instantiate_const(self, val):
if (isinstance(val, Tracer) and val._trace.main is self.main
and val._trace.sublevel == self.sublevel):
return val
else:
return self.new_const(val)
def process_primitive(self, primitive, tracers, params):
avals = [t.aval for t in tracers]
out_avals = primitive.abstract_eval(*avals, **params)
out_avals = [out_avals] if not primitive.multiple_results else out_avals
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
self.frame.eqns.append(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()
def process_call(self, call_primitive, f, tracers, params):
in_avals = [t.aval for t in tracers]
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
if not jaxpr.eqns:
return core.eval_jaxpr(jaxpr, consts, *tracers)
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
update_params = call_param_updaters.get(call_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
new_params, source_info)
self.frame.eqns.append(eqn)
return out_tracers
def post_process_call(self, call_primitive, out_tracers, params):
assert False # unreachable
def process_map(self, map_primitive, f, tracers, params):
in_avals = [t.aval for t in tracers]
axis_name, axis_size = params['axis_name'], params['axis_size']
reduced_in_avals = [core.mapped_aval(axis_size, a) if m else a
for m, a in zip(params['mapped_invars'], in_avals)]
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals)
out_avals = [core.unmapped_aval(params['axis_size'], a) for a in reduced_out_avals]
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
new_mapped_invars = (False,) * len(consts) + params['mapped_invars']
new_params = dict(params, mapped_invars=new_mapped_invars,
call_jaxpr=convert_constvars_jaxpr(jaxpr))
update_params = call_param_updaters.get(map_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
new_params, source_info)
self.frame.eqns.append(eqn)
return out_tracers
def post_process_map(self, map_primitive, out_tracers, params):