-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
partial_eval.py
1932 lines (1696 loc) · 85.6 KB
/
partial_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import contextlib
import functools
from functools import partial
import inspect
import itertools as it
import operator as op
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
List, Union, Set, cast)
from weakref import ref
import numpy as np
from jax import core
from jax._src import dtypes
from jax import linear_util as lu
from jax._src import profiler
from jax._src.ad_util import Zero
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_leaves)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache)
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, raise_to_shaped, Var, Atom, JaxprEqn,
Primitive, DShapedArray, mapped_aval, unmapped_aval)
from jax._src import source_info_util
from jax.config import config
map = safe_map
zip = safe_zip
def identity(x): return x
class PartialVal(tuple):
"""Partial value: either a known value or an unknown (abstract) value.
Represented as a pair `(aval_opt, const)` of one of two kinds:
* `(None, <Constant>)` indicates a known value, either a Python regular
value, or a Tracer.
* `(<AbstractValue>, *)` indicates an unknown value characterized by an
abstract value.
"""
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
pv, const = xs
if config.jax_enable_checks:
# type checks
assert isinstance(pv, (AbstractValue, type(None))), xs
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
# invariant checks
if isinstance(pv, AbstractValue):
assert get_aval(const) == core.abstract_unit, xs
return tuple.__new__(cls, xs)
@classmethod
def known(cls, const: core.Value) -> 'PartialVal':
return PartialVal((None, const))
@classmethod
def unknown(cls, aval: AbstractValue) -> 'PartialVal':
return PartialVal((aval, core.unit))
def is_known(self) -> bool:
return self[0] is None
def get_known(self) -> Optional[core.Value]:
"""Get the known value, if known, else None."""
return self[1] if self[0] is None else None
def get_aval(self) -> AbstractValue:
"""Get AbstractValue directly (if unknown) or from the constant (known)."""
known = self.get_known()
if known is not None:
return get_aval(known)
else:
return self[0]
def merge_with_known(self, val: core.Value) -> core.Value:
"""Either the stored known value, or the given 'val'."""
known = self.get_known()
return known if known is not None else val
breakpoint()
class JaxprTrace(Trace):
def __init__(self, *args, name_stack: source_info_util.NameStack):
super().__init__(*args)
self.name_stack = name_stack
def pure(self, val) -> 'JaxprTracer':
return self.new_const(val)
def lift(self, val) -> 'JaxprTracer':
return self.new_const(val)
def sublift(self, val) -> 'JaxprTracer':
return JaxprTracer(self, val.pval, FreeVar(val))
def new_const(self, val) -> 'JaxprTracer':
if isinstance(val, Tracer) and val._trace.level == self.level:
raise Exception
return JaxprTracer(self, PartialVal.known(val), unit)
def new_instantiated_literal(self, val) -> 'JaxprTracer':
aval = get_aval(val)
return JaxprTracer(self, PartialVal.unknown(aval),
Literal(val, raise_to_shaped(aval)))
def new_instantiated_const(self, val) -> 'JaxprTracer':
return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val))
def new_arg(self, pval: PartialVal) -> 'JaxprTracer':
const = pval.get_known()
# XXX: Think twice before changing this constant argument pruning!
# This has really important consequences for partial_eval_jaxpr.
# Most importantly, this guarantees that the unknown jaxpr never uses
# known inputs (if it needs them, then they get passed through residuals).
if const is None:
return JaxprTracer(self, pval, LambdaBinding())
else:
return self.new_const(const)
def instantiate_const(self, tracer) -> Tracer:
const = tracer.pval.get_known()
if const is None:
return tracer
else:
if type(const) in core.literalable_types and np.shape(const) == ():
return self.new_instantiated_literal(const)
else:
return self.new_instantiated_const(const)
def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer':
const = tracer.pval.get_known()
if const is None:
return tracer
else:
aval = raise_to_shaped(get_aval(const), np.isscalar(const))
return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const))
def process_primitive(self, primitive, tracers, params):
if primitive in custom_partial_eval_rules:
return custom_partial_eval_rules[primitive](self, *tracers, **params)
else:
return self.default_process_primitive(primitive, tracers, params)
def default_process_primitive(self, primitive, tracers, params):
"""By default, if all the input tracers are known, then execute the primitive
and all the outputs are known. Otherwise, all the outputs are unknown."""
consts = [t.pval.get_known() for t in tracers]
if all(c is not None for c in consts):
return primitive.bind(*consts, **params)
tracers = map(self.instantiate_const, tracers)
avals = [t.aval for t in tracers]
out_aval = primitive.abstract_eval(*avals, **params)
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in out_aval]
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source)
for t in out_tracers: t.recipe = eqn
return out_tracers
else:
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
params, source)
return out_tracer
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
if primitive in call_partial_eval_rules:
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers])
# We want to partially evaluate this call into two calls: one evaluated now
# taking known values (in_consts) as inputs and producing known values
# (out_consts) as outputs, and the other staged out as an eqn into the jaxpr
# being built. The latter takes as input residuals (res) produced as outputs
# of the first call, shared closed-over values (env), and explicit arguments
# which were unknown to the first call (corresponding to in_avals).
# Wrap f to perform the partial evaluation and plumb out aux data.
f = trace_to_subjaxpr_nounits(f, self.main, False)
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
# Adjust parameters (e.g. donated_invars) for the call to be evaluated now.
const_params = update_params(params, in_knowns, 0)
# Run the call, getting known out vals and aux data used for staged-out call
out = primitive.bind(f, *in_consts, **const_params)
out_knowns, out_avals, jaxpr, env = aux()
# Split apart known outputs from the original call and residuals.
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
# Create the input tracers for the staged-out (unknown-value) call.
const_tracers = map(self.new_instantiated_const, res)
env_tracers = map(self.full_raise, env)
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
# Adjust parameters (e.g. donated_invars) for the staged-out call's args.
num_new_args = len(const_tracers) + len(env_tracers)
staged_params = update_params(params, map(op.not_, in_knowns), num_new_args)
staged_params = dict(staged_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
def process_map(self, primitive, f: lu.WrappedFun, tracers, params):
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers])
# This method is like process_call above, except:
# 1. we delete an axis from mapped-over input avals' shapes, and
# analogously add an axis to mapped-over output avals' shapes;
# 2. we update the in_axes and out_axes/out_axes_thunk parameters to
# reflect the inputs and outputs pruned from the unknown/known sides.
# Map (delete an axis from) unknown inputs' avals as dictated by in_axes.
unk_in_axes, const_in_axes = partition_list(in_knowns, params['in_axes'])
in_avals_mapped = [mapped_aval(params['axis_size'], ax, aval)
for ax, aval in zip(unk_in_axes, in_avals)]
# Wrap f to perform partial evaluation and plumb out aux data.
f = trace_to_subjaxpr_nounits(f, self.main, False)
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns),
tuple(in_avals_mapped))
# Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk)
const_params = update_params(params, in_knowns, 0) # handles donated_invars
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(closure=out_axes_thunk)
def const_out_axes_thunk():
out_knowns, _, jaxpr, _ = aux()
_, out_axes = partition_list(out_knowns, out_axes_thunk())
return tuple(out_axes) + (0,) * len(jaxpr.constvars) # res mapped axis 0
const_params = dict(const_params, in_axes=tuple(const_in_axes),
out_axes_thunk=const_out_axes_thunk)
# Run the map, getting known out vals and aux data used for staged-out map.
with core.extend_axis_env(params['axis_name'], params['axis_size'], None):
out = primitive.bind(f, *in_consts, **const_params)
out_knowns, out_avals_mapped, jaxpr, env = aux()
# Split apart known outputs from the original call and residuals.
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
# We can only check_jaxpr with the dynamic axis environment extended:
with core.extend_axis_env(params['axis_name'], params['axis_size'], None):
call_jaxpr = convert_constvars_jaxpr(jaxpr)
# Compute staged and const out_axes, taking into account residuals.
out_axes = params['out_axes_thunk']()
staged_out_axes, _ = partition_list(out_knowns, out_axes)
staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,)
# Create the input tracers for the staged-out (unkonwn-value) call.
const_tracers = map(self.new_instantiated_const, res)
env_tracers = map(self.full_raise, env)
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
# Adjust params for staged-out call on unknown values.
num_new_args = len(const_tracers) + len(env_tracers)
staged_params = update_params(params, map(op.not_, in_knowns), num_new_args)
staged_params = dict(staged_params, in_axes=staged_in_axes,
out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr)
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a)
for ax, a in zip(staged_out_axes, out_avals_mapped)]
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params,
source_info_util.current())
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
def post_process_call(self, primitive, out_tracers, params):
unknown_out_tracers = [t for t in out_tracers if not t.is_known()]
jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers)
out_pvals = [t.pval for t in out_tracers]
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
out = [*out_consts, *res]
main = self.main
def todo(out):
trace = main.with_cur_sublevel()
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
const_tracers = map(trace.new_instantiated_const, res)
in_tracers = (*const_tracers, *map(trace.full_raise, env))
out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None)
for a in out_avals]
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
new_params = update_params(params, [], len(in_tracers))
new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
return out, todo
def post_process_map(self, primitive, out_tracers, params):
unknown_out_tracers = [t for t in out_tracers if not t.is_known()]
jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers)
out_pvals = [t.pval for t in out_tracers]
out_knowns, out_avals_mapped, out_consts = partition_pvals(out_pvals)
out = [*out_consts, *res]
main = self.main
with core.extend_axis_env(params['axis_name'], params['axis_size'], None):
call_jaxpr = convert_constvars_jaxpr(jaxpr)
def todo(out):
trace = main.with_cur_sublevel()
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env)
staged_out_axes = tuple(out_axes_unknown) # set by out_axes_transform
staged_in_axes = (0,) * len(res) + (None,) * len(env)
update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p)
staged_params = update_params(params, [], len(res) + len(env))
staged_params = dict(staged_params, in_axes=staged_in_axes,
out_axes=tuple(staged_out_axes),
call_jaxpr=call_jaxpr)
out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], d, a)
for d, a in zip(staged_out_axes, out_avals_mapped)]
out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None)
for a in out_avals]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
primitive, staged_params, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
def out_axes_transform(out_axes):
nonlocal out_axes_unknown
out_axes_unknown, out_axes_known = partition_list(out_knowns, out_axes)
return tuple(out_axes_known) + (0,) * len(jaxpr.constvars)
out_axes_unknown: Optional[list] = None
return out, (todo, out_axes_transform)
def _current_truncated_name_stack(self):
return source_info_util.current_name_stack()[len(self.name_stack):]
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
instantiate: bool):
"""Partially evaluate f on a sequence of PartialVals."""
in_avals, in_consts = unzip2(pvals)
f = trace_to_subjaxpr(f, self.main, instantiate)
f, aux = partial_eval_wrapper(f, tuple(in_avals))
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvs = map(PartialVal, zip(out_avals, out_consts))
env_tracers = map(self.full_raise, env)
return jaxpr, out_pvs, consts, env_tracers
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
tracers = map(self.instantiate_const_abstracted, tracers)
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
fun = trace_to_subjaxpr(fun, self.main, True)
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
out_flat = prim.bind(fun, jvp, *in_consts)
out_avals, jaxpr, env = aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
env_tracers = map(self.full_raise, env)
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *tracers)
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
@_memoize
def jvp_jaxpr_thunk():
jvp_ = trace_to_subjaxpr(jvp, self.main, True)
jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
with core.new_sublevel():
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
out_avals, jaxpr, env = aux()
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*consts, *env)
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts) + len(env)),
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
def post_process_custom_jvp_call(self, out_tracers, _):
# This path should only be reachable if we expose a partial eval API
# unrelated to autodiff, since we raise an error when differentiation with
# respect to values over which a custom_jvp function closes is detected.
raise NotImplementedError # TODO(mattjj)
def process_custom_transpose(self, prim, call, tracers, **params):
res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves])
assert all(t.is_known() for t in res_ts)
lin_all_known = all(t.is_known() for t in lin_ts)
if lin_all_known:
res_cvals = [t.pval[1] for t in res_ts]
lin_cvals = [t.pval[1] for t in lin_ts]
return prim.bind(call, *res_cvals, *lin_cvals, **params)
else:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in params['out_types']]
in_tracers = map(self.instantiate_const, tracers)
new_params = dict(params, call=call)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
tracers = map(self.instantiate_const_abstracted, tracers)
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
fun = trace_to_subjaxpr(fun, self.main, True)
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees)
out_avals, jaxpr, env = aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
env_tracers = map(self.full_raise, env)
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *tracers)
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
@_memoize
def fwd_jaxpr_thunk():
fwd_ = trace_to_subjaxpr(fwd, self.main, True)
fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
with core.new_sublevel():
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
out_avals, jaxpr, env = aux()
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*consts, *env)
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(consts) + len(env),
bwd=bwd, out_trees=out_trees),
source)
for t in out_tracers: t.recipe = eqn
return out_tracers
def post_process_custom_vjp_call(self, out_tracers, _):
# This path should only be reachable if we expose a partial eval API
# unrelated to autodiff, since we raise an error when differentiation with
# respect to values over which a custom_vjp function closes is detected.
raise NotImplementedError # TODO(mattjj)
def partition_pvals(
pvals: List[PartialVal]
) -> Tuple[List[bool], List[AbstractValue], List[Any]]:
knowns = [pval.is_known() for pval in pvals ]
avals = [pval.get_aval() for pval in pvals if not pval.is_known()]
consts = [pval.get_known() for pval in pvals if pval.is_known()]
return knowns, avals, consts
@lu.transformation_with_aux
def partial_eval_wrapper_nounits(
in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue],
*in_consts: Any):
in_avals_, in_consts_ = iter(in_avals), iter(in_consts)
in_pvals = [PartialVal.known(next(in_consts_)) if known else
PartialVal.unknown(next(in_avals_)) for known in in_knowns]
sentinel = object()
assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel
jaxpr, (out_pvals, res, env) = yield (in_pvals,), {}
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
yield (*out_consts, *res), (out_knowns, out_avals, jaxpr, env)
custom_partial_eval_rules: Dict[Primitive, Callable] = {}
call_partial_eval_rules: Dict[Primitive, Callable] = {}
call_param_updaters: Dict[Primitive, Callable] = {}
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _ = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params), avals, debug_info)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
return avals_out
JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar',
'ConstVar', Literal, core.Unit]
class JaxprTracer(Tracer):
__slots__ = ['pval', 'recipe']
def __init__(self, trace: JaxprTrace, pval: PartialVal,
recipe: Optional[JaxprTracerRecipe]):
assert isinstance(pval, PartialVal)
pv, const = pval
if isinstance(const, Tracer) and const._trace.level >= trace.level:
raise core.escaped_tracer_error(
const, "Tracer from a higher level: {} in trace {}".format(const, trace))
self._trace = trace
self.pval = pval
self.recipe = recipe
def __repr__(self):
return 'Traced<{}:{}>'.format(self.aval, self._trace)
@property
def aval(self) -> AbstractValue:
return self.pval.get_aval()
@property
def parents(self) -> Sequence['JaxprTracer']:
if isinstance(self.recipe, JaxprEqnRecipe):
return self.recipe.invars
else:
return []
def full_lower(self):
known = self.pval.get_known()
if known is not None:
return core.full_lower(known)
else:
return self
def is_known(self):
return self.pval.is_known()
@profiler.annotate_function
def trace_to_jaxpr(
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
) -> Tuple[Jaxpr, List[PartialVal], List[core.Value]]:
"""
Partially evaluate a function, building a jaxpr for un-evaluated computation.
Args:
fun: lu.WrappedFun representing the function to be partially evaluated. The
function must be flattened, in the sense of accepting jaxpr type arguments
and returning a flat list of jaxpr type outputs.
pvals: sequence of PartialVals of length equal to the number of inputs to
`fun` indicating which inputs are known or unknown.
instantiate: optional bool or sequence of bools of length equal to the
number of outputs of `fun` indicating which outputs should be forced to be
treated as unknown and hence instantiated in the jaxpr. If a single bool,
the value is applied to all outputs. Default False.
Returns:
A triple where the first element is a jaxpr representing the computation
which depends on unknown inputs; the second element is a list of PartialVals
of length equal to the length of the output of `fun` representing which
outputs are known and unknown (along with their values and abstract values,
respectively); the third element is a list of known residual values. The
returned jaxpr takes as inputs the known residual values followed by values
of the originally unknown inputs.
"""
current_name_stack = source_info_util.current_name_stack()
with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
del main, fun, env
return jaxpr, out_pvals, consts
@lu.transformation
def trace_to_subjaxpr_nounits(
main: core.MainTrace, instantiate: Union[bool, Sequence[bool]],
in_pvals: Sequence[PartialVal]):
assert all([isinstance(pv, PartialVal) for pv in in_pvals]), in_pvals
trace = main.with_cur_sublevel()
in_knowns = [pval.is_known() for pval in in_pvals]
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
in_args = merge_lists(in_knowns, in_tracers, in_consts)
ans = yield in_args, {}
assert isinstance(ans, (list, tuple)), (
f"Got unexpected return type when tracing function to jaxpr: {ans}")
assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), (
f"Got unexpected return type when tracing function to jaxpr: {ans}")
if isinstance(instantiate, bool):
instantiate = [instantiate] * len(ans)
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
out_pvals = [t.pval for t in out_tracers]
out_tracers_ = [t for t in out_tracers if not t.is_known()]
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers_)
del trace, in_tracers, out_tracers, out_tracers_
yield jaxpr, (out_pvals, consts, env)
def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
if instantiate:
return trace.instantiate_const(trace.full_raise(tracer))
else:
return tracer
FreeVar = namedtuple('FreeVar', ['val'])
ConstVar = namedtuple('ConstVar', ['val'])
LambdaBinding = namedtuple('LambdaBinding', [])
class JaxprEqnRecipe(NamedTuple):
eqn_id: object
invars: Sequence[JaxprTracer]
outvars: 'Sequence[ref[JaxprTracer]]'
primitive: Primitive
params: Dict[str, Any]
source_info: source_info_util.SourceInfo
def new_eqn_recipe(invars: Sequence[JaxprTracer],
outvars: Sequence[JaxprTracer],
primitive: Primitive,
params: Dict[str, Any],
source_info: source_info_util.SourceInfo
) -> JaxprEqnRecipe:
"""Constructs a new JaxEqnRecipe.
Params:
invars: the tracers for the primitive inputs.
outvars: the tracers for the primitive outputs.
primitive: the primitive.
params: the primitive params
"""
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
# assert len(invars) == len(params["call_jaxpr"].invars) # TODO constvars?
assert len(outvars) == len(params["call_jaxpr"].outvars)
assert ("donated_invars" not in params or
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
if primitive.map_primitive:
assert ("in_axes" in params and
len(params["in_axes"]) == len(params["call_jaxpr"].invars))
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
params, source_info)
def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
_, in_tracers, out_tracer_refs, primitive, params, source_info = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
invars = [getvar(t) for t in in_tracers]
outvars = [core.DropVar(core.abstract_unit) if t is None
else cast(Var, getvar(t)) for t in out_tracers]
return new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer]
) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]:
"""Constructs Jaxpr given tracers for inputs and outputs.
Params:
in_tracers: the tracers that were created for the function inputs
out_tracers: the tracers that were output by the function.
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
the `constvars` in the returned Jaxps, and a list of environment values.
The vars for the environment values have been prepended to the Jaxpr's
`invars`.
"""
newvar = core.gensym()
t_to_var: Dict[int, Atom] = {}
def getvar(t: JaxprTracer) -> Atom:
var = t_to_var.get(id(t))
if var is None:
aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
var = t_to_var[id(t)] = newvar(aval)
return var
sorted_tracers = toposort(out_tracers)
invars = map(getvar, in_tracers)
eqns: List[core.JaxprEqn] = []
env: Dict[Var, Any] = {}
consts: Dict[Var, Any] = {}
const_to_var: Dict[int, Var] = {}
def getconstvar(c):
var = const_to_var.get(id(c))
if var is None:
var = const_to_var[id(c)] = newvar(get_aval(c))
return var
processed_eqn_ids = set()
for t in sorted_tracers:
recipe = t.recipe
if isinstance(recipe, JaxprEqnRecipe):
if recipe.eqn_id not in processed_eqn_ids:
eqns.append(recipe_to_eqn(getvar, recipe))
processed_eqn_ids.add(recipe.eqn_id)
elif isinstance(recipe, LambdaBinding):
if not any(t is in_tracer for in_tracer in in_tracers):
raise core.escaped_tracer_error(
t, "Tracer not among input tracers {}".format(t))
assert in_tracers, "Lambda binding with no args"
elif isinstance(recipe, FreeVar):
env[cast(Var, getvar(t))] = recipe.val
elif isinstance(recipe, ConstVar):
v = t_to_var[id(t)] = getconstvar(recipe.val)
consts[v] = recipe.val
elif isinstance(recipe, Literal):
t_to_var[id(t)] = recipe
elif recipe is unit:
t_to_var[id(t)] = unitvar
else:
raise TypeError(recipe)
env_vars, env_vals = unzip2(env.items())
const_vars, const_vals = unzip2(consts.items())
# The env_vars are pre-pended to the invars
jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
return jaxpr, const_vals, env_vals
@weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
config.jax_enable_checks and core.check_jaxpr(jaxpr)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
config.jax_enable_checks and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
config.jax_enable_checks and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr
def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
instantiate: Union[bool, Sequence[bool]],
) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
"""Specializes a Jaxpr given an indication of which inputs are known.
Returns: (jaxpr_known, jaxpr_unknown, out_unknowns).
`out_unknowns` specifies which outputs are unknown (depend on some unknown inputs).
`jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs,
and performs *all* the computation in `jaxpr` that depends only on the known inputs.
Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`,
appended with the known residuals (the intermediate computations in `jaxpr`
that depend only on known inputs and that are needed to compute the unknown outputs).
`jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals
computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known
outputs replaced by `*`.
Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively
unknown inputs into:
jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *)
let _, uout = jaxpr_unknown(ki, ui, kresidual)
in (kout, uout)
For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2
in (ki + 3, ui + ka)"
then
`jaxpr_known` = lambda ki, ui: let ka = ki + 2
in (ki + 3, *, ka)
'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka)
Note that if instantiate is True for a given output, then jaxpr_known always returns a
unit in its place. So when instantiate is True, the expectation is the one doesn't
run `jaxpr_known` for any of its outputs, but only to generate residuals that will allow
to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will
simply get passed as residual constants and returned immediately.
"""
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _partial_eval_jaxpr(jaxpr, tuple(unknowns), instantiate)
@weakref_lru_cache
def _partial_eval_jaxpr(jaxpr, unknowns, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
cell = []
def fun(*vals):
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate)
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
return out_consts_2 + consts_2
# For jaxpr_known we pass core.unit for the unknown inputs, and known
# PartialVal for the known inputs.
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
(out_pvs_2, jaxpr_2, num_res), = cell
assert len(jaxpr_2.constvars) == num_res
# jaxpr :: a -> b
# jaxpr_1 :: a1 -> [b1, res]
# jaxpr_2 :: res | a2 -> b2
# jaxpr_2 :: [a2, res] -> b2
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
if not unknown:
var.aval = abstract_unit
uk_out = [pv is not None for pv in out_pvs_2]
in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals))
out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals))
# out_avals_1 and in_avals_2 need the residuals added
res_avals = out_avals[len(jaxpr.out_avals):]
assert len(res_avals) == num_res
out_avals_1 = [*out_avals_1, *res_avals]
in_avals_2 = [*in_avals_2, *res_avals]
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
remat_call_p: Primitive = core.CallPrimitive('remat_call')
remat_call = remat_call_p.bind
remat_call_p.def_impl(core.call_impl)
def _remat_partial_eval(trace, _, f, tracers, params):
concrete = params['concrete']
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
# the function being called, not just for the unknown parts. To do that, we
# instantiate all the input tracers as constants in the jaxpr being formed.
# Those tracers might have concrete avals, and doing abstract interpretation
# on concrete avals engenders a tradeoff: it allows data-dependent Python
# control flow to work, but it can in some cases lead to redundant FLOPs (done
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
# `concrete` parameter to switch this behavior, and if `concrete` is False
# then we raise the avals to the Shaped level.
if concrete:
instantiated_tracers = map(trace.instantiate_const, tracers)
else:
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvals = [PartialVal.unknown(t.pval[0]) for t in instantiated_tracers]
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
jaxpr = convert_constvars_jaxpr(jaxpr)
# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
if params['policy']:
# unzip into jaxpr_known and jaxpr_unknown
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = _partial_eval_jaxpr_custom(
jaxpr, in_unknowns, params['policy'])
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
jaxpr_unknown, in_used_unknown = dce_jaxpr(jaxpr_unknown, used_outs_unknown)
# compute known outputs and residuals (hoisted out of a remat_call)
if concrete: raise NotImplementedError # TODO(mattjj)
_, in_consts_ = unzip2(t.pval for t in it.chain(env_tracers, tracers)
if t.pval.is_known())
_, in_consts = partition_list(in_used_known, [*consts, *in_consts_])
out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
out_consts_ = iter(out_consts)
# reconstruct known outs, inserting units
outs1 = [pval.get_known() if x.aval is abstract_unit else next(out_consts_)
for uk, pval, x in zip(out_unknowns, eval_out_pvals, jaxpr.outvars)
if not uk]
# form known outputs and collect residual tracers
out_known_tracers = [JaxprTracer(trace, PartialVal.known(c), None)
for c in outs1]
residuals = list(out_consts_)
# set up unknown outputs with a recipe to call remat
res_tracers = map(trace.new_instantiated_const, residuals)
const_tracers = map(trace.new_instantiated_const, consts)
in_jaxpr_tracers = [*res_tracers, *const_tracers, *env_tracers,
*instantiated_tracers]
_, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
out_jaxpr_tracers = [JaxprTracer(trace, PartialVal.unknown(x.aval), None)
for x in jaxpr_unknown.outvars]
new_params = dict(params, call_jaxpr=jaxpr_unknown, differentiated=True)
recipe = new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_call_p,
new_params, source_info_util.current())
for t in out_jaxpr_tracers: t.recipe = recipe
return _zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
else:
# TODO(mattjj): this is an old parallel code path, to be deleted once the
# new path is fully functional
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
closed_jaxpr, in_unknowns, instantiate=False) # type: ignore
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
# Next, we need values for the outputs that should be known. Since consts
# weren't passed through Python for evaluation, we need to evaluate
# jaxpr_known, minus the residual outputs that we don't need. When
# `concrete=True`, as an optimization we can avoid redoing *some* redundant
# FLOPs, namely those that produced concrete avals at the output, simply by
# using those as computed values. For the use case of inverse-mode ad in
# op-by-op ("eager mode") evaluation, all the primal outputs should be
# concrete (thus not recomputed).
to_compute = [type(pval[0]) is not ConcreteArray
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
num_outputs = len(jaxpr_unknown.out_avals)
num_res = len(jaxpr_known.out_avals) - num_outputs
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res,
drop_outputs=True)
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
# Known outputs should keep propagating as constants
assert all(pv.is_known() for pv in out_known_pvals)
known_output_tracers = [trace.new_const(pval.get_known())
for pval in out_known_pvals]
# Unknown outputs get wrapped in tracers with the appropriate recipe
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
for out_pval in out_unknown_pvals]
# dce jaxpr outputs
new_jaxpr = _dce_jaxpr(closed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
new_params = dict(params, call_jaxpr=new_jaxpr, differentiated=True)
# set up eqn for unknown outputs
const_tracers = map(trace.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p,
new_params, source_info_util.current())
for t in unknown_output_tracers: t.recipe = eqn
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
def _partition_knowns(pvals, unknowns: Sequence[bool]):
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
[e for e, unknown in zip(pvals, unknowns) if unknown])
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
assert len(known_list) + len(unknown_list) == len(which_unknown)
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
def _partial_eval_jaxpr_custom(
jaxpr: Jaxpr, in_unknowns: Sequence[bool], saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, Sequence[bool], Sequence[bool], int]:
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
env: Dict[Var, Tuple[bool, bool]] = {}
residuals: OrderedSet[Var] = OrderedSet()
def read(x: Atom) -> Tuple[bool, bool]:
if type(x) is Var:
return env[x]
return (False, True)
def write(unk: bool, inst: bool, v: Var) -> None:
assert (unk, inst) != (True, False)
env[v] = (unk, inst)
def ensure_instantiated(inst: bool, x: Atom) -> Atom:
if type(x) is Var and not inst:
residuals.add(x)
return x
known_eqns, staged_eqns = [], []