-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
pxla.py
3069 lines (2651 loc) · 125 KB
/
pxla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of pmap and related functionality."""
from __future__ import annotations
import enum
from contextlib import contextmanager
from collections import namedtuple
from collections.abc import Sequence, Iterable
import dataclasses
from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
from typing import (Any, Callable, NamedTuple, Optional, Union, cast, TypeVar)
import warnings
import numpy as np
import jax
from jax.errors import JAXTypeError
from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding_specs
from jax._src import profiler
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.core import DShapedArray
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTO, UnspecifiedValue, UNSPECIFIED,
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
)
from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache)
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
pass
xe = xc._xla
unsafe_map, map = map, safe_map # type: ignore
logger = logging.getLogger(__name__)
Index = Union[int, slice, tuple[Union[int, slice], ...]]
NoSharding = sharding_specs.NoSharding
Chunked = sharding_specs.Chunked
Unstacked = sharding_specs.Unstacked
ShardedAxis = sharding_specs.ShardedAxis
Replicated = sharding_specs.Replicated
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
Mesh = mesh_lib.Mesh
MeshAxisName = sharding_impls.MeshAxisName
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = sharding_specs.ShardingSpec
### util
def identity(x): return x
def shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
"""Returns a list of size len(devices) containing per-device buffers.
For the C++ pmap path, we fallback to Python (this function) to shard
arguments that are not supported by the C++ `ShardArg`.
Args:
arg: The Python argument.
devices: The list of devices to shard over.
arg_indices: A list of `len(devices)` indices to use to shard the argument.
"""
if canonicalize:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices, sharding)
@profiler.annotate_function
def shard_args(
devices: Sequence[xb.xla_client.Device],
indices: Sequence[Sequence[Index]],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
args,
) -> Sequence[jax.Array]:
"""Shard each argument data array along its leading axis.
Args:
devices: sequence of Devices mapping replica index to a physical device.
indices: sequence of the same length as `args` describing how each arg
should be sharded/replicated across `devices`. Each element in `indices`
is the same length as `devices`.
args: a sequence of JaxTypes representing arguments to be sharded according
to `indices` and placed on `devices`.
Returns:
A list of length matching args, containing lists of per-device buffers
for each argument.
"""
return [shard_arg(arg, devices, indices[i], shardings[i])
for i, arg in enumerate(args)]
shard_arg_handlers: dict[Any, Callable[[Any, Any, Any, Any], Any]] = {}
def _shard_token(x, devices, indices, sharding):
zeros = np.zeros((), dtype=np.dtype(np.bool_))
aval = api_util.shaped_abstractify(zeros)
out = batched_device_put(aval, sharding, [zeros for i in indices], devices)
return out
shard_arg_handlers[core.Token] = _shard_token
def _masked_array_error(x, devices, indices, sharding):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
def _shard_array(x, devices, indices, sharding):
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
out = batched_device_put(aval, sharding, [x[i] for i in indices], devices)
return out
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
def _shard_darray(x, devices, indices, sharding):
return shard_arg(x._data, devices, indices, sharding)
shard_arg_handlers[core.DArray] = _shard_darray
def batched_device_put(aval: core.ShapedArray,
sharding: jax.sharding.Sharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
from jax._src import array
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.device() == d)]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
def shard_aval(size, axis: int, aval):
try:
return shard_aval_handlers[type(aval)](size, axis, aval)
except KeyError as err:
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
shard_aval_handlers: dict[type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
def _shard_abstract_array(size, axis: int, x):
try:
if x.shape[axis] != size:
raise ValueError(f"Axis size {size} does not match dimension {axis} of "
f"shape {x.shape}")
except IndexError:
raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
return x.update(shape=tuple_delete(x.shape, axis))
shard_aval_handlers[ShapedArray] = _shard_abstract_array
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding: sharding_impls.XLACompatibleSharding,
indices: tuple[Index, ...] | None,
) -> Callable[[list[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The local output AbstractValue.
sharding_spec: Indicates how the output is sharded across devices, or None
for non-array avals.
indices: The pre-computed result of spec_to_indices, or None for non-array
avals.
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
try:
return local_result_handlers[(type(aval))](aval, sharding, indices)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
PxlaResultHandler = Callable[..., Callable[[Any], Any]]
local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
def global_aval_to_result_handler(
aval: core.AbstractValue, out_sharding, committed: bool,
is_out_sharding_from_xla: bool
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The global output AbstractValue.
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
is_out_sharding_from_xla: True, if the out_sharding comes from XLA i.e.
the sharding is extracted from the HLO.
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
try:
return global_result_handlers[type(aval)](
aval, out_sharding, committed, is_out_sharding_from_xla)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
### lazy device-memory persistence and result handling
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
def xla_pmap_impl_lazy(
fun: lu.WrappedFun,
*args,
backend: str | None,
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Sequence[Any] | None,
name: str,
in_axes: Sequence[int | None],
out_axes_thunk: Callable[[], Sequence[int | None]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
) -> Callable:
if (config.jax_disable_jit and config.jax_eager_pmap and
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
def _emap_apply_fn(*args):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
axis_size=axis_size, global_axis_size=global_axis_size,
devices=devices, name=name, in_axes=in_axes,
out_axes_thunk=out_axes_thunk,
donated_invars=donated_invars,
is_explicit_global_axis_size=is_explicit_global_axis_size)
return _emap_apply_fn
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, *abstract_args)
# Don't re-abstractify args unless logging is enabled for performance.
if config.jax_distributed_debug:
distributed_debug_log(("Running pmapped function", name),
("python function", fun.f),
("devices", devices),
("abstract args", map(xla.abstractify, args)),
("fingerprint", fingerprint))
return compiled_fun
def xla_pmap_impl(fun: lu.WrappedFun, *args, **params):
compiled_fun = xla_pmap_impl_lazy(fun, *args, **params)
return compiled_fun(*args)
class EmapInfo(NamedTuple):
backend: str | None
devices: Sequence[Any] | None
def _emap_impl(fun: lu.WrappedFun, *args,
backend: str | None,
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Sequence[Any] | None,
name: str,
in_axes: Sequence[int | None],
out_axes_thunk: Callable[[], Sequence[int | None]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
):
from jax._src import array
# TODO(sharadmv,mattjj): implement these cases
if any(d for d in donated_invars):
raise NotImplementedError("Buffer donation not supported in eager pmap.")
if is_explicit_global_axis_size:
raise NotImplementedError("Non-default global_axis_size not supported in "
"eager pmap.")
emap_info = EmapInfo(backend, devices)
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
with core.new_base_main(MapTrace, emap_info=emap_info) as main:
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
t = main.with_cur_sublevel()
tracers = [
MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)]
ans = fun.call_wrapped(*tracers)
out_tracers = map(t.full_raise, ans)
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
del main
out_axes = out_axes_thunk()
platform = xb.get_backend(backend).platform
donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else ()
new_outvals = []
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
with jax.disable_jit(False):
donate_argnums_ = donate_argnums
if isinstance(outval, array.ArrayImpl):
# We don't want to donate if it's already sharded.
donate_argnums_ = ()
out = jax.pmap(
lambda _, x: x,
in_axes=(0, out_axis_src.get(axis_name)),
out_axes=out_axis,
devices=(None if devices is None else list(devices)),
backend=backend,
donate_argnums=donate_argnums_)(np.arange(axis_size), outval)
new_outvals.append(out)
return new_outvals
def _map_schedule(idx: tuple[int | None, ...]) -> tuple[int | None, ...]:
# In order to do a multi-map (a simultaneous map over several axes), we will
# nest several maps. Each time we do a map, we "remove" an input axis so we
# need to update the remaining map axes. For example, if we are to map over
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
return tuple(None if i is None else
i - sum(j is not None and j < i for j in idx[:l])
for l, i in enumerate(idx))
# We're often creating `f`s on the fly and we try to carefully make them have
# the right __hash__ and __eq__. However, despite our attempts pmap's caching
# still ends up not working, because it has a separate cache per
# _function object_. Adding this annotation here lets us reuse the same pmap
# callable for all equivalent primitive pmaps.
@lru_cache
def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName],
all_axes: list[tuple[int | None, ...]]
) -> tuple[Callable, dict[core.AxisName, int]]:
used_names = []
for i, name in reversed(list(enumerate(names))):
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
if any(in_axis is not None for in_axis in in_axes):
f = jax.pmap(
f,
in_axes=in_axes,
axis_name=name,
out_axes=0,
backend=info.backend,
devices=(None if info.devices is None else list(info.devices)))
used_names.append(name)
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
return f, out_shard_axes
FakePrimitive = namedtuple("FakePrimitive", ["multiple_results", "bind"])
class MapTrace(core.Trace):
def __init__(self, *args, emap_info):
super().__init__(*args)
self.emap_info = emap_info
def pure(self, val):
return MapTracer(self, val, {})
def sublift(self, tracer):
return MapTracer(self, tracer.val, tracer.shard_axes)
def process_primitive(self, primitive, tracers, params):
info = self.main.payload["emap_info"]
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
names = tuple(f.name for f in core.thread_local_state.trace_state.axis_env
if f.main_trace is self.main)
all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations
f = HashableFunction(lambda *args: primitive.bind(*args, **params),
(primitive, tuple(params.items())))
f_mapped, out_shard_axes = _multi_pmap(f, info, names, all_axes)
with core.eval_context(), jax.disable_jit(False):
outvals = f_mapped(*vals)
if primitive.multiple_results:
return [MapTracer(self, val, out_shard_axes) for val in outvals]
return MapTracer(self, outvals, out_shard_axes)
def process_call(self, call_primitive, fun, tracers, params):
raise NotImplementedError
def process_map(self, map_primitive, fun, tracers, params):
if params['devices'] is not None:
raise ValueError("Nested pmap with explicit devices argument.")
if not config.jax_disable_jit:
bind = HashableFunction(
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
(map_primitive, fun))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, params)
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
if ax is not None else s
for v, ax, s in zip(vals, in_axes, shard_axes)]
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
t = self.main.with_cur_sublevel()
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(t.full_raise, ans)
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
del t, in_tracers, ans, out_tracers
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
return map(partial(MapTracer, self), out, outaxes)
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
bind = HashableFunction(
lambda *args, **kwargs: prim.bind(
fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs),
(prim, fun, jvp, symbolic_zeros))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
bind = HashableFunction(
lambda *args, **kwargs: primitive.bind(
fun, fwd, bwd, *args, out_trees=out_trees,
symbolic_zeros=symbolic_zeros, **kwargs),
(primitive, fun, fwd, bwd))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})
def process_axis_index(self, frame):
bind = HashableFunction(
lambda _: jax.lax.axis_index(frame.name),
(jax.lax.axis_index, frame.name))
fake_primitive = FakePrimitive(multiple_results=False, bind=bind)
with core.eval_context():
range = jax.lax.iota(np.int32, frame.size)
dummy_tracer = MapTracer(self, range, {frame.name: 0})
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
annotation: int | None) -> int | None:
if annotation is None: return None
mapped_axes_ = set(mapped_axes)
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
shard_axis_src: dict[core.AxisName, int],
dst_annotation: int | None
) -> tuple[Any, dict[core.AxisName, int]]:
shard_axis_out = dict(shard_axis_src)
src = shard_axis_out.pop(axis_name, None)
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
dst_annotation)
with core.eval_context():
if src == dst:
outval = val
elif type(src) == type(dst) == int:
outval = batching.moveaxis(val, src, dst)
shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst)
elif src is None and dst is not None:
outval = batching.broadcast(val, axis_size, dst)
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
else:
raise NotImplementedError
return outval, shard_axis_out
def _moveaxis(ndim: int, shard_axes: dict[core.AxisName, int],
src: int, dst: int) -> dict[core.AxisName, int]:
lst: list[core.AxisName | None] = [None] * ndim
for k, v in shard_axes.items():
lst[v] = k
name = lst.pop(src)
lst.insert(dst - (src < dst), name)
return {name: i for i, name in enumerate(lst) if name is not None}
class MapTracer(core.Tracer):
__slots__ = ["val", "shard_axes"]
def __init__(self, trace: MapTrace, val, shard_axes: dict[core.AxisName, int]):
self._trace = trace
self.val = val
self.shard_axes = shard_axes
assert all(val < self.val.ndim for val in self.shard_axes.values())
@property
def aval(self):
aval = xla.abstractify(self.val)
shard_axes = dict(self.shard_axes)
for axis_idx in sorted(shard_axes.values())[::-1]:
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
return aval
def full_lower(self):
return self
def __str__(self):
named_axes = [f"{k}={v}" for k, v in self.shard_axes.items()]
return f"{self.val}{{{','.join(named_axes)}}}"
@lu.cache
def parallel_callable(fun: lu.WrappedFun,
backend_name: str | None,
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Sequence[Any] | None,
name: str,
in_axes: Sequence[int | None],
out_axes_thunk: Callable[[], Sequence[int | None]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
*avals):
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, avals, lowering_platform=None)
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
@dataclasses.dataclass(frozen=True)
class ParallelCallableInfo:
name: str
backend: xc.Client
axis_name: core.AxisName
axis_size: int
global_axis_size: int
devices: Sequence[xc.Device] | None
in_axes: Iterable[int | None]
out_axes_thunk: Callable[[], Sequence[int | None]]
avals: Sequence[core.AbstractValue]
@cached_property
def local_devices(self):
if self.devices:
out = [d for d in self.devices
if d.process_index == xb.process_index(self.backend)]
assert len(out) > 0
else:
out = None # type: ignore
return out
@cached_property
def out_axes(self):
return self.out_axes_thunk()
class ShardInfo(NamedTuple):
sharded_avals: Sequence[core.AbstractValue]
out_sharded_avals: Sequence[core.ShapedArray]
global_sharded_avals: Sequence[core.AbstractValue]
num_local_shards: int
num_global_shards: int
class ReplicaInfo(NamedTuple):
jaxpr_replicas: int
num_local_replicas: int
num_global_replicas: int
def find_replicas(
jaxpr: core.Jaxpr, axis_size: int, global_axis_size: int
) -> ReplicaInfo:
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
num_local_replicas = axis_size * jaxpr_replicas
num_global_replicas = global_axis_size * jaxpr_replicas
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
def stage_parallel_callable(
pci: ParallelCallableInfo, fun: lu.WrappedFun
) -> tuple[core.Jaxpr, list[Any], ReplicaInfo, ShardInfo]:
sharded_avals = tuple(
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(pci.in_axes, pci.avals))
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))
replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
num_local_shards = replicas.num_local_replicas
num_global_shards = replicas.num_global_replicas
shards = ShardInfo(
sharded_avals, out_sharded_avals, sharded_avals,
num_local_shards, num_global_shards)
return jaxpr, consts, replicas, shards
@profiler.annotate_function
def lower_parallel_callable(
fun: lu.WrappedFun,
backend_name: str | None,
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Sequence[xc.Device] | None,
name: str,
in_axes: Iterable[int | None],
out_axes_thunk: Callable[[], Sequence[int | None]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
*,
lowering_platform: str | None):
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
if (xb.process_count() == 1 and is_explicit_global_axis_size
and global_axis_size != axis_size):
raise ValueError(
f"Specified axis_size {global_axis_size} doesn't match received "
f"axis_size {axis_size}.")
if devices is not None and backend_name is None:
backend = xb.get_device_backend(devices[0])
else:
backend = xb.get_backend(backend_name)
no_nested_sharding = False
must_run_on_all_devices = False
if not is_explicit_global_axis_size:
if xb.process_count(backend) > 1:
if devices:
# This allows each host in a multi-host pmap to run on a different number
# of devices, but precludes nested sharding (i.e. inner pmaps).
no_nested_sharding = True
else:
# This assumes all hosts run on the same number of devices. We make sure
# this assumption is true by requiring that the pmap is run on all devices
# (and making the further assumption that each host has the same number of
# devices). Nested sharding is ok in this case.
must_run_on_all_devices = True
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("sharded_avals: %s", shards.sharded_avals)
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
logger.debug("num_replicas: %d num_local_replicas: %d",
replicas.num_global_replicas, replicas.num_local_replicas)
logger.debug("devices: %s", devices)
logger.debug("local_devices: %s", pci.local_devices)
if (xb.process_count(backend) > 1 and must_run_on_all_devices and
shards.num_local_shards != xb.local_device_count(backend)):
if shards.num_local_shards == axis_size:
raise ValueError(
f"On multi-host platforms, the input to pmapped functions must have "
f"leading axis size equal to the number of local devices if no "
f"`devices` argument is specified. Got {axis_size=}, "
f"num_local_devices={xb.local_device_count(backend)}")
else:
raise ValueError(
f"On multi-host platforms, pmapped functions must run across all "
f"devices, i.e. num_replicas * num_partitions should equal the "
f"number of local devices. Got "
f"num_replicas={replicas.num_local_replicas}, and "
f"num_local_devices={xb.local_device_count(backend)}")
if no_nested_sharding and replicas.jaxpr_replicas > 1:
raise ValueError(
f"On multi-host platforms, pmapped functions that both have `devices` "
f"specified and contain an inner_pmap must specify an "
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
f"{replicas.jaxpr_replicas}")
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
fun.__name__, id(fun),
shards.num_global_shards, avals, replicas.num_global_replicas)
axis_env = sharding_impls.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
replicated_args = [axis is None for axis in in_axes]
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
backend.platform)
module_name = f"pmap_{fun.__name__}"
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
ordered_effects = list(
effects.ordered_effects.filter_in(closed_jaxpr.effects))
if ordered_effects:
raise ValueError("Ordered effects not supported in `pmap`.")
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
lowering_platform or backend.platform,
sharding_impls.ReplicaAxisContext(axis_env),
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=replicas.num_global_replicas)
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
shards=shards, tuple_args=tuple_args,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
keepalive=lowering_result.keepalive,
host_callbacks=lowering_result.host_callbacks,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
class PmapComputation(stages.XlaLowering):
_hlo: ir.Module
_executable: PmapExecutable | None
def __init__(self, hlo: ir.Module, **compile_args):
self._executable = None
self._hlo = hlo
self.compile_args = compile_args
# -- stages.XlaLowering overrides
def stablehlo(self) -> ir.Module:
return self._hlo
@profiler.annotate_function
def compile(self, compiler_options=None) -> PmapExecutable:
if self._executable is None or compiler_options is not None:
executable = UnloadedPmapExecutable.from_hlo(
self._hlo, **self.compile_args,
compiler_options=compiler_options)
if compiler_options is None:
self._executable = executable
return executable
return self._executable
def _cast_to_shaped_array(aval: core.AbstractValue) -> ShapedArray:
assert isinstance(aval, ShapedArray), aval
return cast(ShapedArray, aval)
@dataclasses.dataclass
class UnloadedPmapExecutable:
compiled: Any
backend: xb.XlaBackend
local_input_avals: Sequence[core.AbstractValue]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
local_output_avals: Sequence[ShapedArray]
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
unordered_effects: list[core.Effect]
ordered_effects: list[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
jaxpr_debug_info: core.JaxprDebugInfo
def build_execute_fun(self):
input_indices = []
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
assert isinstance(spec, sharding_impls.PmapSharding), spec
assert isinstance(aval, core.ShapedArray), aval
input_indices.append(
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
if spec.sharding_spec is not None else None)
handle_outs = local_avals_to_results_handler(self.local_output_avals,
self.output_shardings)
handle_args = InputsHandler(self.compiled.local_devices(),
self.input_shardings, input_indices)
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
self.backend, handle_args, handle_outs,
self.unordered_effects,
self.ordered_effects, self.keepalive,
bool(self.host_callbacks),
set(range(len(input_indices))))
return execute_fun
def load(self) -> PmapExecutable:
fingerprint = getattr(self.compiled, "fingerprint", None)
return PmapExecutable(
self.compiled, self.build_execute_fun, fingerprint,
self.local_input_avals, self.jaxpr_debug_info, self)
@staticmethod
def from_hlo(hlo: ir.Module,
pci: ParallelCallableInfo,
replicas: ReplicaInfo,
shards: ShardInfo,
tuple_args: bool,
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect],
host_callbacks: list[Any],
keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo,
compiler_options=None):
devices = pci.devices
if devices is None:
if shards.num_global_shards > xb.device_count(pci.backend):
msg = ("compiling computation that requires {} logical devices, but only {} XLA "
"devices are available (num_replicas={})")
raise ValueError(msg.format(shards.num_global_shards,
xb.device_count(pci.backend),
replicas.num_global_replicas))
# On a single host, we simply grab the first N devices from jax.devices().
# In the single host case, we want the default device order of pmap to
# match jax.devices().
# On multiple hosts, we create a default device assignment that ensures
# each host is responsible for a contiguous set of replicas.
if shards.num_global_shards > shards.num_local_shards:
# TODO(skye): use a locality-aware assignment that satisfies the above
# constraint.
devices = [d for process_index in range(xb.process_count(pci.backend))
for d in xb.local_devices(process_index, pci.backend)]
else:
devices = xb.local_devices(backend=pci.backend)[:shards.num_local_shards]
else:
if shards.num_local_shards != len(pci.local_devices):
local_devices_str = ", ".join(map(str, pci.local_devices))
if shards.num_local_shards == pci.axis_size:
raise ValueError(
f"Leading axis size of input to pmapped function must equal the "
f"number of local devices passed to pmap. Got axis_size="
f"{pci.axis_size}, num_local_devices={len(pci.local_devices)}.\n"
f"(Local devices available to pmap: {local_devices_str})")
else:
raise ValueError(
f"pmapped function requires {shards.num_local_shards} local "
f"devices to run due to nested pmapped or other parallel "
f"functions, but only {len(pci.local_devices)} are available.\n"
f"(outer axis size: {pci.axis_size}, local devices available to "
f"pmap: {local_devices_str})")
if shards.num_global_shards != len(devices):
raise ValueError("compiling computation that creates %s shards, "
"but %s devices were specified" %
(shards.num_global_shards, len(devices)))
# 'devices' may be 1D or 2D at this point (e.g.
# get_default_device_assignment() returns 2D assignment, caller may have
# provided 1D list of devices).
# Convert to 2D in case it's 1D and we have > 1 partitions.
num_partitions = 1
device_assignment: np.ndarray = np.array(devices).reshape(
(replicas.num_global_replicas, num_partitions))
compile_options = xb.get_compile_options(
num_replicas=replicas.num_global_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,
use_spmd_partitioning=False,
env_options_overrides=compiler_options,
)
compile_options.parameter_is_tupled_arguments = tuple_args
process_index = xb.process_index(pci.backend)
local_device_assignment = np.array([
d for d in device_assignment.flat if d.process_index == process_index
])
input_sharding_specs = [
sharding_specs.pmap_sharding_spec(
replicas.num_local_replicas, pci.axis_size,
cast(ShapedArray, aval).shape, in_axis)
for aval, in_axis in safe_zip(shards.sharded_avals, pci.in_axes)]
in_shardings = _get_pmap_sharding(local_device_assignment,
input_sharding_specs)
local_unmapped_avals = [
_cast_to_shaped_array(
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
if out_axis is not None else aval
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
out_specs = [
sharding_specs.pmap_sharding_spec(
replicas.num_local_replicas, pci.axis_size, aval.shape, out_axis)
for aval, out_axis in safe_zip(
shards.out_sharded_avals, pci.out_axes)]
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
if hasattr(pci.backend, "compile_replicated"):
input_indices = [
sharding_specs.spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
]
handle_outs = local_avals_to_results_handler(local_unmapped_avals,
out_shardings)
return _compile_replicated_pmap_executable_from_hlo(
hlo, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, bool(unordered_effects),
ordered_effects, jaxpr_debug_info)
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
compiled = dispatch.compile_or_get_cached(
pci.backend, hlo, device_assignment, compile_options,
host_callbacks)
return UnloadedPmapExecutable(
compiled=compiled,
backend=pci.backend,
local_input_avals=pci.avals,
input_shardings=in_shardings,
local_output_avals=local_unmapped_avals,
output_shardings=out_shardings,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
keepalive=keepalive,
host_callbacks=host_callbacks,
jaxpr_debug_info=jaxpr_debug_info).load()
def _compile_replicated_pmap_executable_from_hlo(
hlo: ir.Module, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
jaxpr_debug_info):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=hlo,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
jaxpr_debug_info, None)
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
"fingerprint", "in_avals", "_jaxpr_debug_info",
"_unloaded_executable"]
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
in_avals, jaxpr_debug_info, unloaded_executable):
self.xla_executable = xla_executable
self._unsafe_call = None