-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
control_flow.py
3056 lines (2553 loc) · 125 KB
/
control_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Control flow primitives.
"""
import collections
import functools
from functools import partial
import inspect
import itertools
import operator
import os
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar
import numpy as np
import jax
from jax._src import api
from jax import core
from jax._src import ad_checkpoint
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax import linear_util as lu
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax._src.api_util import flatten_fun_nokwargs
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import masking
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_client
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
split_list, cache, extend_name_stack, wrap_name)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_map,
tree_leaves, tree_structure)
from jax._src import ad_util
from jax.config import config
xops = xla_client.ops
_map = safe_map
zip = safe_zip
_reduce = functools.reduce
T = TypeVar('T')
Array = Any
BooleanNumeric = Any # A bool, or a Boolean array.
@cache()
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
return jaxpr, consts, out_tree()
@cache()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
fun, in_tree, in_avals, primitive_name)
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
return closed_jaxpr, consts, out_tree
@cache()
def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable], in_tree, in_avals, primitive_name: str):
# When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for
# their (input) signatures to match. This function "joins" the staged jaxprs:
# for each one, it makes another that accepts *all* constants, but only uses
# those that it needs (dropping the rest).
jaxprs, all_consts, all_out_trees = \
unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
for fun in funs)
newvar = core.gensym(jaxprs, suffix='_')
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
for consts in all_consts]
unused_const_vars = [[newvar(aval) for aval in const_avals]
for const_avals in all_const_avals]
def pad_jaxpr_constvars(i, jaxpr):
prefix = util.concatenate(unused_const_vars[:i])
suffix = util.concatenate(unused_const_vars[i + 1:])
constvars = [*prefix, *jaxpr.constvars, *suffix]
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
consts = util.concatenate(all_consts)
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
for jaxpr in jaxprs]
return closed_jaxprs, consts, all_out_trees
def _abstractify(x):
return raise_to_shaped(core.get_aval(x))
def _typecheck_param(prim, param, name, msg_required, pred):
if not pred:
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
f'{msg_required} required:')
param_str = str(param)
sep = os.linesep if os.linesep in param_str else ' '
msg = sep.join([msg, param_str])
raise core.JaxprTypeError(msg)
### fori_loop and while_loop
def _fori_cond_fun(loop_carry):
i, upper, _ = loop_carry
return lax.lt(i, upper)
@cache()
def _fori_body_fun(body_fun):
def while_body_fun(loop_carry):
i, upper, x = loop_carry
return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
return while_body_fun
@cache()
def _fori_scan_body_fun(body_fun):
def scanned_fun(loop_carry, _):
i, x = loop_carry
return (i + 1, body_fun(i, x)), None
return scanned_fun
@api_boundary
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
The type signature in brief is
.. code-block:: haskell
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
The semantics of ``fori_loop`` are given by this Python implementation::
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
Unlike that Python version, ``fori_loop`` is implemented in terms of either a
call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the
trip count is static (meaning known at tracing time, perhaps because ``lower``
and ``upper`` are Python integer literals) then the ``fori_loop`` is
implemented in terms of ``scan`` and reverse-mode autodiff is supported;
otherwise, a ``while_loop`` is used and reverse-mode autodiff is not
supported. See those functions' docstrings for more information.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
other words, the type ``a`` in the type signature above represents an array
with a fixed shape and dtype (or a nested tuple/list/dict container data
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: function of type ``(int, a) -> a``.
init_val: initial loop carry value of type ``a``.
Returns:
Loop value from the final iteration, of type ``a``.
"""
# TODO(phawkins): perhaps do more type checking here, better error messages.
lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower))
upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper))
if lower_dtype != upper_dtype:
msg = ("lower and upper arguments to fori_loop must have equal types, "
"got {} and {}")
raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
# If we can specialize on the trip count, call scan instead of a while_loop
# to enable efficient reverse-mode differentiation.
if (isinstance(core.get_aval(lower), ConcreteArray) and
isinstance(core.get_aval(upper), ConcreteArray)):
try:
lower_ = int(lower)
upper_ = int(upper)
except TypeError:
use_scan = False
else:
use_scan = True
else:
use_scan = False
if use_scan:
if config.jax_disable_jit and upper_ == lower_:
# non-jit implementation of scan does not support length=0
return init_val
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
None, length=upper_ - lower_)
else:
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
(lower, upper, init_val))
return result
@api_boundary
def while_loop(cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T) -> T:
"""Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
The type signature in brief is
.. code-block:: haskell
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
The semantics of ``while_loop`` are given by this Python implementation::
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
to a single XLA While HLO. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
other words, the type ``a`` in the type signature above represents an array
with a fixed shape and dtype (or a nested tuple/list/dict container data
structure with a fixed structure and arrays with fixed shape and dtype at the
leaves).
Another difference from using Python-native loop constructs is that
``while_loop`` is not reverse-mode differentiable because XLA computations
require static bounds on memory requirements.
Args:
cond_fun: function of type ``a -> Bool``.
body_fun: function of type ``a -> a``.
init_val: value of type ``a``, a type that can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof, representing the initial
loop carry value.
Returns:
The output from the final iteration of body_fun, of type ``a``.
"""
if config.jax_disable_jit:
try:
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
except core.ConcretizationTypeError:
# Can't run this while_loop in Python (e.g. because there's a vmap
# transformation on it), so we fall back to the primitive version.
pass
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, "while_loop")
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
pred_aval = cond_jaxpr.out_avals[0]
if (not isinstance(pred_aval, ShapedArray)
or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
# The body input and output avals must match exactly. However, we want to account for
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
# may not match the output despite being compatible by virtue of their weak type.
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
# necessary, a second time with modified init values.
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals)
if changed:
new_init_val, = tree_unflatten(in_tree, new_init_vals)
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
cond_jaxpr, cond_consts, body_consts, body_tree = rest
in_tree_children = in_tree.children()
assert len(in_tree_children) == 1
_check_tree_and_avals("body_fun output and input",
body_tree, body_jaxpr.out_avals,
in_tree_children[0], init_avals)
outs = while_p.bind(*cond_consts, *body_consts, *init_vals,
cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr,
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
return tree_unflatten(body_tree, outs)
def _while_loop_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, kwargs["body_jaxpr"].out_avals)
def _while_loop_translation_rule(ctx, avals_in, avals_out, *args, cond_jaxpr,
body_jaxpr, cond_nconsts, body_nconsts):
c = ctx.builder
cond_consts, body_consts, init_vals = split_list(args, [cond_nconsts, body_nconsts])
batched = bool(cond_jaxpr.out_avals[0].shape)
# Since jaxprs don't have tuples and have multiple return values, but we need
# the HLO While loop to take a single tuple input and output a single boolean
# (for the cond computation) or a single tuple output (for the body
# computation), we build XLA computations that handle the tuple munging before
# generating a Call into the computations formed from the jaxprs.
init_carry = xops.Tuple(c, cond_consts + body_consts + init_vals)
cond_c = xla_client.XlaBuilder("cond_computation")
cond_carry = xla.parameter(cond_c, 0, c.get_shape(init_carry))
cond_carry_elts = [xops.GetTupleElement(cond_carry, i) for i in range(len(args))]
x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
name_stack = extend_name_stack(ctx.name_stack, 'while')
cond_ctx = ctx.replace(builder=cond_c,
name_stack=extend_name_stack(name_stack, 'cond'))
pred, = xla.jaxpr_subcomp(
cond_ctx, cond_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, cond_c), cond_jaxpr.consts),
*(x + z))
if batched:
scalar = ShapedArray((), np.bool_)
or_ = xla.primitive_subcomputation(ctx.platform, ctx.axis_env, lax.or_p,
scalar, scalar)
pred = xops.Reduce(cond_c, [pred], [xops.Constant(cond_c, np.array(False))],
or_, list(range(cond_jaxpr.out_avals[0].ndim)))
body_c = xla_client.XlaBuilder("body_computation")
body_carry = xla.parameter(body_c, 0, c.get_shape(init_carry))
body_carry_elts = [xops.GetTupleElement(body_carry, i) for i in range(len(args))]
x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
body_ctx = ctx.replace(builder=body_c,
name_stack=extend_name_stack(name_stack, 'body'))
new_z = xla.jaxpr_subcomp(
body_ctx, body_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, body_c), body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = body_ctx.replace(
name_stack=extend_name_stack(name_stack, 'body_pred'))
body_pred, = xla.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, body_c), cond_jaxpr.consts),
*(x + z))
new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z,
body_jaxpr.out_avals)
assert _map(body_c.get_shape, new_z) == _map(body_c.get_shape, z) # no broadcast
new_carry = xops.Tuple(body_c, [*x, *y, *new_z])
ans = xops.While(cond_c.build(pred), body_c.build(new_carry), init_carry)
ans_elts = [xops.GetTupleElement(ans, i) for i in range(len(args))]
_, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
return z
def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue):
pred_shape = c.get_shape(pred).dimensions()
x_shape = c.get_shape(x).dimensions()
y_shape = c.get_shape(y).dimensions()
assert x_shape == y_shape
if x_y_aval is core.abstract_unit:
return x
elif x_y_aval is core.abstract_token:
return xops.AfterAll(c, [x, y])
else:
assert pred_shape == x_shape[:len(pred_shape)] == y_shape[:len(pred_shape)], (pred_shape, x_shape, y_shape)
bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape))))
return xops.Select(bcast_pred, x, y)
def _while_loop_batching_rule(axis_size, axis_name, main_type, args, dims,
cond_nconsts, cond_jaxpr,
body_nconsts, body_jaxpr):
orig_batched = [d is not batching.not_mapped for d in dims]
cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts])
cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts])
cconst_dims, bconst_dims, init_dims = split_list(dims, [cond_nconsts, body_nconsts])
carry_bat = init_bat
# Fixpoint computation of which carry are batched: either
# batched from init, or the carry out is batched. Each iteration promotes
# at least one carry to batched. We need at most len(carry) iterations to
# reach a fixpoint.
for _ in range(1 + len(carry_bat)):
_, carry_bat_out = batching.batch_jaxpr(
body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat,
axis_name=axis_name, main_type=main_type)
if carry_bat == carry_bat_out:
break
carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out)
else:
assert False, "Fixpoint not reached"
# Knowing how the carry is batched now, we can determine if the predicate is
# batched.
_, (pred_bat,) = batching.batch_jaxpr(
cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False,
axis_name=axis_name, main_type=main_type)
if pred_bat:
# If the predicate is batched, we have to batch *all* of the carry
# regardless of if the body needs it.
carry_bat = [True] * len(carry_bat)
carry_dims = [0] * len(carry_bat)
body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
body_jaxpr, axis_size, bconst_dims + carry_dims,
carry_dims, axis_name=axis_name, main_type=main_type)
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
cond_jaxpr, axis_size, cconst_dims + carry_dims, [0],
axis_name=axis_name, main_type=main_type)
else:
# If the predicate is not batched, we can look at the `cond_jaxpr`'s out
# shape to determine the rank of the predicate. From this rank we pick the
# dims of the carry to be batched to ensure that the predicate shape is a
# prefix of the carry in and out shapes. We can then batch the `body_jaxpr`
# according to these new batch dims.
cond_rank = len(cond_jaxpr.out_avals[0].shape)
carry_dims = [cond_rank if b else None for b in carry_bat]
body_jaxpr_batched, _ = batching.batch_jaxpr_axes(
body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims,
axis_name=axis_name, main_type=main_type)
# Now we need to rebatch the `cond_jaxpr` according to the new dims of the
# carry.
cond_jaxpr_batched, _ = batching.batch_jaxpr_axes(
cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,),
axis_name=axis_name, main_type=main_type)
# To prepare the `init` to the `while_p`, we broadcast values if they are
# unbatched and need to have an out axis. If their current batch axis does not
# match the one it needs to be for the translation rule to work, we move it
# into place.
new_init = []
for x, old_axis, new_axis in zip(init, init_dims, carry_dims):
if old_axis is batching.not_mapped and new_axis is not batching.not_mapped:
new_init.append(batching.broadcast(x, axis_size, new_axis))
elif old_axis is batching.not_mapped and new_axis is batching.not_mapped:
new_init.append(x)
else:
assert new_axis is not batching.not_mapped
new_init.append(batching.moveaxis(x, old_axis, new_axis))
outs = while_p.bind(*(cconsts + bconsts + new_init),
cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
return outs, carry_dims
def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
body_jaxpr):
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts])
carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
body_nonzeros = bconst_nz + carry_nz
body_jvp, nonzeros_out = ad.jvp_jaxpr(
body_jaxpr, body_nonzeros, instantiate=carry_nz)
if nonzeros_out == carry_nz:
break
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
else:
assert False, "Fixpoint not reached"
nonzeros = cconst_nz + body_nonzeros
tangents = [ad.instantiate_zeros(t) if nz else t
for t, nz in zip(tangents, nonzeros)]
cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts])
_, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts])
bconst_dot = _prune_zeros(bconst_dot)
init_dot = _prune_zeros(init_dot)
num_carry = len(primals) - cond_nconsts - body_nconsts
body_jvp_rearranged = ad.rearrange_binders(
body_jvp,
[body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
[num_carry], [len(init_dot)])
newvar = core.gensym([cond_jaxpr.jaxpr])
invars_aug = (
cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot])
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
invars_aug,
cond_jaxpr.jaxpr.outvars,
cond_jaxpr.jaxpr.eqns)
cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts)
out = while_p.bind(
*(cconst + bconst + bconst_dot + init + init_dot),
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr_augmented,
body_nconsts=len(bconst) + len(bconst_dot),
body_jaxpr=body_jvp_rearranged)
out_carry, out_carry_dot = split_list(out, [num_carry])
out_tangents_iter = iter(out_carry_dot)
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(out_carry, nonzeros_out)]
return out_carry, out_tangents
def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int,
body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]:
"""An implementation of partial evaluation for while.
As long as some carry (and hence output) are known and the output
of `cond_jaxpr` is known, we use a portion of the loop body to compute the known
outputs of the `while_loop`. For the unknown outputs we generate Jaxpr to run
the whole while, including recomputing the known parts.
This means that we don't actually save any computation by partial
evaluation if there are unknown outputs.
What this achieves is that we can give a proper error for reverse
differentiation of `while`, because in that use of partial evaluation the
primal inputs are considered "known", and only the tangent computation is
unknown (see issue #2129).
"""
unknowns = [not t.pval.is_known() for t in tracers]
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
# Fixpoint computation of unknown carry. Each iteration promotes
# at least one carry to unknown. We need one last iteration to prepare the jaxpr.
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_jaxpr_known, _, carry_out_uk = pe.partial_eval_jaxpr( # type: ignore
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
if carry_out_uk == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_out_uk)
else:
assert False, "Fixpoint not reached"
cond_jaxpr_known, _, cond_uk = pe.partial_eval_jaxpr( # type: ignore
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
# If conditional is unknown, or all inputs are known, or all are unknown,
# just do the default processing.
return trace.default_process_primitive(while_p, tracers, params)
# Run the known part of the while. Prepare the inputs, as constants (if known), or
# as core.unit.
in_consts = [ core.unit if uk else t.pval.get_known()
for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk,
tracers)]
# There should be no residuals for the cond_jaxpr_known
assert 1 == len(cond_jaxpr_known.out_avals)
# We ignore the residuals from the body_jaxpr_known, so the type of inputs matches
# the type of outputs; residuals are at the end
if len(body_jaxpr_known.out_avals) > len(body_jaxpr.out_avals):
# TODO(necula): this is not quite enough; we should drop the residual computations also
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:len(body_jaxpr.out_avals)]
out_known = while_p.bind(
*in_consts,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr_known,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr_known)
# Run the whole while_loop to get all the outputs, then merge with known ones
out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params)
out_tracers: Sequence[pe.Tracer] = [
out_unknown if uk
else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe)
for uk, out_unknown, known in zip(carry_uk, out_all, out_known)]
return out_tracers
def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for "
"lax.while_loop or lax.fori_loop. "
"Try using lax.scan instead.")
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_translation(while_p, _while_loop_translation_rule,
initial_style=True)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'while_loop')
def _pred_bcast_select_mhlo(
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_unit:
return []
elif x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
y, = ys
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
bcast_pred = mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
return mhlo.SelectOp(bcast_pred, x, y).results
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
flat_loop_carry_types = util.flatten(loop_carry_types)
flat_args = mlir.flatten_lowering_ir_args(args)
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)
# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'cond'))
(pred,), = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z))
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context,
primitive=None,
avals_in=[pred_aval],
avals_out=[pred_aval.update(shape=())])
pred, = lax._unary_reduce_lower(
mhlo.OrOp,
lambda dtype: np.array(False, dtype),
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
mhlo.ReturnOp([pred])
# Loop body
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
with ir.InsertionPoint(body_block):
flat_body_args = [
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body'))
new_z = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
_map(mlir.ir_constants, body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body_pred'))
(body_pred,), = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts), *(x + z))
new_z = _map(
partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
mhlo.ReturnOp([*util.flatten(x), *util.flatten(y), *util.flatten(new_z)])
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
_, _, z = util.split_list(outputs, [cond_nconsts, body_nconsts])
return z
mlir.register_lowering(while_p, _while_lowering)
### cond and switch
# For backward compatibility with a previous switch/cond calling convention,
# we allow a single (pytree) `operand` argument to be passed by keyword. We use
# a sentinel object as its default value to indicate when it is _not_ passed.
_no_operand_sentinel = object()
@api_boundary
def switch(index, branches: Sequence[Callable], *operands,
operand=_no_operand_sentinel):
"""Apply exactly one of ``branches`` given by ``index``.
If ``index`` is out of bounds, it is clamped to within bounds.
Has the semantics of the following Python::
def switch(index, branches, operand):
index = clamp(0, index, len(branches) - 1)
return branches[index](operand)
Args:
index: Integer scalar type, indicating which branch function to apply.
branches: Sequence of functions (A -> B) to be applied based on ``index``.
operands: Operands (A) input to whichever branch is applied.
Returns:
Value (B) of ``branch(*operands)`` for the branch that was selected based
on ``index``.
"""
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
f"operands can be passed, got operand={operand} "
f"and positional operands {operands}")
operands = (operand,)
del operand
if len(np.shape(index)) != 0:
raise TypeError(
f"Branch index must be scalar, "
f"got {index} of shape {np.shape(index)}.")
try:
index_dtype = dtypes.result_type(index)
except TypeError as err:
msg = f"Index type must be an integer, got {index}."
raise TypeError(msg) from err
if index_dtype.kind not in 'iu':
raise TypeError(
f"Index type must be an integer, got {index} as {index_dtype}")
branches = tuple(branches)
if len(branches) == 0:
raise ValueError("Empty branch sequence")
elif len(branches) == 1:
return branches[0](*operands)
index = lax.convert_element_type(index, np.int32)
lo = np.array(0, np.int32)
hi = np.array(len(branches) - 1, np.int32)
index = lax.clamp(lo, index, hi)
if (config.jax_disable_jit and
isinstance(core.get_aval(index), ConcreteArray)):
return branches[int(index)](*operands)
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(_map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
out_trees[0], jaxprs[0].out_avals,
out_tree, jaxpr.out_avals)
linear = (False,) * (len(consts) + len(ops))
out = cond_p.bind(
index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
return tree_unflatten(out_trees[0], out)
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
``cond()`` has equivalent semantics to this Python implementation::
def cond(pred, true_fun, false_fun, *operands):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
``pred`` must be a scalar type.
Args:
pred: Boolean scalar type, indicating which branch function to apply.
true_fun: Function (A -> B), to be applied if ``pred`` is True.
false_fun: Function (A -> B), to be applied if ``pred`` is False.
operands: Operands (A) input to either branch depending on ``pred``. The
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
thereof.
Returns:
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
depending on the value of ``pred``. The type can be a scalar, array, or any
pytree (nested Python tuple/list/dict) thereof.
"""
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
f"operands can be passed, got operand={operand} "
f"and positional operands {operands}")
operands = (operand,)
del operand
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
raise TypeError(
f"Pred must be a scalar, got {pred} of " +
(f"type {type(pred)}" if isinstance(pred, Sequence)
else f"shape {np.shape(pred)}."))
try:
pred_dtype = dtypes.result_type(pred)
except TypeError as err:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred)) from err
if pred_dtype.kind != 'b':
if pred_dtype.kind in 'iuf':
pred = pred != 0
else:
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred_dtype))
if config.jax_disable_jit and isinstance(core.get_aval(pred), ConcreteArray):
if pred:
return true_fun(*operands)
else:
return false_fun(*operands)
ops, ops_tree = tree_flatten(operands)
if linear is None:
linear_ops = [False] * len(ops)
else:
linear_ops, ops_tree2 = tree_flatten(linear)
if ops_tree != ops_tree2:
raise TypeError('linear tree and operand tree mismatch')
ops_avals = tuple(_map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees
_check_tree_and_avals("true_fun and false_fun output",
out_tree, true_jaxpr.out_avals,
false_out_tree, false_jaxpr.out_avals)
index = lax.convert_element_type(pred, np.int32)
linear = [False] * len(consts) + linear_ops
out = cond_p.bind(
index, *consts, *ops,
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
return tree_unflatten(out_tree, out)
@api_boundary
@functools.wraps(_cond)
def cond(*args, **kwargs):
# detect an attempt to call the former, deprecated cond
try:
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
except TypeError:
pass
else:
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
_, _, maybe_true_fun, _, maybe_false_fun = ba.args
if callable(maybe_true_fun) and callable(maybe_false_fun):
return _cond_with_per_branch_args(*ba.args)
return _cond(*args, **kwargs)
def _cond_with_per_branch_args(pred,
true_operand, true_fun: Callable,
false_operand, false_fun: Callable):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Has equivalent semantics to this Python implementation::
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
Pred has to be a scalar type, collection types (list, tuple) are not supported
"""
return _cond(pred,
lambda op: true_fun(op[0]),
lambda op: false_fun(op[1]),
(true_operand, false_operand))
def _cond_abstract_eval(*args, **kwargs):
return _map(raise_to_shaped, kwargs["branches"][0].out_avals)
def _cond_translation_rule(ctx, avals_in, avals_out, index, *args, branches,
linear):
del linear # Unused.
name_stack = extend_name_stack(ctx.name_stack, "cond")
def make_computation(name, jaxpr, op_shape):
c = xla_client.XlaBuilder(name + '_comp')
op = xla.parameter(c, 0, op_shape)
ops = [xops.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
subctx = ctx.replace(
builder=c, name_stack=extend_name_stack(name_stack, name + '_fun'))
outs = xla.jaxpr_subcomp(
subctx, jaxpr.jaxpr,
_map(partial(xla.pyval_to_ir_constant, c), jaxpr.consts), *ops)
return c.build(xops.Tuple(c, outs))
c = ctx.builder
op = xops.Tuple(c, args)
op_shape = c.get_shape(op)
branch_computations = [
make_computation(f'branch_{i}', jaxpr, op_shape)
for i, jaxpr in enumerate(branches)]
return xla.xla_destructure(
c, xops.Conditional(index, branch_computations, [op] * len(branches)))
def _bcast_select(pred, on_true, on_false):
if np.ndim(pred) != np.ndim(on_true):
idx = list(range(np.ndim(pred)))
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
return lax.select(pred, on_true, on_false)
def _bcast_select_n(pred, *cases):
if np.ndim(pred) != np.ndim(cases[0]):
idx = list(range(np.ndim(pred)))
pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
return lax.select_n(pred, *cases)
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
index, *ops = args
index_dim, *op_dims = dims
if index_dim is not batching.not_mapped:
# Convert to a lax.select. While we could get away with not broadcasting
# some operands yet, because all outputs must be broadcast together anyway
# for the select we broadcast the input operands for simplicity and leave
# optimizations to XLA.
# TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
index, *ops = [
batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)]
in_batched = [a is not core.abstract_unit for a in branches[0].in_avals]
out_batched = [a is not core.abstract_unit for a in branches[0].out_avals]
branches_batched = [
batching.batch_jaxpr(
jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0]
for jaxpr in branches]
branch_outs = []
for i, jaxpr in enumerate(branches_batched):
# Perform a select on the inputs for safety of reverse-mode autodiff; see
# https://github.com/google/jax/issues/1052
predicate = lax.eq(index, lax._const(index, i))
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x))
if x is not core.unit else x for x in ops]
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
out = [_bcast_select_n(index, *outs) if outs[0] is not core.unit else outs[0]
for outs in zip(*branch_outs)]