-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
maps.py
2175 lines (1900 loc) · 98.2 KB
/
maps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import contextlib
import numpy as np
import itertools as it
from collections import OrderedDict, abc
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
NamedTuple, Union, Sequence)
from warnings import warn
from functools import wraps, partial, partialmethod
from enum import Enum
from jax import numpy as jnp
from jax import core
from jax import linear_util as lu
from jax import stages
from jax._src.api import _check_callable, _check_arg
from jax._src import dispatch
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
tree_leaves, treedef_tuple)
from jax._src.tree_util import _replace_nones
from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
_ensure_index_tuple, donation_vector,
shaped_abstractify)
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax.errors import JAXTypeError
from jax.experimental.global_device_array import GlobalDeviceArray, _get_array_mapping
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import ad
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.util import (safe_map, safe_zip, HashableFunction,
as_hashable_function, unzip2, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name)
from jax import lax
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = safe_map, map
zip = safe_zip
xops = xc.ops
class _PositionalSemantics(Enum):
"""Indicates whether the positional shapes of inputs should be interpreted as
global or local with respect to the multi-host mesh.
While named axes are always associated with global sizes, the outermost pjit
is the boundary between the local shapes in the outer scope and global
positional shapes in its inner scope. pjits nested inside that one should not
attempt to increase the sizes of avals again, and xmap has to take this into
account when inferring the global size of a named axis.
"""
LOCAL = 0
GLOBAL = 1
class _PSThreadLocalState(threading.local):
def __init__(self):
self.val = _PositionalSemantics.LOCAL
_positional_semantics = _PSThreadLocalState()
class FrozenDict(abc.Mapping):
def __init__(self, *args, **kwargs):
self.contents = dict(*args, **kwargs)
def __iter__(self):
return iter(self.contents)
def __len__(self):
return len(self.contents)
def __getitem__(self, name):
return self.contents[name]
def __eq__(self, other):
return isinstance(other, FrozenDict) and self.contents == other.contents
def __hash__(self):
return hash(tuple(self.contents.items()))
def __repr__(self):
return f"FrozenDict({self.contents})"
# Multi-dimensional generalized map
AxisName = core.AxisName
ResourceAxisName = AxisName # Different name just for documentation purposes
Mesh = pxla.Mesh
ResourceEnv = pxla.ResourceEnv
EMPTY_ENV = pxla.EMPTY_ENV
thread_resources = pxla.thread_resources
class SerialLoop:
"""Create an anonymous serial loop resource for use in a single xmap axis.
A use of :py:class:`SerialLoop` in :py:func:`xmap`'s ``axis_resources``
extends the resource environment with a new serial loop with a unique
unspecified name, that will only be used to partition the axis that
used a given instance.
This is unlike :py:func:`serial_loop`, which makes it possible to iterate
jointly over chunks of multiple axes (with the usual requirement that they
do not coincide in a named shape of any value in the program).
Example::
# Processes `x` in a vectorized way, but in 20 micro-batches.
xmap(f, in_axes=['i'], out_axes=[i], axis_resources={'i': SerialLoop(20)})(x)
# Computes the result in a vectorized way, but in 400 micro-batches,
# once for each coordinate (0, 0) <= (i, j) < (20, 20). Each `SerialLoop`
# creates a fresh anonymous loop.
xmap(h, in_axes=(['i'], ['j']), out_axes=['i', 'j'],
axis_resources={'i': SerialLoop(20), 'j': SerialLoop(20)})(x, y)
"""
length: int
def __init__(self, length):
self.length = length
def __eq__(self, other):
return self.length == other.length
def __hash__(self):
return hash(self.length)
@contextlib.contextmanager
def serial_loop(name: ResourceAxisName, length: int):
"""Define a serial loop resource to be available in scope of this context manager.
This is similar to :py:func:`mesh` in that it extends the resource
environment with a resource called ``name``. But, any use of this resource
axis in ``axis_resources`` argument of :py:func:`xmap` will cause the
body of :py:func:`xmap` to get executed ``length`` times with each execution
only processing only a slice of inputs mapped along logical axes assigned
to this resource.
This is especially useful in that it makes it possible to lower the memory
usage compared to :py:func:`vmap`, because it will avoid simultaneous
materialization of intermediate values for every point in the batch.
Note that collectives over loop axes are not supported, so they are less
versatile than physical mesh axes.
Args:
name: Name of the loop in the resource environment.
length: Number of iterations.
Example::
with loop('l', 4):
out = xmap(
lambda x: jnp.sin(x) * 5, # This will be called 4 times with different
# slices of x.
in_axes=['i'], out_axes=['i'],
axis_resources={'i': 'l'})(x)
"""
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = old_env.with_extra_loop(pxla._Loop(name, length))
try:
yield
finally:
thread_resources.env = old_env
@contextlib.contextmanager
def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]):
"""Declare the hardware resources available in the scope of this manager.
In particular, all ``axis_names`` become valid resource names inside the
managed block and can be used e.g. in the ``axis_resources`` argument of
:py:func:`xmap`.
If you are compiling in multiple threads, make sure that the
``with mesh`` context manager is inside the function that the threads will
execute.
Args:
devices: A NumPy ndarray object containing JAX device objects (as
obtained e.g. from :py:func:`jax.devices`).
axis_names: A sequence of resource axis names to be assigned to the
dimensions of the ``devices`` argument. Its length should match the
rank of ``devices``.
Example::
devices = np.array(jax.devices())[:4].reshape((2, 2))
with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
distributed_out = xmap(
jnp.vdot,
in_axes=({0: 'left', 1: 'right'}),
out_axes=['left', 'right', ...],
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
"""
warn("`maps.mesh` context manager is deprecated. Please use `maps.Mesh`.",
FutureWarning)
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = old_env.with_mesh(Mesh(np.asarray(devices, dtype=object), axis_names))
try:
yield
finally:
thread_resources.env = old_env
_next_resource_id = 0
class _UniqueResourceName:
def __init__(self, uid, tag=None):
self.uid = uid
self.tag = tag
def __eq__(self, other):
return type(other) is _UniqueResourceName and self.uid == other.uid
def __hash__(self):
return hash(self.uid)
def __repr__(self):
return f"<UniqueResource {self.tag} {self.uid}>"
def fresh_resource_name(tag=None):
global _next_resource_id
try:
return _UniqueResourceName(_next_resource_id, tag)
finally:
_next_resource_id += 1
# This is really a Dict[AxisName, int], but we don't define a
# pytree instance for it, so that it is treated as a leaf.
class AxisNamePos(FrozenDict):
user_repr: str
expected_rank: Optional[int] = None
def __init__(self, *args, user_repr, **kwargs):
super().__init__(*args, **kwargs)
self.user_repr = user_repr
class AxisNamePosWithRank(AxisNamePos):
def __init__(self, *args, expected_rank, **kwargs):
super().__init__(*args, **kwargs)
self.expected_rank = expected_rank
# str(...) == 'Ellipsis' which is really annoying
class DotDotDotRepr:
def __repr__(self): return '...'
def _parse_entry(arg_name, entry):
# Dictionaries mapping axis names to positional axes
if isinstance(entry, dict) and all(isinstance(v, int) for v in entry.keys()):
result = AxisNamePos(((name, axis) for axis, name in entry.items()),
user_repr=str(entry))
num_mapped_dims = len(entry)
# Non-empty lists or tuples that optionally terminate with an ellipsis
elif isinstance(entry, (tuple, list)):
if entry and entry[-1] == ...:
constr = AxisNamePos
entry = entry[:-1]
tail = [DotDotDotRepr()] if isinstance(entry, list) else (DotDotDotRepr(),)
user_repr = str(entry + tail)
else:
constr = partial(AxisNamePosWithRank, expected_rank=len(entry))
user_repr = str(entry)
result = constr(((name, axis) for axis, name in enumerate(entry)
if name is not None),
user_repr=user_repr)
num_mapped_dims = sum(name is not None for name in entry)
else:
raise TypeError(f"""\
Value mapping specification in xmap {arg_name} pytree can be either:
- lists of axis names (possibly ending with the ellipsis object: ...)
- dictionaries that map positional axes (integers) to axis names (e.g. {2: 'name'})
but got: {entry}""")
if len(result) != num_mapped_dims:
raise ValueError(f"Named axes should be unique within each {arg_name} argument "
f"specification, but one them is: {entry}")
for axis in result.values():
if axis < 0:
raise ValueError(f"xmap doesn't support negative axes in {arg_name}")
return result
def _is_axes_leaf(entry):
if isinstance(entry, dict) and all_leaves(entry.values()):
return True
# NOTE: `None`s are not considered leaves by `all_leaves`
if isinstance(entry, (tuple, list)) and all_leaves(v for v in entry if v is not None):
return True
return False
def _prepare_axes(axes, arg_name):
entries, treedef = tree_flatten(axes, is_leaf=_is_axes_leaf)
entries = map(partial(_parse_entry, arg_name), entries)
return tree_unflatten(treedef, entries), entries, treedef
Resource = Union[ResourceAxisName, SerialLoop]
ResourceSet = Union[Resource, Tuple[Resource, ...]]
# TODO: Some syntactic sugar to make the API more usable in a single-axis case?
# TODO: Are the resource axes scoped lexically or dynamically? Dynamically for now!
def xmap(fun: Callable,
in_axes,
out_axes,
*,
axis_sizes: Dict[AxisName, int] = {},
axis_resources: Dict[AxisName, ResourceSet] = {},
donate_argnums: Union[int, Sequence[int]] = (),
backend: Optional[str] = None) -> stages.Wrapped:
"""Assign a positional signature to a program that uses named array axes.
.. warning::
This is an experimental feature and the details can change at
any time. Use at your own risk!
.. warning::
This docstring is aspirational. Not all features of the named axis
programming model have been implemented just yet.
The usual programming model of JAX (or really NumPy) associates each array
with two pieces of metadata describing its type: the element type (``dtype``)
and the ``shape``. :py:func:`xmap` extends this model by adding support for
*named axes*. In particular, each array used in a function wrapped by
:py:func:`xmap` can additionally have a non-empty ``named_shape`` attribute,
which can be used to query the set of named axes (introduced by
:py:func:`xmap`) appearing in that value along with their shapes.
Furthermore, in most places where positional axis indices are allowed (for
example the `axes` arguments in :py:func:`sum`), bound axis names are also
accepted. The :py:func:`einsum` language is extended inside :py:func:`xmap`
to additionally allow contractions that involve named axes. Broadcasting of
named axes happens *by name*, i.e. all axes with equal names are expected to
have equal shapes in all arguments of a broadcasting operation, while the
result has a (set) union of all named axes. The positional semantics of the
program remain unchanged, and broadcasting still implicitly right-aligns
positional axes for unification. For an extended description of the
:py:func:`xmap` programming model, please refer to the :py:func:`xmap`
tutorial notebook in main JAX documentation.
Note that since all top-level JAX expressions are interpreted in the NumPy
programming model, :py:func:`xmap` can also be seen as an adapter that
converts a function that uses named axes (including in arguments and returned
values) into one that takes and returns values that only have positional
axes.
The default lowering strategy of :py:func:`xmap` converts all named axes into
positional axes, working similarly to multiple applications of
:py:func:`vmap`. However, this behavior can be further customized by the
``axis_resources`` argument. When specified, each axis introduced by
:py:func:`xmap` can be assigned to one or more *resource axes*. Those include
the axes of the hardware mesh, as defined by the :py:func:`mesh` context
manager. Each value that has a named axis in its ``named_shape`` will be
partitioned over all mesh axes that axis is assigned to. Hence,
:py:func:`xmap` can be seen as an alternative to :py:func:`pmap` that also
exposes a way to automatically partition the computation over multiple
devices.
.. warning::
While it is possible to assign multiple axis names to a single resource axis,
care has to be taken to ensure that none of those named axes co-occur in a
``named_shape`` of any value in the named program. At the moment this is
**completely unchecked** and will result in **undefined behavior**. The
final release of :py:func:`xmap` will enforce this invariant, but it is a
work in progress.
Note that you do not have to worry about any of this for as long as no
resource axis is repeated in ``axis_resources.values()``.
Note that any assignment of ``axis_resources`` doesn't ever change the
results of the computation, but only how it is carried out (e.g. how many
devices are used). This makes it easy to try out various ways of
partitioning a single program in many distributed scenarios (both small- and
large-scale), to maximize the performance. As such, :py:func:`xmap` can be
seen as a way to seamlessly interpolate between :py:func:`vmap` and
:py:func:`pmap`-style execution.
Args:
fun: Function that uses named axes. Its arguments and return
value should be arrays, scalars, or (nested) standard Python containers
(tuple/list/dict) thereof (in general: valid pytrees).
in_axes: A Python object with the same container (pytree) structure as the
signature of arguments to ``fun``, but with a positional-to-named axis
mapping in place of every array argument. The valid positional-to-named
mappings are: (1) a ``Dict[int, AxisName]`` specifying that a positional
dimensions given by dictionary keys are to be converted to named axes
of given names (2) a list of axis names that ends with the Ellipsis object
(``...``) in which case a number of leading positional axes of the argument
will be converted into named axes inside the function. Note that ``in_axes``
can also be a prefix of the argument container structure, in which case the
mapping is repeated for all arrays in the collapsed subtree.
out_axes: A Python object with the same container (pytree) structure as the
returns of ``fun``, but with a positional-to-named axis mapping in place
of every returned array. The valid positional-to-named mappings are the same
as in ``in_axes``. Note that ``out_axes`` can also be a prefix of the return
container structure, in which case the mapping is repeated for all arrays
in the collapsed subtree.
axis_sizes: A dict mapping axis names to their sizes. All axes defined by xmap
have to appear either in ``in_axes`` or ``axis_sizes``. Sizes of axes
that appear in ``in_axes`` are inferred from arguments whenever possible.
In multi-host scenarios, the user-specified sizes are expected to be the
global axis sizes (and might not match the expected size of local inputs).
axis_resources: A dictionary mapping the axes introduced in this
:py:func:`xmap` to one or more resource axes. Any array that has in its
shape an axis with some resources assigned will be partitioned over the
resources associated with the respective resource axes.
donate_argnums: Specify which argument buffers are "donated" to the computation.
It is safe to donate argument buffers if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
Returns:
A version of ``fun`` that takes in arrays with positional axes in place of
named axes bound in this :py:func:`xmap` call, and results with all named
axes converted to positional axes. If ``axis_resources`` is specified,
``fun`` can additionally execute in parallel on multiple devices.
For example, :py:func:`xmap` makes it very easy to convert a function that
computes the vector inner product (such as :py:func:`jax.numpy.vdot`) into
one that computes a matrix multiplication:
>>> import jax.numpy as jnp
>>> x = jnp.arange(10).reshape((2, 5))
>>> xmap(jnp.vdot,
... in_axes=({0: 'left'}, {1: 'right'}),
... out_axes=['left', 'right', ...])(x, x.T)
DeviceArray([[ 30, 80],
[ 80, 255]], dtype=int32)
Note that the contraction in the program is performed over the positional axes,
while named axes are just a convenient way to achieve batching. While this
might seem like a silly example at first, it might turn out to be useful in
practice, since with conjuction with ``axis_resources`` this makes it possible
to implement a distributed matrix-multiplication in just a few lines of code::
devices = np.array(jax.devices())[:4].reshape((2, 2))
with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
distributed_out = xmap(
jnp.vdot,
in_axes=({0: 'left'}, {1: 'right'}),
out_axes=['left', 'right', ...],
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
Still, the above examples are quite simple. After all, the xmapped
computation was a simple NumPy function that didn't use the axis names at all!
So, let's explore a slightly larger example which is linear regression::
def regression_loss(x, y, w, b):
# Contract over in_features. Batch and out_features are present in
# both inputs and output, so they don't need to be mentioned
y_pred = jnp.einsum('{in_features},{in_features}->{}', x, w) + b
error = jnp.sum((y - y_pred) ** 2, axis='out_features')
return jnp.mean(error, axis='batch')
xmap(regression_loss,
in_axes=(['batch', 'in_features', ...],
['batch', 'out_features', ...],
['in_features', 'out_features', ...],
['out_features', ...]),
out_axes={}) # Loss is reduced over all axes, including batch!
.. note::
When using ``axis_resources`` along with a mesh that is controlled by
multiple JAX hosts, keep in mind that in any given process :py:func:`xmap`
only expects the data slice that corresponds to its local devices to be
specified. This is in line with the current multi-host :py:func:`pmap`
programming model.
"""
warn("xmap is an experimental feature and probably has bugs!")
_check_callable(fun)
if isinstance(in_axes, list) and not _is_axes_leaf(in_axes):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axes = tuple(in_axes)
if in_axes == (): # Allow empty argument lists
in_axes, in_axes_entries = (), []
else:
in_axes, in_axes_entries, _ = _prepare_axes(in_axes, "in_axes")
if out_axes == ():
raise ValueError("xmapped functions cannot have no return values")
else:
out_axes, out_axes_entries, out_axes_treedef = _prepare_axes(out_axes, "out_axes")
out_axes_entries = tuple(out_axes_entries) # Make entries hashable
axis_sizes_names = set(axis_sizes.keys())
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
defined_names = axis_sizes_names | in_axes_names
out_axes_names = set(it.chain(*(spec.keys() for spec in out_axes_entries)))
anon_serial_loops = []
def normalize_resource(r) -> ResourceAxisName:
if isinstance(r, SerialLoop):
name = fresh_resource_name()
anon_serial_loops.append((name, r.length))
return name
return r
normalized_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]] = {}
for axis in defined_names:
resources = axis_resources.get(axis, ())
if not isinstance(resources, tuple):
resources = (resources,)
normalized_axis_resources[axis] = tuple(unsafe_map(normalize_resource, resources))
frozen_axis_resources = FrozenDict(normalized_axis_resources)
necessary_resources = set(it.chain(*frozen_axis_resources.values()))
axes_with_resources = set(frozen_axis_resources.keys())
if axes_with_resources > defined_names:
raise ValueError(f"All axes that were assigned resources have to appear in "
f"in_axes or axis_sizes, but the following are missing: "
f"{axes_with_resources - defined_names}")
if out_axes_names > defined_names:
raise ValueError(f"All axis names appearing in out_axes must also appear in "
f"in_axes or axis_sizes, but the following are missing: "
f"{out_axes_names - defined_names}")
for axis, resources in frozen_axis_resources.items():
if len(set(resources)) != len(resources): # type: ignore
raise ValueError(f"Resource assignment of a single axis must be a tuple of "
f"distinct resources, but specified {resources} for axis {axis}")
donate_argnums = _ensure_index_tuple(donate_argnums)
# A little performance optimization to avoid iterating over all args unnecessarily
has_input_rank_assertions = any(spec.expected_rank is not None for spec in in_axes_entries)
has_output_rank_assertions = any(spec.expected_rank is not None for spec in out_axes_entries)
def infer_params(*args):
# Putting this outside of fun_mapped would make resources lexically scoped
resource_env = thread_resources.env
available_resources = set(resource_env.shape.keys())
if necessary_resources - available_resources:
raise ValueError(f"In-scope resources are insufficient to execute the "
f"xmapped function. The missing resources are: "
f"{necessary_resources - available_resources}")
args_flat, in_tree = tree_flatten(args)
fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
if donate_argnums:
donated_invars = donation_vector(donate_argnums, args, ())
else:
donated_invars = (False,) * len(args_flat)
in_axes_flat = _flatten_axes("xmap in_axes", in_tree, in_axes, tupled_args=True)
# Some pytree containers might be unhashable, so we flatten the out_axes
# pytree into a treedef and entries which are guaranteed to be hashable.
out_axes_thunk = HashableFunction(
lambda: tuple(_flatten_axes("xmap out_axes", out_tree(), out_axes, tupled_args=False)),
closure=(out_axes_entries, out_axes_treedef))
in_positional_semantics = tuple(
_PositionalSemantics.GLOBAL
if isinstance(a, GlobalDeviceArray) else _positional_semantics.val
for a in args_flat)
out_positional_semantics = _positional_semantics.val
axis_resource_count = _get_axis_resource_count(
frozen_axis_resources, resource_env, in_positional_semantics)
for axis, size in axis_sizes.items():
resources = axis_resource_count[axis]
if size % resources.nglobal != 0:
global_size = "Global size" if resources.distributed else "Size"
raise ValueError(f"{global_size} of axis {axis} ({size}) is not divisible "
f"by the total number of resources assigned to this axis "
f"({frozen_axis_resources[axis]}, {resources.nglobal} in total)")
frozen_global_axis_sizes = _get_axis_sizes(
args_flat, in_axes_flat, axis_sizes, axis_resource_count,
in_positional_semantics)
missing_sizes = defined_names - set(frozen_global_axis_sizes.keys())
if missing_sizes:
raise ValueError(f"Failed to infer size of axes: {', '.join(unsafe_map(str, missing_sizes))}. "
f"You've probably passed in empty containers in place of arguments that had "
f"those axes in their in_axes. Provide the sizes of missing axes explicitly "
f"via axis_sizes to fix this error.")
if has_input_rank_assertions:
for arg, spec in zip(args_flat, in_axes_flat):
if spec.expected_rank is not None and spec.expected_rank != arg.ndim:
raise ValueError(f"xmap argument has an in_axes specification of {spec.user_repr}, "
f"which asserts that it should be of rank {spec.expected_rank}, "
f"but the argument has rank {arg.ndim} (and shape {arg.shape})")
_check_gda_xmap_partitioning(frozen_axis_resources, resource_env,
frozen_global_axis_sizes, in_axes_flat,
in_positional_semantics, args_flat)
params = dict(
name=getattr(fun, '__name__', '<unnamed function>'),
in_axes=tuple(in_axes_flat),
out_axes_thunk=out_axes_thunk,
donated_invars=donated_invars,
global_axis_sizes=frozen_global_axis_sizes,
axis_resources=frozen_axis_resources,
resource_env=resource_env,
backend=backend,
spmd_in_axes=None,
spmd_out_axes_thunk=None,
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics)
return fun_flat, args_flat, params, in_tree, out_tree
def verify_outputs(out_flat, out_tree, params):
if has_output_rank_assertions:
for out, spec in zip(out_flat, params['out_axes_thunk']()):
if spec.expected_rank is not None and spec.expected_rank != out.ndim:
raise ValueError(f"xmap output has an out_axes specification of {spec.user_repr}, "
f"which asserts that it should be of rank {spec.expected_rank}, "
f"but the output has rank {out.ndim} (and shape {out.shape})")
return tree_unflatten(out_tree(), out_flat)
def decorate_serial(f):
for loop_params in reversed(anon_serial_loops):
f = serial_loop(*loop_params)(f)
return f
@wraps(fun)
@decorate_serial
def fun_mapped(*args):
tree_map(_check_arg, args)
fun_flat, args_flat, params, _, out_tree = infer_params(*args)
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
return verify_outputs(out_flat, out_tree, params)
@decorate_serial
def lower(*args):
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
computation = make_xmap_callable(
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
params['resource_env'], params['backend'], params['spmd_in_axes'],
params['spmd_out_axes_thunk'], params['in_positional_semantics'],
params['out_positional_semantics'], *avals_flat)
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
in_avals = in_tree.unflatten(avals_flat)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree(),
no_kwargs=True)
fun_mapped.lower = lower
return fun_mapped
def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics,
out_positional_semantics):
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
xmap_callable = make_xmap_callable(
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, out_positional_semantics,
*in_avals).compile().unsafe_call
distributed_debug_log(("Running xmapped function", name),
("python function", fun.f),
("mesh", resource_env.physical_mesh),
("abstract args", in_avals))
return xmap_callable(*args)
@lu.cache
def make_xmap_callable(fun: lu.WrappedFun,
name,
in_axes, out_axes_thunk, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics,
out_positional_semantics, *in_avals):
assert out_positional_semantics == _PositionalSemantics.LOCAL
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
# TODO: Making axis substitution final style would allow us to avoid
# tracing to jaxpr here
mapped_in_avals = [_delete_aval_axes(aval, in_axes, global_axis_sizes)
for aval, in_axes in zip(in_avals, in_axes)]
with core.extend_axis_env_nd(global_axis_sizes.items()):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for xmap in {elapsed_time} sec"):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
out_axes = out_axes_thunk()
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
# NOTE: We don't use avals and all params, so only pass in the relevant parts (too lazy...)
_resource_typing_xmap([], dict(axis_resources=axis_resources,
out_axes=out_axes,
call_jaxpr=jaxpr,
resource_env=resource_env,
name=name),
source_info_util.new_source_info(), resource_env, {})
jaxpr = plan.subst_axes_with_resources(jaxpr)
use_spmd_lowering = config.experimental_xmap_spmd_lowering
ensure_fixed_sharding = config.experimental_xmap_ensure_fixed_sharding
if use_spmd_lowering and ensure_fixed_sharding:
jaxpr = _fix_inferred_spmd_sharding(jaxpr, resource_env)
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts)))
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
f = plan.vectorize_and_loop(f, in_axes, out_axes)
used_resources = _jaxpr_resources(jaxpr, resource_env) | set(it.chain(*axis_resources.values()))
used_mesh_axes = used_resources & resource_env.physical_resource_axes
if used_mesh_axes:
assert spmd_in_axes is None and spmd_out_axes_thunk is None # No outer xmaps, so should be None
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
mesh = resource_env.physical_mesh
global_in_avals = [
av if ips == _PositionalSemantics.GLOBAL else mesh._local_to_global(ax, av)
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
]
in_is_global = [ips == _PositionalSemantics.GLOBAL or not ia
for ips, ia in safe_zip(in_positional_semantics, mesh_in_axes)]
tiling_method: pxla.TilingMethod
if config.experimental_xmap_spmd_lowering_manual:
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
tiling_method = pxla.TileManual(manual_mesh_axes)
else:
tiling_method = pxla.TileVectorize()
return pxla.lower_mesh_computation(
f, 'xmap', name, mesh,
mesh_in_axes, mesh_out_axes, donated_invars,
use_spmd_lowering, global_in_avals,
tiling_method=tiling_method, in_is_global=in_is_global)
else:
return dispatch.lower_xla_callable(
f, None, backend, name, donated_invars, *((a, None) for a in in_avals))
class EvaluationPlan(NamedTuple):
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
resource_env: ResourceEnv
physical_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
loop_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
axis_subst_dict: Dict[AxisName, Tuple[ResourceAxisName, ...]]
axis_vmap_size: Dict[AxisName, Optional[int]]
@property
def axis_subst(self) -> core.AxisSubst:
return lambda name: self.axis_subst_dict.get(name, (name,))
@property
def resource_axis_env(self):
env = dict(self.resource_env.shape)
for axis, size in self.axis_vmap_size.items():
if size is None:
continue
vmap_axis = self.axis_subst_dict[axis][-1]
env[vmap_axis] = size
return env
@classmethod
def from_axis_resources(cls,
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
resource_env: ResourceEnv,
global_axis_sizes: Dict[AxisName, int],
in_positional_semantics: Sequence[bool]):
physical_axis_resources, loop_axis_resources = _unzip_axis_resources(
axis_resources, resource_env)
axis_resource_count = _get_axis_resource_count(
axis_resources, resource_env, in_positional_semantics)
axis_subst_dict = dict(axis_resources)
axis_vmap_size: Dict[AxisName, Optional[int]] = {}
for naxis, raxes in sorted(axis_resources.items(), key=lambda x: str(x[0])):
num_resources = axis_resource_count[naxis]
assert global_axis_sizes[naxis] % num_resources.nglobal == 0
local_tile_size = global_axis_sizes[naxis] // num_resources.nglobal
# We have to vmap when there are no resources (to handle the axis name!) or
# when every resource gets chunks of values.
if not raxes or local_tile_size > 1:
axis_vmap_size[naxis] = local_tile_size
axis_subst_dict[naxis] += (fresh_resource_name(naxis),)
else:
axis_vmap_size[naxis] = None
return cls(resource_env,
physical_axis_resources, loop_axis_resources,
axis_subst_dict, axis_vmap_size)
def subst_axes_with_resources(self, jaxpr):
try:
if any(self.loop_axis_resources.values()):
_check_no_loop_collectives(jaxpr, self.loop_axis_resources)
with core.extend_axis_env_nd(self.resource_axis_env.items()):
return core.subst_axis_names_jaxpr(jaxpr, self.axis_subst)
except core.DuplicateAxisNameError:
raise AssertionError("Incomplete resource type-checking? Please open a bug report!")
def vectorize_and_loop(self, f: lu.WrappedFun, in_axes, out_axes) -> lu.WrappedFun:
vmap_axes = {
naxis: raxes[-1]
for naxis, raxes in self.axis_subst_dict.items()
if self.axis_vmap_size[naxis] is not None
}
for naxis, vaxis in sorted(vmap_axes.items(), key=lambda x: x[1].uid):
local_tile_size = self.axis_vmap_size[naxis]
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
f = batching.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis)
used_loops = set(it.chain.from_iterable(self.loop_axis_resources.values()))
if not used_loops:
return f
if len(used_loops) > 1:
# TODO: Support multiple loops
raise NotImplementedError("Only one loop per xmap is supported")
loop_in_axes = _to_resource_axes(in_axes, self.loop_axis_resources)
loop_out_axes = _to_resource_axes(out_axes, self.loop_axis_resources)
loop_name, = used_loops
loop_length = self.resource_env.shape[loop_name]
def looped_f(*args):
def body(i, _):
# XXX: This call_wrapped is only valid under the assumption that scan
# only ever traces the body once (which it does at the moment).
result = f.call_wrapped(
*(_slice_tile(arg, spec.get(loop_name, None), i, loop_length)
for arg, spec in zip(args, loop_in_axes)))
return i + 1, result
_, stacked_results = lax.scan(body, 0, (), length=loop_length)
return [_merge_leading_axis(sresult, spec.get(loop_name, None))
for sresult, spec in zip(stacked_results, loop_out_axes)]
return lu.wrap_init(looped_f)
def to_mesh_axes(self, in_axes, out_axes=None):
"""
Convert in/out_axes parameters ranging over logical dimensions to
in/out_axes that range over the mesh dimensions.
"""
if out_axes is None:
return _to_resource_axes(in_axes, self.physical_axis_resources)
else:
return (_to_resource_axes(in_axes, self.physical_axis_resources),
_to_resource_axes(out_axes, self.physical_axis_resources))
# -------- xmap primitive and its transforms --------
# xmap has a different set of parameters than pmap, so we make it its own primitive type
class XMapPrimitive(core.MapPrimitive):
def __init__(self):
super().__init__('xmap')
self.def_impl(xmap_impl)
self.def_custom_bind(self.bind)
def bind(self, fun, *args, in_axes, **params):
assert len(in_axes) == len(args), (in_axes, args)
return core.map_bind(self, fun, *args, in_axes=in_axes, **params)
def process(self, trace, fun, tracers, params):
return trace.process_xmap(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
raise NotImplementedError
def get_bind_params(self, params):
new_params = dict(params)
subfun = lu.wrap_init(partial(core.eval_jaxpr, new_params.pop('call_jaxpr'), ()))
axes = new_params.pop('out_axes')
new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes)
spmd_axes = new_params.pop('spmd_out_axes')
if spmd_axes is not None:
new_params['spmd_out_axes_thunk'] = \
HashableFunction(lambda: spmd_axes, closure=spmd_axes)
else:
new_params['spmd_out_axes_thunk'] = None
return [subfun], new_params
xmap_p = XMapPrimitive()
core.EvalTrace.process_xmap = core.EvalTrace.process_call # type: ignore
def _process_xmap_default(self, call_primitive, f, tracers, params):
raise NotImplementedError(f"{type(self)} must override process_xmap to handle xmap")
core.Trace.process_xmap = _process_xmap_default # type: ignore
def _xmap_axis_subst(params, subst, traverse):
if 'call_jaxpr' not in params: # TODO(apaszke): This feels sketchy, but I'm not sure why
return params
if not traverse:
return params
def shadowed_subst(name):
return (name,) if name in params['global_axis_sizes'] else subst(name)
with core.extend_axis_env_nd(params['global_axis_sizes'].items()):
new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'], shadowed_subst)
return dict(params, call_jaxpr=new_jaxpr)
core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
# NOTE: We don't have to handle spmd_{in|out}_axes here, because
# SPMD batching always gets involved as the last transform before XLA translation
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p]
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
fun = lu.hashable_partial(
lu.wrap_init(ad.backward_pass),
call_jaxpr, reduce_axes + tuple(params['global_axis_sizes'].keys()), False)
fun, nz_arg_cts = ad.nonzero_outputs(fun)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
in_axes, out_axes = params['in_axes'], params['out_axes']
new_in_axes = (*(axis for axis, x in zip(in_axes, args) if not ad.is_undefined_primal(x)),
*(axis for axis, x in zip(out_axes, cts_in) if type(x) is not ad.Zero))
# NOTE: This assumes that the output cotangents being zero is a deterministic
# function of which input cotangents were zero.
@as_hashable_function(closure=(in_axes, tuple(type(c) is ad.Zero for c in cts_in)))
def out_axes_thunk():
return tuple(axis for axis, nz in zip(in_axes, nz_arg_cts()) if nz)
new_params = dict(params,
name=wrap_name(params['name'], 'transpose'),
in_axes=new_in_axes,
out_axes_thunk=out_axes_thunk,
donated_invars=(False,) * len(new_in_axes),
spmd_out_axes_thunk=None)
del new_params['out_axes']
del new_params['spmd_out_axes']
out_flat = xmap_p.bind(fun, *all_args, **new_params)
arg_cts = tree_unflatten(out_tree(), out_flat)
axis_resource_count = _get_axis_resource_count(
params['axis_resources'], params['resource_env'],
params['in_positional_semantics'])
local_axis_sizes = {
axis: axis_resource_count[axis].to_local(params['out_positional_semantics'], global_size)
for axis, global_size in params['global_axis_sizes'].items()
}
def unmap_zero(zero, axes):
return ad.Zero(_insert_aval_axes(zero.aval, axes, local_axis_sizes))
return tuple(unmap_zero(arg_ct, in_axis) if type(arg_ct) is ad.Zero else arg_ct
for arg_ct, in_axis in zip(arg_cts, in_axes))
ad.primitive_transposes[xmap_p] = _xmap_transpose
def _typecheck_xmap(
*in_avals, call_jaxpr, name, in_axes, out_axes, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes, in_positional_semantics,
out_positional_semantics):
axis_resource_count = _get_axis_resource_count(
axis_resources, resource_env, in_positional_semantics)
local_axis_sizes = {
axis: axis_resource_count[axis].to_local(out_positional_semantics, global_size)
for axis, global_size in global_axis_sizes.items()
}
binder_in_avals = [_insert_aval_axes(v.aval, a_in_axes, local_axis_sizes)
for v, a_in_axes in zip(call_jaxpr.invars, in_axes)]
for binder_in_aval, in_aval in zip(binder_in_avals, in_avals):
if not core.typecompat(binder_in_aval, in_aval):
raise core.JaxprTypeError(
f"xmap passes operand {in_aval} to jaxpr expecting {binder_in_aval}")
mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
for a, a_in_axes in zip(in_avals, in_axes)]
with core.extend_axis_env_nd(global_axis_sizes.items()):
core._check_jaxpr(lambda: core.JaxprPpContext(), call_jaxpr,
mapped_in_avals)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
for a, a_out_axes in zip(mapped_out_avals, out_axes)]
return out_avals
core.custom_typechecks[xmap_p] = _typecheck_xmap
def _resource_typing_xmap(avals,
params,
source_info: source_info_util.SourceInfo,
resource_env,
outer_axis_resources):
axis_resources = params['axis_resources']
inner_axis_resources = dict(outer_axis_resources)
inner_axis_resources.update(axis_resources)