-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
loops.py
2001 lines (1734 loc) · 86.2 KB
/
loops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the loop primitives."""
from functools import partial
import itertools
import operator
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar
import jax
import weakref
from jax import core
from jax import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
import jax._src.pretty_printer as pp
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map)
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (
cache,
extend_name_stack,
partition_list,
safe_map,
safe_zip,
split_list,
unzip2,
weakref_lru_cache,
)
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_avals_short,
_check_tree_and_avals,
_initial_style_jaxpr,
_make_closed_jaxpr,
_prune_zeros,
_typecheck_param,
allowed_effects,
)
_map = safe_map
zip = safe_zip
T = TypeVar('T')
Array = Any
BooleanNumeric = Any # A bool, or a Boolean array.
### Helper functions
def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
"""Promote weakly-typed in_vals to be compatible with out_avals.
Args:
in_vals : flattened list of input values.
in_avals : corresponding list of avals.
out_avals : list of target output avals.
Returns:
in_vals_new : flattened list of modified in_vals with no weak types.
changed : bool; true if in_vals required modification.
"""
if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
# Calling function is responsible for catching this.
return in_vals, False
weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
if not weak_mismatches:
return in_vals, False
for i in weak_mismatches:
new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
return in_vals, True
### scan
Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')
@api_boundary
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1) -> Tuple[Carry, Y]:
"""Scan a function over leading array axes while carrying along state.
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
where we use [t] here to denote the type t with an additional leading axis.
That is, if t is an array type then [t] represents the type with an additional
leading axis, and if t is a pytree (container) type with array leaves then [t]
represents the type with the same pytree structure and corresponding leaves
each with an additional leading axis.
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
of ``scan`` are given roughly by this Python implementation::
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
types, and so multiple arrays can be scanned over at once and produce multiple
output arrays. (None is actually an empty pytree.)
Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
a single XLA While HLO. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
across all iterations (and not just be consistent up to NumPy rank/shape
broadcasting and dtype promotion rules, for example). In other words, the type
``c`` in the type signature above represents an array with a fixed shape and
dtype (or a nested tuple/list/dict container data structure with a fixed
structure and arrays with fixed shape and dtype at the leaves).
.. note::
:py:func:`scan` compiles ``f``, so while it can be combined with
:py:func:`jit`, it's usually unnecessary.
Args:
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
that ``f`` accepts two arguments where the first is a value of the loop
carry and the second is a slice of ``xs`` along its leading axis, and that
``f`` returns a pair where the first element represents a new value for
the loop carry and the second represents a slice of the output.
init: an initial loop carry value of type ``c``, which can be a scalar,
array, or any pytree (nested Python tuple/list/dict) thereof, representing
the initial loop carry value. This value must have the same structure as
the first element of the pair returned by ``f``.
xs: the value of type ``[a]`` over which to scan along the leading axis,
where ``[a]`` can be an array or any pytree (nested Python
tuple/list/dict) thereof with consistent leading axis sizes.
length: optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
be used to perform scans where no input ``xs`` are needed).
reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
unroll: optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
Returns:
A pair of type ``(c, [b])`` where the first element represents the final
loop carry value and the second element represents the stacked outputs of
the second output of ``f`` when scanned over the leading axis of the inputs.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not callable(f):
raise TypeError("lax.scan: f argument should be a callable.")
xs_flat, xs_tree = tree_flatten(xs)
try:
lengths = [x.shape[0] for x in xs_flat]
except AttributeError as err:
msg = "scan got value with no leading axis to scan over: {}."
raise ValueError(
msg.format(', '.join(str(x) for x in xs_flat
if not hasattr(x, 'shape')))) from err
if length is not None:
length = int(length)
if not all(length == l for l in lengths):
msg = ("scan got `length` argument of {} which disagrees with "
"leading axis sizes {}.")
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
else:
unique_lengths = set(lengths)
if len(unique_lengths) > 1:
msg = "scan got values with different leading axis sizes: {}."
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
elif len(unique_lengths) == 0:
msg = "scan got no values to scan over and `length` not provided."
raise ValueError(msg)
else:
length, = unique_lengths
if config.jax_disable_jit:
if length == 0:
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
carry = init
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
ys.append(y)
stack = lambda *ys: jax.numpy.stack(ys)
stacked_y = tree_map(stack, *maybe_reversed(ys))
return carry, stacked_y
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(_abstractify, init_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(
f, in_tree, (*carry_avals, *x_avals), "scan")
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves]
return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children
# The carry input and output avals must match exactly. However, we want to account for
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
# may not match the output despite being compatible by virtue of their weak type.
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
# necessary, a second time with modified init values.
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out)
if changed:
new_init = tree_unflatten(init_tree, new_init_flat)
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
in_flat, jaxpr, consts, out_tree, out_tree_children = rest
_check_tree_and_avals("scan carry output and input",
# Extract the subtree and avals for the first element of the return tuple
out_tree_children[0], carry_avals_out,
init_tree, carry_avals)
disallowed_effects = jaxpr.effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `scan`: {disallowed_effects}')
out = scan_p.bind(*consts, *in_flat,
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=len(init_flat),
linear=(False,) * (len(consts) + len(in_flat)),
unroll=unroll)
return tree_unflatten(out_tree, out)
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
carry = init
ys = []
for i in range(length):
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out = f_impl(*consts, *carry, *x)
carry, y = split_list(out, [num_carry])
ys.append(y)
ys = list(reversed(ys)) if reverse else ys
ys = list(zip(*ys))
ys = _map(_stack, y_avals, ys)
return (*carry, *ys)
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
def cond_fun(vals):
i, *_ = vals
return i < length
def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out
ys_init = _map(partial(_empty_array, length), y_avals)
if length == 0:
return init + ys_init
else:
init_val = [lax._const(length, 0)] + init + ys_init
_, *outs = while_loop(cond_fun, body_fun, init_val)
return outs
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
linear, block_length, f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, block_length)
assert rem == 0
partition = partial(_partition_leading, num_blocks, block_length)
xs_block = _map(partition, x_avals, xs)
prepend_aval = partial(_prepend_dim_to_aval, block_length)
x_block_avals = _map(prepend_aval, x_avals)
y_block_avals = _map(prepend_aval, y_avals)
f_impl_block = partial(
_scan_impl_unrolled, reverse=reverse, length=block_length,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
outs = _scan_impl_loop(
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
carry, ys_blocks = split_list(outs, [num_carry])
combine = partial(_combine_leading, num_blocks, block_length)
ys = _map(combine, y_avals, ys_blocks)
return (*carry, *ys)
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
unroll):
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
f_impl = core.jaxpr_as_fun(jaxpr)
if unroll == 1:
return _scan_impl_loop(
*args, reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
y_avals=y_avals)
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, unroll)
length_div = num_blocks * unroll
if rem > 0:
if reverse:
split = partial(_split_leading_dim, rem)
xs_rem, xs = unzip2(_map(split, x_avals, xs))
else:
split = partial(_split_leading_dim, length_div)
xs, xs_rem = unzip2(_map(split, x_avals, xs))
outs = _scan_impl_block_unrolled(
*consts, *init, *xs, reverse=reverse, length=length_div,
num_consts=num_consts, num_carry=num_carry, linear=linear,
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys = split_list(outs, [num_carry])
if rem > 0:
outs = _scan_impl_unrolled(
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys_rem = split_list(outs, [num_carry])
if reverse:
ys = _map(_concatenate, y_avals, ys_rem, ys)
else:
ys = _map(_concatenate, y_avals, ys, ys_rem)
return (*carry, *ys)
def _stack(aval, vals):
vals = [lax.expand_dims(x, (0,)) for x in vals]
return lax.concatenate(vals, 0)
def _concatenate(aval, x1, x2):
return lax.concatenate([x1, x2], 0)
def _split_leading_dim(i, aval, x):
assert x.ndim >= 1
return (slicing.slice_in_dim(x, 0, i),
slicing.slice_in_dim(x, i, x.shape[0]))
def _dynamic_index_array(i, aval, x):
return slicing.dynamic_index_in_dim(x, i, keepdims=False)
def _index_array(i, aval, x):
return slicing.index_in_dim(x, i, keepdims=False)
def _empty_array(sz, aval):
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
def _update_array(i, aval, xs, x):
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
def _partition_leading(sz0, sz1, aval, x):
assert x.ndim >= 1
assert x.shape[0] == sz0 * sz1
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
def _combine_leading(sz0, sz1, aval, x):
assert x.ndim >= 2
assert x.shape[0] == sz0
assert x.shape[1] == sz1
return lax.collapse(x, 0, 2)
def _prepend_dim_to_aval(sz, aval):
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
return carry_avals + ys_avals, jaxpr.effects
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear, unroll):
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry])
# Fixpoint computation of which carry are not ad.zero: either
# non-zero from init, or the carry out is non-zero. Each iteration promotes
# at least one carry to non-zero. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_nz.
carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
nonzeros = const_nz + carry_nz + xs_nz
jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys)
carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:]
if carry_nz_out == carry_nz:
break
else:
carry_nz = _map(operator.or_, carry_nz, carry_nz_out)
else:
assert False, "Fixpoint not reached"
tangents = [ad.instantiate_zeros(t) if nz else t
for t, nz in zip(tangents, nonzeros)]
consts, init, xs = split_list(primals, [num_consts, num_carry])
all_tangents = split_list(tangents, [num_consts, num_carry])
consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents)
jaxpr_jvp_rearranged = ad.rearrange_binders(
jaxpr_jvp,
[num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)],
[num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)])
consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot)
+ init_linear + [True] * len(init_dot)
+ xs_linear + [True] * len(xs_dot))
out_flat = scan_p.bind(
*(consts + consts_dot + init + init_dot + xs + xs_dot),
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
num_consts=num_consts + len(consts_dot),
num_carry=num_carry + len(init_dot),
linear=jaxpr_jvp_linear, unroll=unroll)
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
primals_out = carry + ys
tangents_out_iter = iter(carry_dot + ys_dot)
tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(primals_out, nonzeros_out)]
return primals_out, tangents_out
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry
unknowns = [not t.pval.is_known() for t in tracers]
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
# Fixpoint computation of which carry elements are unknown. Each iteration
# promotes at least one carry to unknown. We need at most len(carry)
# iterations, but we need one last iteration to prepare the jaxpr based on the
# final carry_uk.
carry_uk = init_uk
for _ in range(1 + len(carry_uk)):
unknowns = const_uk + carry_uk + xs_uk
jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
carry_uk_out, ys_uk = split_list(out_uk, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
num_res = len(res_avals)
del res_avals, carry_uk_out
# Instantiate those inputs which must be treated as unknown from the fixpoint.
tracers = [trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns)]
# The residual inputs and outputs of the jaxprs produced haven't yet been
# adapted to the scan calling convention; in particular, jaxpr_known has its
# residual outputs all at the end, meaning they're extensive outputs (which is
# fully general but may be wasteful for residuals which are loop-invariant)
# while jaxpr_unknown has its corresponding residual inputs at the front (just
# as a convention with partial_eval_jaxpr_nounits), making them constant
# inputs. To make them consistent, we move the residual inputs on
# jaxpr_unknown to the end, even though we may move some back in the sequel.
jaxpr_unknown = pe.move_binders_to_back(
jaxpr_unknown, [True] * num_res + [False] * sum(unknowns))
# At this point, all residuals are treated as extensive outputs of jaxpr_known
# (and extensive inputs to jaxpr_unknown). But residuals that are loop-
# invariant can be hoisted out of the scan, rather than letting them get
# broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
# want to broadcast the matrix!). So, outside the loop we perform a partial
# evaluation with known 'const' inputs (but all other inputs unknown).
const_pvals = [pe.PartialVal.known(t.pval.get_known())
for t in tracers[:num_consts] if t.pval.is_known()]
other_pvals = [pe.PartialVal.unknown(aval)
for aval in jaxpr_known.in_avals[len(const_pvals):]]
with source_info_util.reset_name_stack():
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals,
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
# (known values in invar_pvals_out) and also computed loop-invariant values
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
# previous consts). We need to collect the computed inteisive residuals, and
# move corresponding intensive residual binders in jaxpr_unknown to the front.
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
jaxpr_unknown = pe.move_binders_to_front(
jaxpr_unknown,
[False] * sum(unknowns) + [pval.is_known() for pval in res_pvals])
del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals
# We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
# we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.
# As another optimization, for any extensive inputs that are just forwarded to
# extensive outputs, to avoid a copy (which would be looping over
# dynamic-update-slice) we'd rather forward the input tracer/value. That means
# pruning some outputs from jaxpr_known here, and updating `out_flat` below.
fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
# Prune fwds_known to include only extensive input to extensive output.
fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and
in_idx is not None and
in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk)
else None for out_idx, in_idx in enumerate(fwds_known)]
# Drop any extensive output we can instead get by forwarding an input.
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
jaxpr_known_.outvars = [x for x, i in zip(jaxpr_known_.outvars, fwds_known)
if i is None]
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
del jaxpr_known_
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
# Run the known part of the scan (if it has any outputs or effects).
known_inputs = (list(jaxpr_known_consts) +
[t.pval.get_known() for t in tracers[num_consts:]
if t.pval.is_known()])
if not jaxpr_known.out_avals and not jaxpr_known.effects:
out_known = []
else:
linear_known = [False] * len(known_inputs) # conservative!
out_known = scan_p.bind(
*known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk),
linear=tuple(linear_known), unroll=unroll)
del linear_known
# Complete the known output by filling in forwarded values using fwds_known.
out_known_iter = iter(out_known)
out_known = [next(out_known_iter) if f is None
else _maybe_put(known_inputs[f]) for f in fwds_known]
assert next(out_known_iter, None) is None
del known_inputs, out_known_iter
# Split known outputs from residuals.
out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)])
assert len(intensive_res) + len(extensive_res) == num_res
# Create input tracers for jaxpr_unknown bind.
unknown_inputs = [t for t in tracers if not t.pval.is_known()]
intensive_res = _map(trace.new_instantiated_const, intensive_res)
extensive_res = _map(trace.new_instantiated_const, extensive_res)
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval)
for y_aval in y_avals]
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in itertools.chain(carry_avals, ys_avals)]
del carry_avals, y_avals
# Create equation.
linear_unknown = tuple([False] * len(intensive_res) +
[l for l, uk in zip(linear, unknowns) if uk] +
[False] * len(extensive_res))
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res],
out_tracers, scan_p,
dict(reverse=reverse, length=length, unroll=unroll,
jaxpr=jaxpr_unknown, linear=linear_unknown,
num_consts=len(intensive_res) + sum(const_uk),
num_carry=sum(carry_uk)),
jaxpr_unknown.effects, source)
for t in out_tracers: t.recipe = eqn
# Merge known and unknown outputs into final result.
return util.merge_lists(out_uk, out_known, out_tracers)
def _maybe_put(x):
if isinstance(x, np.ndarray):
return jax.device_put(x, jax.devices('cpu')[0])
else:
return x
def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts,
num_carry, jaxpr, linear, unroll):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
num_eres = len(xs_lin) - sum(xs_lin)
if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
raise NotImplementedError
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
raise NotImplementedError
if not all(init_lin):
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
consts, _, xs = split_list(args, [num_consts, num_carry])
ires, _ = split_list(consts, [num_ires])
_, eres = split_list(xs, [sum(xs_lin)])
assert not any(ad.is_undefined_primal(r) for r in ires)
assert not any(ad.is_undefined_primal(r) for r in eres)
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals)
ct_carry, ct_ys = split_list(cts, [num_carry])
ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry)
ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys)
ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts])
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
jaxpr_trans = _transpose_scan_jaxpr(
num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes)
linear_trans = ([False] * num_ires +
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
[False] * num_eres)
outs = scan_p.bind(
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans),
unroll=unroll)
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
# if an axis isn't reduced
res1_avals, c_avals, a_avals, res2_avals = split_list(
jaxpr.in_avals, [num_res1, num_c, num_a])
num_b = len(jaxpr.out_avals)
b_avals = list(jaxpr.out_avals)
@lu.wrap_init
def transposed(*res1_cbar_bbar_res2):
res1, c_bar, b_bar, res2 = split_list(
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts,
primals, b_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
_map(ad.add_tangents, c_bar, new_c_bar))
return c_bar + a_bar
return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
def _scan_batching_rule(axis_size, axis_name, main_type, args, dims, reverse, length,
jaxpr, num_consts, num_carry, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry
orig_batched = [d is not batching.not_mapped for d in dims]
const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry])
# Fixpoint computation of which carry are batched: either
# batched from init, or the carry out is batched. Each iteration promotes
# at least one carry to batched. We need at most len(carry) iterations,
# but we need one last iteration to prepare the jaxpr based on the final
# carry_batched.
carry_batched = init_batched
for _ in range(1 + len(carry_batched)):
batched = const_batched + carry_batched + xs_batched
jaxpr_batched, batched_out = batching.batch_jaxpr(
jaxpr, axis_size, batched,
instantiate=carry_batched + [False] * num_ys,
axis_name=axis_name,
main_type=main_type)
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
if carry_batched_out == carry_batched:
break
else:
carry_batched = _map(operator.or_, carry_batched, carry_batched_out)
else:
assert False, "Fixpoint not reached"
consts, init, xs = split_list(args, [num_consts, num_carry])
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(consts, consts_bdims)]
new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched
else batching.moveaxis(x, d, 0) if now_batched else x
for x, d, was_batched, now_batched in
zip(init, init_bdims, init_batched, carry_batched)]
new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1
else x for x, d in zip(xs, xs_bdims)]
new_args = new_consts + new_init + new_xs
outs = scan_p.bind(
*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll)
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
return outs, carry_bdims + ys_bdims
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
fun = core.jaxpr_as_fun(jaxpr)
@lu.wrap_init
def masked(*args):
[dynamic_length], consts, [i], carry, xs = split_list(
args, [1, num_consts, 1, num_carry])
out = fun(*(consts + carry + xs))
new_carry, ys = split_list(out, [num_carry])
new_carry = [lax.select(i < dynamic_length, new_c, c)
for new_c, c in zip(new_carry, carry)]
return [i + 1] + new_carry + ys
aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_))
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], core.JaxprEqn]:
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry])
for i in range(1 + num_carry):
used_outputs = used_carry_out + used_extensive_out
jaxpr_dce, used_inputs = pe.dce_jaxpr(
jaxpr.jaxpr, used_outputs,
instantiate=[False] * num_consts + used_carry_out + [False] * num_xs)
used_consts, used_carry_in, used_extensive_in = \
split_list(used_inputs, [num_consts, num_carry])
if list(used_carry_in) == list(used_carry_out):
break
else:
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
else:
assert False, "Fixpoint not reached"
if config.jax_enable_checks: core.check_jaxpr(jaxpr.jaxpr)
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
new_params = dict(eqn.params, num_consts=sum(used_consts),
num_carry=sum(used_carry_in), linear=tuple(new_linear),
jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts))
# TODO(mattjj,sharadmv): don't assume effects are never DCE'd?
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, eqn.effects, eqn.source_info)
assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals )
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
return used_inputs, new_eqn
# TODO(mattjj): de-duplicate code with _scan_partial_eval
def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_ys = len(jaxpr.out_avals) - num_carry
# Fixpoint (trivial on 'inst_in', since we might as well make all inputs
# available as DCE can subsequently prune any unused ones)
const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry])
for _ in range(1 + len(carry_uk)):
unks_in = const_uk + carry_uk + xs_uk
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=True,
ensure_out_unknowns=carry_uk + [False] * num_ys,
ensure_out_inst=True, saveable=saveable)
carry_uk_out, ys_uk = split_list(unks_out, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
# Move all residual binders to the back of jaxpr_staged so they're extensive.
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
res_avals = jaxpr_staged.in_avals[:num_res]
jaxpr_staged = pe.move_binders_to_back(
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
# passing in_inst argument to partial_eval_jaxpr_custom above).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)
# As an optimization, hoist loop-invariant residuals out of the loop rather
# than using extensive outputs for them. See _scan_partial_eval for comments.
num_const_known = len(const_uk) - sum(const_uk)
num_carry_known = len(carry_uk) - sum(carry_uk)
num_xs_known = len( xs_uk) - sum( xs_uk)
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \
pe.partial_eval_jaxpr_nounits(
jaxpr_known,
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
[True] * (len(unks_out) - sum(unks_out)) + [False] * num_res)
# jaxpr_known_hoist produces intensive residuals followed by the constants for
# jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts.
_, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res])
jaxpr_staged = pe.move_binders_to_front(
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
del loop_dep, num_carry_known, num_xs_known, const_uk
# Create residual variables.
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a)
for a in ext_avals_mapped]
newvar = core.gensym()
intensive_res = _map(newvar, intensive_avals)
extensive_res = _map(newvar, ext_avals)
# Create known eqn, which is a call_p combining evaluation of
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
# jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist
# (corresponding to consts_known_lp_avals) followed by known carry and xs.
linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
_, linear_known_ = split_list(linear_known_, [num_const_known])
linear_known = [False] * len(consts_known_lp_avals) + linear_known_
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
num_consts=len(consts_known_lp_avals),
num_carry=len(carry_uk)-sum(carry_uk),
linear=tuple(linear_known))
@lu.wrap_init
def known(*ins_known):
consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known])
out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist)
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
return [*intensive_res, *out_loop]
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in ins_known])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*intensive_res, *out_binders_known, *extensive_res],
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
eqn.source_info)
# Create the staged eqn.
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
[False] * len(extensive_res))
params_staged = dict(eqn.params, jaxpr=jaxpr_staged,
num_consts=len(intensive_res) + eqn.params['num_consts'],
linear=tuple(linear_staged))
eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res],
out_binders_staged, eqn.primitive,
params_staged, jaxpr_staged.effects,
eqn.source_info)
new_vars = [*new_inst, *intensive_res, *extensive_res]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
avals = [x.aval for x in in_atoms]
tc = partial(_typecheck_param, 'scan')
tc(reverse, 'reverse', 'bool', type(reverse) is bool)
tc(num_consts, 'num_consts', 'non-negative int',
type(num_consts) is int and num_consts >= 0)
tc(num_carry, 'num_carry', 'non-negative int',
type(num_carry) is int and num_carry >= 0)
tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr)
tc(linear, 'linear', 'tuple of bool',
type(linear) is tuple and all(type(x) is bool for x in linear))
tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0)
tc(length, 'length', 'non-negative int',
type(length) is int and length >= 0)
if len(linear) != len(avals):
raise core.JaxprTypeError(
f'scan param linear has length {len(linear)} for {len(avals)} operands')
const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list(
jaxpr.in_avals, [num_consts, num_carry])
carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry])
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a)
for a in y_avals_mapped]
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
raise core.JaxprTypeError(
f'scan input carry input and output types mismatch: '
f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)):
raise core.JaxprTypeError(
f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
f'called with consts of type\n{_avals_short(const_avals)}')
if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)):
raise core.JaxprTypeError(
f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
f'called with initial carry of type\n{_avals_short(init_avals)}')
if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)):
raise core.JaxprTypeError(
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
f'called with sequence of type\n{_avals_short(x_avals)}')
return [*init_avals, *y_avals], jaxpr.effects
def _scan_pp_rule(eqn, context, settings):
printed_params = dict(eqn.params)
del printed_params['linear']
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
del printed_params['length']
if printed_params['unroll'] == 1:
del printed_params['unroll']
if printed_params['num_carry'] == 0:
del printed_params['num_carry']
if printed_params['num_consts'] == 0:
del printed_params['num_consts']
if not printed_params['reverse']:
del printed_params['reverse']
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
def scan_bind(*args, **params):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_scan_typecheck(True, *in_atoms, **params)
core.check_jaxpr(params['jaxpr'].jaxpr)
return core.AxisPrimitive.bind(scan_p, *args, **params)