-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
batching.py
966 lines (861 loc) · 41.6 KB
/
batching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
import dataclasses
from functools import partial
from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set,
Tuple, Type, Union)
import numpy as np
import jax
from jax.config import config
from jax._src import core
from jax._src import source_info_util
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, Zero, SymbolicZero,
replace_rule_output_symbolic_zeros, instantiate)
from jax._src import linear_util as lu
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)
from jax.interpreters import partial_eval as pe
Array = Any
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# Piles
# i:(Fin 3) => f32[[3, 1, 4].i]
@dataclasses.dataclass(frozen=True)
class PileTy:
binder: core.Var
length: Union[int, Tracer, core.Var]
elt_ty: core.DShapedArray
def __repr__(self) -> str:
return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}'
replace = dataclasses.replace
# [3, 1, 4].i
@dataclasses.dataclass(frozen=True)
class IndexedAxisSize:
idx: core.Var
lengths: Union[Array, core.Var, Tracer]
def __repr__(self) -> str:
return f'{str(self.lengths)}.Var{id(self.idx)}'
replace = dataclasses.replace
# Pile(aval=a:3 => f32[[3 1 4].a],
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
@dataclasses.dataclass(frozen=True)
class Pile:
aval: PileTy
data: Array
# To vmap over a pile, one must specify the axis as PileAxis.
class PileAxis: pass
pile_axis = PileAxis()
# As a temporary measure before we have more general JITable / ADable interfaces
# (analogues to vmappable), to enable Piles to be used with other
# transformations and higher-order primitives (primarily jit, though also grad
# with allow_int=True) we register them as pytrees.
# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration
def _pile_flatten(pile):
lengths = []
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
if type(d) is IndexedAxisSize else d
for d in pile.aval.elt_ty.shape]
elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape))
aval = pile.aval.replace(elt_ty=elt_ty)
return (lengths, pile.data), aval
def _pile_unflatten(aval, x):
lengths, data = x
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
if type(d) is IndexedAxisSize else d
for d in aval.elt_ty.shape]
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
aval = aval.replace(elt_ty=elt_ty)
return Pile(aval, data)
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
def _pile_result(axis_size, axis, segment_lens, x):
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
shape = list(x.shape)
shape[axis] = IndexedAxisSize(binder, segment_lens)
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
return Pile(PileTy(binder, axis_size, elt_ty), x)
@dataclasses.dataclass(frozen=True)
class ConcatAxis:
axis: int
segment_lengths: Array
def _update_annotation(
f: lu.WrappedFun, orig_type: Optional[core.InputType],
axis_size: core.AxisSize, axis_name: AxisName,
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
segment_lens: Sequence[Array],
) -> lu.WrappedFun:
if orig_type is None: return f
# By convention, `explicit_in_dims` only accounts for explicit arguments.
assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type)
# We need to:
# * if `axis_size` is dynamic, add a new implicit binder (type) for it;
# * for each element of `segment_lengths`, add a new explicit binder for it;
# * drop other implicit binders, replacing DBIdx which refer to them with
# Name objects;
# * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
# size if `axis_size` is int, otherwise Name); if ConcatAxis-valued in_dim,
# add batch axis (int if corresponding segment_lengths is concrete, Name if
# not);
# * generate full in_type with implicit args too.
class Name:
def __init__(self, a): self.a = a
names = [Name(a) for a, _ in orig_type]
avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d # type: ignore
for d in a.shape))
if type(a) is core.DShapedArray else a for a, e in orig_type if e]
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
for a, d in zip(avals, explicit_in_dims):
if isinstance(d, ConcatAxis):
s = segment_lens[d.segment_lengths.val]
if isinstance(core.get_aval(s), core.ConcreteArray):
shape = list(a.shape) # type: ignore
shape[d.axis] = int(s.sum()) # specialize on shape if we can
new_avals.append(a.update(shape=tuple(shape)))
else:
new_avals.append(a)
else:
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
mentioned = {d for a in new_avals if type(a) is core.DShapedArray
for d in a.shape if type(d) is Name}
expl_names = set(map(Name, new_avals))
impl_names = mentioned - expl_names # type: ignore
impl_part = [(n.a, False) for n in impl_names] # type: ignore
name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))}
expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape))
if type(a) is core.DShapedArray else a, True) for a in new_avals]
return lu.annotate(f, (*impl_part, *expl_part))
### vmappable typeclass
Vmappable = Any
Elt = Any
MapSpec = Any
AxisSize = Any
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
MakeIotaHandler = Callable[[AxisSize], Array]
def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
handler = to_elt_handlers.get(type(x))
if handler:
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
elif type(x) is Pile:
if spec is not pile_axis:
raise TypeError("pile input without using pile_axis in_axes spec")
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
if type(sz) is IndexedAxisSize)
return BatchTracer(trace, x.data, ConcatAxis(d, ias.lengths)) # type: ignore
elif isinstance(spec, int) or spec is None:
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
return (BatchTracer(trace, x, spec, source_info_util.current())
if spec is not None else x)
else:
assert False
to_elt_handlers: Dict[Type, ToEltHandler] = {}
def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
) -> Vmappable:
handler = from_elt_handlers.get(type(x))
if handler:
return handler(partial(from_elt, trace), axis_size, x, spec)
x_ = trace.full_raise(x)
val, bdim = x_.val, x_.batch_dim
if type(bdim) is ConcatAxis:
if spec is not pile_axis:
# TODO(mattjj): improve this error message
raise TypeError("ragged output without using pile_axis out_axes spec")
return _pile_result(axis_size, bdim.axis, bdim.segment_lengths, val)
else:
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
from_elt_handlers: Dict[Type, FromEltHandler] = {}
def make_iota(axis_size: AxisSize) -> Array:
handler = make_iota_handlers.get(type(axis_size))
if handler:
return handler(axis_size)
else:
return jax.lax.iota('int32', int(axis_size))
make_iota_handlers: Dict[Type, MakeIotaHandler] = {}
def register_vmappable(data_type: Type, spec_type: Type, axis_size_type: Type,
to_elt: Callable, from_elt: Callable,
make_iota: Optional[Callable]):
vmappables[data_type] = (spec_type, axis_size_type)
spec_types.add(spec_type)
to_elt_handlers[data_type] = to_elt
from_elt_handlers[data_type] = from_elt
if make_iota: make_iota_handlers[axis_size_type] = make_iota
vmappables: Dict[Type, Tuple[Type, Type]] = {}
spec_types: Set[Type] = {PileAxis}
def unregister_vmappable(data_type: Type) -> None:
spec_type, axis_size_type = vmappables.pop(data_type)
spec_types.remove(spec_type)
del to_elt_handlers[data_type]
del from_elt_handlers[data_type]
if axis_size_type in make_iota_handlers:
del make_iota_handlers[axis_size_type]
def is_vmappable(x: Any) -> bool:
return type(x) is Pile or type(x) in vmappables
@lu.transformation_with_aux
def flatten_fun_for_vmap(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans, is_leaf=is_vmappable)
### tracer
# TODO(mattjj): use a special sentinel type rather than None
NotMapped = type(None)
not_mapped = None
class BatchTracer(Tracer):
__slots__ = ['val', 'batch_dim', 'source_info']
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, ConcatAxis],
source_info: Optional[source_info_util.SourceInfo] = None):
if config.jax_enable_checks:
assert type(batch_dim) in (NotMapped, int, ConcatAxis)
if type(batch_dim) is int:
aval = raise_to_shaped(core.get_aval(val))
assert 0 <= batch_dim < len(aval.shape) # type: ignore
self._trace = trace
self.val = val
self.batch_dim = batch_dim
self.source_info = source_info
@property
def aval(self):
aval = raise_to_shaped(core.get_aval(self.val))
if self.batch_dim is not_mapped:
return aval
elif type(self.batch_dim) is int:
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
elif type(self.batch_dim) is ConcatAxis:
shape = list(aval.shape)
size_tracer = BatchTracer(self._trace, self.batch_dim.segment_lengths, 0)
shape[self.batch_dim.axis] = size_tracer
return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
weak_type=aval.weak_type)
def full_lower(self):
if self.batch_dim is not_mapped:
return core.full_lower(self.val)
else:
return self
def _origin_msg(self):
if self.source_info is None:
return ""
return (f"\nThis BatchTracer with object id {id(self)} was created on line:"
f"\n {source_info_util.summarize(self.source_info)}")
def _contents(self):
return [('val', self.val), ('batch_dim', self.batch_dim)]
def get_referent(self):
if self.batch_dim is None or type(self.batch_dim) is int:
return core.get_referent(self.val)
else: # TODO(mattjj): could handle the ConcatAxis case?
return self
class BatchTrace(Trace):
def __init__(self, *args, axis_name, spmd_axis_name = None):
super().__init__(*args)
self.axis_name = axis_name
self.spmd_axis_name = spmd_axis_name
def pure(self, val):
return BatchTracer(self, val, not_mapped, source_info_util.current())
def lift(self, val):
return BatchTracer(self, val, not_mapped, source_info_util.current())
def sublift(self, val):
return BatchTracer(self, val.val, val.batch_dim, source_info_util.current())
def get_primitive_batcher(self, primitive, frame):
if primitive in primitive_batchers:
return primitive_batchers[primitive]
elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers:
return partial(spmd_axis_primitive_batchers[primitive],
self.spmd_axis_name, frame.size, frame.name,
frame.main_trace.trace_type)
elif primitive in axis_primitive_batchers:
return self.get_axis_primitive_batcher(primitive, frame)
msg = "Batching rule for '{}' not implemented"
raise NotImplementedError(msg.format(primitive))
def get_axis_primitive_batcher(self, primitive, frame):
return partial(axis_primitive_batchers[primitive],
frame.size, frame.name, frame.main_trace.trace_type)
def get_frame(self, vals, dims) -> core.AxisEnvFrame:
if self.axis_name is core.no_axis_name:
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so
# we reconstruct it from the information we have available
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
for x, d in zip(vals, dims) if d is not not_mapped)
axis_size, = core.dedup_referents(sizes)
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
return core.axis_frame(self.axis_name)
def process_primitive(self, primitive, tracers, params):
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
is_axis_primitive = primitive in axis_primitive_batchers
used_names = core.used_axis_names(primitive, params)
if is_axis_primitive and _main_trace_for_axis_names(self.main, used_names):
frame = self.get_frame(vals_in, dims_in)
batcher_primitive = self.get_axis_primitive_batcher(primitive, frame)
val_out, dim_out = batcher_primitive(vals_in, dims_in, **params)
elif all(bdim is not_mapped for bdim in dims_in):
return primitive.bind(*vals_in, **params)
else:
frame = self.get_frame(vals_in, dims_in)
batched_primitive = self.get_primitive_batcher(primitive, frame)
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
src = source_info_util.current()
if primitive.multiple_results:
return [BatchTracer(self, x, d, src) for x, d in zip(val_out, dim_out)]
else:
return BatchTracer(self, val_out, dim_out, src)
def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
params = dict(params, name=params.get('name', f.__name__))
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
for x, d in zip(vals, dims) if d is not not_mapped)
axis_size, = core.dedup_referents(sizes)
segment_lens, dims = unpack_concat_axes(dims)
f_, dims_out = batch_subtrace(f, self.main, tuple(dims))
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims,
segment_lens)
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
vals_out, dims_out = reassemble_concat_axes(vals_out, dims_out())
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
def post_process_call(self, call_primitive, out_tracers, params):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
return map(partial(BatchTracer, trace), vals, dims, srcs)
return vals, todo
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(dim is not_mapped for dim in dims):
return map_primitive.bind(f, *vals, **params)
else:
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
# The logic for the dimension math below is as follows:
# ╔═════════════╦════════════════════════════════════════╦═══════════╗
# ║ d / in_axis ║ None ║ int ║
# ╠═════════════╬════════════════════════════════════════╩═══════════╣
# ║ None ║ No extra axis, so in_axis unaffected ║
# ╠═════════════╬════════════════════════════════════════╦═══════════╣
# ║ int ║ Not mapped, so batching dim unaffected ║ See below ║
# ╚═════════════╩════════════════════════════════════════╩═══════════╝
# When both d and in_axis are defined then:
# - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
# - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
def both_mapped(in_out_axis, d):
return in_out_axis is not None and d is not not_mapped
new_in_axes = tuple(
in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
for d, in_axis in zip(dims, params['in_axes']))
new_dims = tuple(
d - 1 if both_mapped(in_axis, d) and in_axis < d else d
for d, in_axis in zip(dims, params['in_axes']))
f, dims_out = batch_subtrace(f, self.main, new_dims)
out_axes_thunk = params['out_axes_thunk']
# NOTE: This assumes that the choice of the dimensions over which outputs
# are batched is entirely dependent on the function and not e.g. on the
# data or its shapes.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
for out_axis, d in zip(out_axes_thunk(), dims_out()))
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
vals_out = map_primitive.bind(f, *vals, **new_params)
dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
for d, out_axis in zip(dims_out(), out_axes_thunk())]
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)]
def post_process_map(self, call_primitive, out_tracers, params):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
main = self.main
def both_mapped(in_out_axis, d):
return in_out_axis is not None and d is not not_mapped
def todo(vals):
trace = main.with_cur_sublevel()
return [BatchTracer(trace, v, d + 1 if both_mapped(oa, d) and oa <= d else d, s)
for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs)]
if call_primitive.map_primitive:
def out_axes_transform(out_axes):
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
for out_axis, d in zip(out_axes, dims))
todo = (todo, out_axes_transform)
return vals, todo
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
assert out_dims == out_dims[:len(out_dims) // 2] * 2
out_dims = out_dims[:len(out_dims) // 2]
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
if jvp_was_run:
primal_dims, tangent_dims = dims[:len(vals)], dims[len(vals):]
assert primal_dims == tangent_dims
primal_srcs = srcs[:len(vals)]
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
else:
return map(partial(BatchTracer, trace), vals, dims, srcs)
return vals, todo
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): # pytype: disable=signature-mismatch
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
if d is not not_mapped}
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
out_dims2, in_dims, self.main.trace_type,
self.spmd_axis_name)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
_, res_tree = out_trees()
_, out_dims = split_list(out_dims, [res_tree.num_leaves])
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
def post_process_custom_vjp_call(self, out_tracers, _):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
return map(partial(BatchTracer, trace), vals, dims, srcs)
return vals, todo
def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
main, trace_type = self.main, self.main.trace_type
axis_name = self.axis_name
_, res_tree = out_trees()
num_res = res_tree.num_leaves
res_dims, primal_dims = split_list(dims, [num_res])
_, primal_srcs = split_list(srcs, [num_res])
def todo(vals):
trace = main.with_cur_sublevel()
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
def bwd_transform(bwd):
return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,),
trace_type, self.spmd_axis_name)
return vals, todo, bwd_transform
def _main_trace_for_axis_names(main_trace: core.MainTrace,
axis_name: Iterable[AxisName],
) -> bool:
# This function exists to identify whether a main trace corresponds to any of
# the axis names used by a primitive. Axis names alone aren't enough because
# axis names can shadow, so we use the main trace as a tag.
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
### API for batching callables with vmappable inputs and outputs
def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size,
in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace,
spmd_axis_name: Optional[Tuple[AxisName, ...]] = None
) -> lu.WrappedFun:
# we split up _batch_inner and _batch_outer for the leak checker
f = _batch_inner(fun, axis_size, out_dim_dests)
return _batch_outer(f, axis_name, axis_size, in_dims, main_type,
spmd_axis_name)
@lu.transformation
def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name,
*in_vals):
with core.new_main(
main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main:
with core.extend_axis_env(axis_name, axis_size, main):
with source_info_util.transform_name_stack('vmap'):
outs = yield (main, in_dims, *in_vals), {}
del main
yield outs
@lu.transformation
def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals):
in_dims = in_dims() if callable(in_dims) else in_dims
trace = main.with_cur_sublevel()
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_size), 0,
source_info_util.current()))
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
outs = yield in_tracers, {}
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
yield out_vals
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat: lu.WrappedFun,
in_axes_flat: Tuple[Optional[int], ...],
out_axes_flat: Tuple[Optional[int], ...],
tile_size: Optional[int],
axis_name: AxisName,
main_type: Type[BatchTrace] = BatchTrace):
@curry
def tile_axis(arg, axis: Optional[int], tile_size):
if axis is None:
return arg
shape = list(arg.shape)
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
return arg.reshape(shape)
def untile_axis(out, axis: Optional[int]):
if axis is None:
return out
shape = list(out.shape)
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
return out.reshape(shape)
@lu.transformation
def _map_to_tile(*args_flat):
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
tile_size_ = tile_size or next(sizes, None)
assert tile_size_ is not None, "No mapped arguments?"
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
yield map(untile_axis, outputs_flat, out_axes_flat)
return _map_to_tile(batch(
f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type))
### API for batching functions with jaxpr type inputs and outputs
@lu.transformation_with_aux
def batch_subtrace(main, in_dims, *in_vals):
trace = main.with_cur_sublevel()
in_dims = in_dims() if callable(in_dims) else in_dims
in_vals, in_dims = reassemble_concat_axes(in_vals, in_dims)
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
segment_lens, out_dims = unpack_concat_axes(out_dims)
yield (*segment_lens, *out_vals), out_dims
def unpack_concat_axes(dims):
if not any(type(d) is ConcatAxis for d in dims):
return [], dims
concat_axis_map = collections.OrderedDict()
def convert(d: ConcatAxis) -> ConcatAxis:
_, dbidx = concat_axis_map.setdefault(
id(core.get_referent(d.segment_lengths)),
(d.segment_lengths, pe.DBIdx(len(concat_axis_map))))
return ConcatAxis(d.axis, dbidx)
new_dims = [convert(d) if isinstance(d, ConcatAxis) else d for d in dims]
segment_lens = [s for s, _ in concat_axis_map.values()]
return segment_lens, new_dims
def reassemble_concat_axes(vals, dims):
idxs = {d.segment_lengths.val for d in dims if isinstance(d, ConcatAxis)}
dims = [ConcatAxis(d.axis, vals[d.segment_lengths.val])
if isinstance(d, ConcatAxis) else d for d in dims]
vals = [x for i, x in enumerate(vals) if i not in idxs]
return vals, dims
### API for batching jaxprs
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize,
in_axes: Tuple[Union[int, NotMapped], ...],
axis_name: AxisName,
spmd_axis_name: AxisName,
main_type: Type[BatchTrace],
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
spmd_axis_name, main_type)
@weakref_lru_cache
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize,
in_axes: Tuple[Union[int, NotMapped], ...],
axis_name: AxisName,
spmd_axis_name: AxisName,
main_type: Type[BatchTrace],
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_axes = _batch_jaxpr_inner(f, axis_size)
f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes,
main_type)
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval)
if b is not not_mapped else aval
for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
spmd_axis_name, main_type):
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _batch_jaxpr(closed_jaxpr, axis_size, tuple(in_batched), inst,
axis_name, spmd_axis_name, main_type)
def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
spmd_axis_name, main_type):
assert (isinstance(instantiate, bool) or
isinstance(instantiate, (list, tuple)) and
all(isinstance(b, bool) for b in instantiate))
if isinstance(instantiate, bool):
instantiate = [instantiate] * len(closed_jaxpr.out_avals)
in_axes = [0 if b else not_mapped for b in in_batched]
out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate]
return batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
axis_name, spmd_axis_name, main_type)
def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name,
spmd_axis_name, main_type):
return _batch_jaxpr_axes(closed_jaxpr, axis_size, tuple(in_axes),
tuple(out_axes_dest), axis_name, spmd_axis_name,
main_type)
@weakref_lru_cache
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
axis_name, spmd_axis_name, main_type):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_axes = _batch_jaxpr_inner(f, axis_size)
f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes)
f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes,
main_type)
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
@lu.transformation_with_aux
def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals):
trace = main.with_cur_sublevel()
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_axes)]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals, out_axes = unzip2((t.val, t.batch_dim) for t in out_tracers)
yield out_vals, out_axes
@lu.transformation_with_aux
def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes,
*in_vals):
trace = main.with_cur_sublevel()
out_vals = yield (main, in_axes, *in_vals), {}
out_axes = out_axes()
out_axes_dest = [(None if src is not_mapped else 0)
if dst is zero_if_mapped else dst
for src, dst in unsafe_zip(out_axes, out_axes_dest)]
if len(out_axes_dest) != len(out_axes):
out_axis_dest, = out_axes_dest
out_axes_dest = [out_axis_dest] * len(out_axes)
out_vals = map(partial(matchaxis, trace.axis_name, axis_size),
out_axes, out_axes_dest, out_vals)
out_batched = [dst is not None for dst in out_axes_dest]
yield out_vals, out_batched
@lu.transformation
def _batch_jaxpr_outer(axis_name, spmd_axis_name, axis_size, in_dims, main_type,
*in_vals):
if axis_size is None:
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
in_dims = in_dims() if callable(in_dims) else in_dims
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
else ax for x, ax in unsafe_zip(in_vals, in_dims)]
with core.new_main(main_type, axis_name=axis_name,
spmd_axis_name=spmd_axis_name) as main:
with core.extend_axis_env(axis_name, axis_size, main):
out_vals = yield (main, in_dims, *in_vals), {}
del main
yield out_vals
def _merge_bdims(x, y):
if x == y:
return x
elif x is not_mapped:
return y
elif y is not_mapped:
return x
else:
return x # arbitrary
class ZeroIfMapped: pass
zero_if_mapped = ZeroIfMapped()
### functions for handling custom_vjp
@lu.transformation_with_aux
def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
size, = {x.shape[d] for x, d in zip(in_vals, in_dims * 2)
if d is not not_mapped}
trace = main.with_cur_sublevel()
in_tracers = [val if dim is None else
SymbolicZero(core.mapped_aval(size, dim, val.aval))
if type(val) is SymbolicZero else BatchTracer(trace, val, dim)
for val, dim in zip(in_vals, in_dims * 2)]
outs = yield in_tracers, {}
# TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can
# be wasteful in the rare case it actually triggers; handle symbolically!
outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs]
out_tracers = map(trace.full_raise, outs)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
out_primals = map(partial(matchaxis, trace.axis_name, size),
out_primal_bds, out_dims, out_primals)
out_tangents = map(partial(matchaxis, trace.axis_name, size),
out_tangent_bds, out_dims, out_tangents)
yield out_primals + out_tangents, out_dims * 2
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests,
main_type, spmd_axis_name):
def new_bwd(*args):
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd))
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type,
spmd_axis_name)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
out_dim_dests)
return bwd_.call_wrapped(*args)
return new_bwd
@lu.transformation
def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
# this is like _match_axes, but we do reduce-sums as needed
out_vals = yield in_vals, {}
yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
sum_match=True), out_dims_thunk(), out_dim_dests, out_vals)
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
# TODO(mattjj): dedup with matchaxis
if isinstance(x, Zero):
if src == dst:
return x
elif type(src) == type(dst) == int:
aval = core.mapped_aval(sz, src, x.aval)
return Zero(core.unmapped_aval(sz, name, dst, aval))
elif src is not_mapped and dst is not not_mapped:
return Zero(core.unmapped_aval(sz, name, dst, x.aval))
elif dst is not_mapped and sum_match:
return Zero(core.mapped_aval(sz, src, x.aval))
else:
raise ValueError((axis_name, x, src, dst))
else:
return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
### utilities for defining primitives' batching rules
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim)
def vectorized_batcher(prim, batched_args, batch_dims, **params):
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
return prim.bind(*batched_args, **params), batch_dims[0]
def defbroadcasting(prim):
primitive_batchers[prim] = partial(broadcast_batcher, prim)
def broadcast_batcher(prim, args, dims, **params):
"""Process a primitive with built-in broadcasting.
Args:
args: the possibly-batched arguments
dims: list or tuple of the same length as `args`, where each
entry indicates the batching state of the corresponding entry to `args`:
either an int indicating the batch dimension, or else `not_mapped`
indicating no batching.
"""
assert len(args) > 1
shape, dim = next((x.shape, d) for x, d in zip(args, dims)
if d is not not_mapped)
if all(core.symbolic_equal_shape(shape, x.shape) and d == dim
for x, d in zip(args, dims) if np.ndim(x)):
# if there's only agreeing batch dims and scalars, just call the primitive
out = prim.bind(*args, **params)
return (out, (dim,) * len(out)) if prim.multiple_results else (out, dim)
else:
# We pass size of 1 here because (1) at least one argument has a real batch
# dimension and (2) all unmapped axes can have a singleton axis inserted and
# then rely on the primitive's built-in broadcasting.
args = [bdim_at_front(x, d, 1) if np.ndim(x) else x
for x, d in zip(args, dims)]
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
out = prim.bind(*args, **params)
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
def _handle_scalar_broadcasting(nd, x, d):
if d is not_mapped or nd == np.ndim(x):
return x
else:
return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
def defreducer(prim):
primitive_batchers[prim] = partial(reducer_batcher, prim)
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
operand, = batched_args
bdim, = batch_dims
if isinstance(bdim, int):
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
elif isinstance(bdim, ConcatAxis):
if bdim.axis in axes:
other_axes = [i for i in axes if i != bdim.axis]
if other_axes:
operand = prim.bind(operand, axes=other_axes, **params)
c_axis = bdim.axis - sum(d < bdim.axis for d in other_axes)
operand = bdim_at_front(operand, c_axis, operand.shape[c_axis])
return segment_sum(operand, bdim.segment_lengths), 0
else:
raise NotImplementedError # TODO(mattjj)
else:
assert False
# TODO(mattjj): replace with jax.lax.ops.segment_sum (once it's easier to trace
# under dynamic shapes)
def segment_sum(operand, segment_lens):
scat_idx = jax.numpy.cumsum(segment_lens) - segment_lens
segment_ids = jax.numpy.cumsum(
jax.numpy.zeros(operand.shape[0], 'int32').at[scat_idx].set(1)) - 1
out = jax.numpy.zeros((len(segment_lens), *operand.shape[1:]),
operand.dtype).at[segment_ids].add(operand)
return out
### general utilities for manipulating axes on jaxpr types (not vmappables)
def broadcast(x, sz, axis):
shape = list(np.shape(x))
shape.insert(axis, sz)
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
if dst == pile_axis:
x = bdim_at_front(x, src, sz)
elt_ty = x.aval.update(shape=x.shape[1:])
aval = PileTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
x.shape[0], elt_ty)
return Pile(aval, x)
try:
_ = core.get_aval(x)
except TypeError as e:
raise TypeError(f"Output from batched function {repr(x)} with type "
f"{type(x)} is not a valid JAX type") from e
if src == dst:
return x
elif type(src) == type(dst) == int:
return moveaxis(x, src, dst)
elif src is not_mapped and dst is not not_mapped:
return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
elif dst is not_mapped and sum_match:
return x.sum(src)
else:
if (not isinstance(axis_name, core._TempAxisName) and
axis_name is not core.no_axis_name):
raise ValueError(f'vmap has mapped output ({axis_name=}) but out_axes is {dst}')
else:
raise ValueError(f'vmap has mapped output but out_axes is {dst}')
def bdim_at_front(x, bdim, size):
if bdim is not_mapped:
return broadcast(x, size, 0)
else:
return moveaxis(x, bdim, 0)
# sets up primitive batchers for ad_util and xla primitives
def add_batched(batched_args, batch_dims):
bdx, bdy = batch_dims
x, y = batched_args
if bdx == bdy:
return add_jaxvals(x, y), bdx
elif bdx is not_mapped:
x = broadcast(x, y.shape[bdy], bdy)
return add_jaxvals(x, y), bdy
elif bdy is not_mapped:
y = broadcast(y, x.shape[bdx], bdx)
return add_jaxvals(x, y), bdx
else:
x = moveaxis(x, bdx, bdy)
return add_jaxvals(x, y), bdy
primitive_batchers[add_jaxvals_p] = add_batched
def zeros_like_batched(batched_args, batch_dims):
val, = batched_args
bdim, = batch_dims
return zeros_like_jaxval(val), bdim
primitive_batchers[zeros_like_p] = zeros_like_batched