-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
pxla.py
2547 lines (2170 loc) · 102 KB
/
pxla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of pmap and related functionality."""
# A ShardingSpec describes at a high level how a logical array is sharded across
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
# describe how to shard inputs to a parallel computation). spec_to_indices()
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
# how the sharded array is "laid out" across devices. Given a sequence of
# devices, we shard the data across the devices in row-major order, with
# replication treated as an extra inner dimension.
#
# For example, given the logical data array [1, 2, 3, 4], if we were to
# partition this array 4 ways with a replication factor of 2, for a total of 8
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
#
# This encoding is assumed by various parts of the system, e.g. generating
# replica groups for collective operations.
from __future__ import annotations
from contextlib import contextmanager, ContextDecorator
from collections import defaultdict, OrderedDict
import dataclasses
from functools import partial
import itertools as it
import operator as op
import threading
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast)
import sys
from absl import logging
import numpy as np
import jax
from jax import core
from jax import linear_util as lu
from jax.core import ConcreteArray, ShapedArray
from jax.errors import JAXTypeError
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_map
from jax._src import abstract_arrays
from jax._src import device_array
from jax._src import source_info_util
from jax._src import util
from jax._src import dispatch
from jax._src import profiler
from jax._src import stages
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.util import (unzip3, prod, safe_map, safe_zip,
extend_name_stack, new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
pass
if sys.version_info >= (3, 8):
from functools import cached_property as maybe_cached_property
else:
maybe_cached_property = property
if sys.version_info >= (3, 9):
OrderedDictType = OrderedDict
else:
OrderedDictType = Dict
xops = xc.ops
xe = xc._xla
unsafe_map, map = map, safe_map # type: ignore
Index = Union[int, slice, Tuple[Union[int, slice], ...]]
NoSharding = pmap_lib.NoSharding
Chunked = pmap_lib.Chunked
Unstacked = pmap_lib.Unstacked
ShardedAxis = pmap_lib.ShardedAxis
Replicated = pmap_lib.Replicated
_UNSHARDED_INSTANCE = NoSharding()
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = pmap_lib.ShardingSpec
MeshAxisName = Any
OpShardingType = Any
def sharding_spec_mesh_shape(self):
sharded_axis_sizes = []
for sharding in self.sharding:
if isinstance(sharding, NoSharding):
continue
elif isinstance(sharding, Unstacked):
sharded_axis_sizes.append(sharding.size)
elif isinstance(sharding, Chunked):
sharded_axis_sizes.extend(sharding.chunks)
else:
assert_unreachable(sharding)
return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas
for a in self.mesh_mapping)
def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType] = {}):
"""Converts a ShardingSpec to an OpSharding proto.
See
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
Unfortunately the semantics are not very well described in the proto spec, but the code here might help:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
"""
mesh_shape = cast(Tuple[int, ...], self.mesh_shape)
mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
replicated_maxes = [] # lists mesh axis identifiers to replicate over
for maxis, assignment in enumerate(self.mesh_mapping):
if isinstance(assignment, Replicated):
replicated_maxes.append((maxis, assignment.replicas))
elif isinstance(assignment, ShardedAxis):
sharded_axes[assignment.axis] = maxis
else:
assert_unreachable(assignment)
proto = xc.OpSharding()
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
proto.type = xc.OpSharding.Type.REPLICATED
return proto
else:
proto.type = xc.OpSharding.Type.OTHER
mesh_permutation = []
new_mesh_shape = []
next_sharded_axis = 0
for axis, sharding in enumerate(self.sharding):
if isinstance(sharding, NoSharding):
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
elif isinstance(sharding, Chunked):
for nchunks in sharding.chunks:
maxis = sharded_axes[next_sharded_axis]
assert mesh_shape[maxis] == nchunks
mesh_permutation.append(maxis)
next_sharded_axis += 1
new_mesh_shape.append(int(np.prod(sharding.chunks)))
elif isinstance(sharding, Unstacked):
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
else:
assert_unreachable(sharding)
# Create a partial sharding proto if tensor is replicated or partitioned
# specially over some mesh axes.
if replicated_maxes:
last_tile_dims = []
axes_by_type: Dict[OpShardingType, List[MeshAxisName]] = {}
size_by_type: Dict[OpShardingType, int] = defaultdict(lambda: 1)
assert set(x[0] for x in replicated_maxes).issuperset(set(special_axes.keys()))
for axis, size in replicated_maxes:
ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED)
axes_by_type.setdefault(ty, []).append(axis)
size_by_type[ty] *= size
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
last_tile_dims.append(ty)
new_mesh_shape.append(size_by_type[ty])
mesh_permutation.extend(axes)
proto.last_tile_dims = last_tile_dims
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
proto.tile_assignment_dimensions = list(proto_mesh.shape)
proto.tile_assignment_devices = list(proto_mesh.flat)
return proto
def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
"""Returns NumPy-style indices corresponding to a sharding spec.
Args:
shape: The shape of the logical array being sharded.
Returns:
An ndarray with the same shape as the logical mesh (as derived form
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
the data array to be placed on a corresponding device. The indices can be
ints, slice objects with step=1, or tuples of those.
"""
assert len(shape) == len(self.sharding), (shape, self.sharding)
axis_indices: List[Sequence[Index]] = []
shard_indices_shape = []
for dim, sharding in enumerate(self.sharding):
axis_size = shape[dim]
if isinstance(sharding, NoSharding):
axis_indices.append([slice(None)])
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
# because they do not appear in the mesh mapping.
elif isinstance(sharding, Unstacked):
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
axis_indices.append(range(axis_size))
shard_indices_shape.append(axis_size)
elif isinstance(sharding, Chunked):
total_chunks = int(np.prod(sharding.chunks))
shard_size, ragged = divmod(axis_size, total_chunks)
assert not ragged, (axis_size, total_chunks, dim)
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
for i in range(total_chunks)])
shard_indices_shape.extend(sharding.chunks)
else:
assert_unreachable(sharding)
# shard_indices is an ndarray representing the sharded axes of the logical array,
# with each dimension having size equal to the number of shards across the corresponding
# logical array dimension, and each element containing the multi-dimensional index that
# is used to extract the corresponding shard of the logical array.
shard_indices = np.empty([prod(shard_indices_shape)], dtype=np.object_)
for i, idxs in enumerate(it.product(*axis_indices)):
shard_indices[i] = idxs
shard_indices = shard_indices.reshape(shard_indices_shape)
# Ensure that each sharded axis is used exactly once in the mesh mapping
num_sharded_dim = len(shard_indices_shape)
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
len(sharded_dim_perm) == num_sharded_dim)
# Replicate/reorder the indices according to the mesh mapping
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
replica_dim, sharded_dim = it.count(0), iter(sharded_dim_perm)
perm = [next(replica_dim) if isinstance(a, Replicated) else
len(replica_sizes) + next(sharded_dim)
for a in self.mesh_mapping]
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
.transpose(perm))
def sharding_spec_repr(self):
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
ShardingSpec.mesh_shape = property(sharding_spec_mesh_shape)
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
ShardingSpec.indices = sharding_spec_indices
# mypy raises: error: Cannot assign to a method [assignment]
ShardingSpec.__repr__ = sharding_spec_repr # type: ignore
# Do not pollute the namespace
del sharding_spec_mesh_shape, sharding_spec_indices, sharding_spec_repr
def spec_to_indices(shape: Tuple[int, ...],
spec: ShardingSpec) -> Tuple[Index, ...]:
"""Returns numpy-style indices corresponding to a sharding spec.
Each index describes a shard of the array. The order of the indices is the
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
row-major).
Args:
shape: The shape of the logical array being sharded.
spec: Describes how the array is sharded and how the shards are assigned to
the logical mesh.
Returns:
A tuple of length equal to the size of the mesh (inferred as the product of
sharded dimension sizes and all replication factors). Each element is an
int, a slice object with step=1, or a tuple thereof, to be treated as an
index into the full logical array.
"""
return tuple(spec.indices(shape).flat) # type: ignore
### util
def identity(x): return x
def _shard_arg(arg, devices, arg_indices):
"""Returns a list of size len(devices) containing per-device buffers.
For the C++ pmap path, we fallback to Python (this function) to shard
arguments that are not supported by the C++ `ShardArg`.
Arrgs:
arg: The Python argument.
devices: The list of devices to shard over.
arg_indices: A list of `len(devices)` indices to use to shard the argument.
"""
if isinstance(arg, ShardedDeviceArray) and arg_indices == arg.indices:
# The shard_arg_handlers allow an extensible set of types to be sharded, but
# inline handling for ShardedDeviceArray as a special case for performance
# NOTE: we compare indices instead of sharding_spec because
# pmap_benchmark.pmap_shard_args_benchmark indicates this is faster.
return [
buf if buf.device() == d else buf.copy_to_device(d)
for d, buf in zip(devices, arg.device_buffers)
]
else:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, devices, arg_indices)
@profiler.annotate_function
def shard_args(devices: Sequence[xb.xla_client.Device],
indices: Sequence[Sequence[Index]],
args) -> Sequence[Sequence[xb.xla_client.Buffer]]:
"""Shard each argument data array along its leading axis.
Args:
devices: sequence of Devices mapping replica index to a physical device.
indices: sequence of the same length as `args` describing how each arg
should be sharded/replicated across `devices`. Each element in `indices`
is the same length as `devices`.
args: a sequence of JaxTypes representing arguments to be sharded according
to `indices` and placed on `devices`.
Returns:
A list of length matching args, containing lists of per-device buffers
for each argument.
"""
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
shard_arg_handlers[core.Unit] = \
lambda x, devices, _: device_put(core.unit, devices, replicate=True) # type: ignore[has-type]
def _shard_array(x, devices, indices):
return device_put([x[i] for i in indices], devices)
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
def _shard_device_array(x, devices, indices):
start_indices, limit_indices, removed_dims = unzip3(
_as_slice_indices(x, idx) for idx in indices)
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
return device_put(shards, devices)
for t in device_array.device_array_types:
shard_arg_handlers[t] = _shard_device_array
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
# subtle and more likely to change than the index logic we have to support here.
def _as_slice_indices(arr: device_array.DeviceArrayProtocol, idx: Index) -> Tuple[
Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]:
"""Returns start_indices, limit_indices, removed_dims"""
start_indices = [0] * arr.ndim
limit_indices = list(arr.shape)
removed_dims = []
tuple_idx = idx if isinstance(idx, tuple) else (idx,)
for dim, sub_idx in enumerate(tuple_idx):
if isinstance(sub_idx, int):
start_indices[dim] = sub_idx
limit_indices[dim] = sub_idx + 1
removed_dims.append(dim)
elif sub_idx == slice(None):
continue
else:
assert isinstance(sub_idx, slice), sub_idx
assert isinstance(sub_idx.start, int), sub_idx
assert isinstance(sub_idx.stop, int), sub_idx
start_indices[dim] = sub_idx.start
limit_indices[dim] = sub_idx.stop
return tuple(start_indices), tuple(limit_indices), tuple(removed_dims) # type: ignore
def shard_aval(size, axis: int, aval):
try:
return shard_aval_handlers[type(aval)](size, axis, aval)
except KeyError as err:
raise TypeError(f"No shard_aval handler for type: {type(aval)}") from err
shard_aval_handlers: Dict[Type[core.AbstractValue], Callable[[int, int, Any], Any]] = {}
shard_aval_handlers[core.AbstractUnit] = lambda size, axis, x: x
def _shard_abstract_array(size, axis: int, x):
try:
if x.shape[axis] != size:
raise ValueError(f"Axis size {size} does not match dimension {axis} of "
f"shape {x.shape}")
except IndexError:
raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None
return x.update(shape=tuple_delete(x.shape, axis))
shard_aval_handlers[ShapedArray] = _shard_abstract_array
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
# TODO(yashkatariya): Move PartitionSpec into a place where all files can
# import it without cyclic dependency.
from jax.interpreters.sharded_jit import PartitionSpec
if not array_mapping:
return PartitionSpec()
max_index = -1
reverse_map = defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
if index > max_index:
max_index = index
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
for i in range(max_index + 1))
return PartitionSpec(*partitions)
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding_spec: Optional[ShardingSpec],
indices: Optional[Tuple[Index]],
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The local output AbstractValue.
sharding_spec: Indicates how the output is sharded across devices, or None
for non-array avals.
indices: The pre-computed result of spec_to_indices, or None for non-array
avals.
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
try:
return local_result_handlers[type(aval)](aval, sharding_spec, indices)
except KeyError as err:
raise TypeError(
"No pxla_result_handler for type: {}".format(type(aval))) from err
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
local_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
def sda_array_result_handler(aval: ShapedArray, sharding_spec, indices):
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
local_result_handlers[ShapedArray] = sda_array_result_handler
local_result_handlers[ConcreteArray] = sda_array_result_handler
def global_aval_to_result_handler(
aval: core.AbstractValue, out_axis_resources, global_mesh,
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The global output AbstractValue.
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
try:
return global_result_handlers[type(aval)](aval, out_axis_resources,
global_mesh)
except KeyError as err:
raise TypeError(
"No pxla_result_handler for type: {}".format(type(aval))) from err
global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
global_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
### lazy device-memory persistence and result handling
# TODO(jblespiau): Consider removing this option.
_USE_CPP_SDA = True
def make_sharded_device_array(
aval: ShapedArray,
sharding_spec: Optional[ShardingSpec],
# Any is for JAX extensions implementing their own buffer.
device_buffers: List[Union[Any, xb.xla_client.Buffer]],
indices: Optional[Tuple[Index, ...]] = None,
):
"""Returns a ShardedDeviceArray implementation based on arguments.
Returns either a C++ SDA or a Python DeviceArray when the buffers are not
JAX buffers.
Args:
aval: The `ShapedArray` for this array.
sharding_spec: If `None`, assumes a pmap-style ShardedDeviceArrays over the
first dimension.
device_buffers: If a list of Jax `Buffer` objects, a C++ SDA will be
returned (if the version is high enough). Otherwise, a Python object will
be returned, for JAX extensions not implementing the C++ API.
indices: For caching purposes, will be computed if `None`.
"""
if sharding_spec is None:
sharded_aval = aval.update(shape=aval.shape[1:])
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0], 1, None,
sharded_aval, 0)
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)
if (_USE_CPP_SDA and
(not device_buffers or
isinstance(device_buffers[0], xb.xla_client.Buffer))):
return pmap_lib.ShardedDeviceArray.make(
aval, sharding_spec, device_buffers,
indices, aval.weak_type)
return _ShardedDeviceArray(aval, sharding_spec, device_buffers, indices)
if _USE_CPP_SDA:
ShardedDeviceArrayBase = pmap_lib.ShardedDeviceArrayBase # type: ignore
# We want the C++ SDA to extend the DeviceArrayBase. We want this both to
# benefit from its methods, and to have isinstance(x, DeviceArray) return true
ShardedDeviceArrayBase.__bases__ = ((device_array.DeviceArray,) + # type: ignore
ShardedDeviceArrayBase.__bases__)
_SDA_BASE_CLASS = pmap_lib.ShardedDeviceArrayBase # type: ignore
else:
_SDA_BASE_CLASS: Type[device_array.DeviceArray] = device_array.DeviceArray # type: ignore
class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
"""A ShardedDeviceArray is an ndarray sharded across devices.
The purpose of a ShardedDeviceArray is to reduce the number of transfers when
executing replicated computations, by allowing results to persist on the
devices that produced them. That way dispatching a similarly replicated
computation that consumes the same sharded memory layout does not incur any
transfers.
A ShardedDeviceArray represents one logical ndarray value, and simulates the
behavior of an ndarray so that it can be treated by user code as an ndarray;
that is, it is only an optimization to reduce transfers.
Attributes:
aval: A ShapedArray indicating the shape and dtype of this array.
sharding_spec: describes how this array is sharded across `device_buffers`.
device_buffers: the buffers containing the data for this array. Each buffer
is the same shape and on a different device. Buffers are in row-major
order, with replication treated as an extra innermost dimension.
indices: the result of spec_to_indices(sharding_spec). Can optionally be
precomputed for efficiency. A list the same length as
`device_buffers`. Each index indicates what portion of the full array is
stored in the corresponding device buffer, i.e. `array[indices[i]] ==
device_buffers[i].to_py()`.
"""
__slots__ = [
"aval", "device_buffers", "sharding_spec", "indices",
"_one_replica_buffer_indices", "_npy_value"
]
def __init__(self,
aval: ShapedArray,
sharding_spec: ShardingSpec,
device_buffers: List[xb.xla_client.Buffer],
indices: Optional[Tuple[Index, ...]] = None):
super().__init__()
# TODO(skye): assert invariants. Keep performance in mind though.
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)
self.aval = aval
self.device_buffers = device_buffers
self.sharding_spec = sharding_spec
self.indices = indices
self._npy_value = None
self._one_replica_buffer_indices = None
if config.jax_enable_checks:
assert type(aval) is ShapedArray
@property
def shape(self):
return self.aval.shape
@property
def dtype(self):
return self.aval.dtype
@property
def size(self):
return prod(self.aval.shape)
@property
def ndim(self):
return len(self.aval.shape)
def delete(self):
if self.device_buffers is None:
return
for buf in self.device_buffers:
buf.delete()
self.device_buffers = None
self._npy_value = None
def _sda_one_replica_buffer_indices(self):
"""Indices of buffers containing one complete copy of the array data."""
if self._one_replica_buffer_indices is None:
one_replica_indices = []
seen_index_hashes = set()
for i, index in enumerate(self.indices):
hashed_index = _hashable_index(index)
if hashed_index not in seen_index_hashes:
one_replica_indices.append(i)
seen_index_hashes.add(hashed_index)
self._one_replica_buffer_indices = one_replica_indices
return self._one_replica_buffer_indices
def _sda_copy_to_host_async(self):
for buffer_index in self.one_replica_buffer_indices:
self.device_buffers[buffer_index].copy_to_host_async()
def _sda_check_if_deleted(self):
if self.device_buffers is None:
raise ValueError("ShardedDeviceArray has been deleted.")
def _sda_block_until_ready(self):
self._check_if_deleted()
for buf in self.device_buffers:
buf.block_until_ready()
return self
def _sda_value(self):
if self._npy_value is None:
self.copy_to_host_async()
npy_value = np.empty(self.aval.shape, self.aval.dtype)
for i in self.one_replica_buffer_indices:
npy_value[self.indices[i]] = self.device_buffers[i].to_py()
self._npy_value = npy_value
return self._npy_value
def _sda__getitem__(self, idx):
self._check_if_deleted()
if not isinstance(idx, tuple):
cidx = (idx,) + (slice(None),) * (len(self.aval.shape) - 1)
else:
cidx = idx + (slice(None),) * (len(self.aval.shape) - len(idx))
if self._npy_value is None:
try:
buf_idx = self.indices.index(cidx)
except ValueError:
buf_idx = None
if buf_idx is not None:
buf = self.device_buffers[buf_idx]
aval = ShapedArray(buf.xla_shape().dimensions(), self.aval.dtype)
return device_array.make_device_array(aval, None, buf)
return super(self.__class__, self).__getitem__(idx)
def _sda__iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0]))
def _sda__reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in range(self.shape[0] - 1, -1, -1))
for sda in [_ShardedDeviceArray, pmap_lib.ShardedDeviceArray]:
setattr(sda, "one_replica_buffer_indices",
property(_sda_one_replica_buffer_indices))
setattr(sda, "copy_to_host_async", _sda_copy_to_host_async)
setattr(sda, "_check_if_deleted", _sda_check_if_deleted)
setattr(sda, "block_until_ready", _sda_block_until_ready)
setattr(sda, "_value", property(_sda_value))
setattr(sda, "__getitem__", _sda__getitem__)
setattr(sda, "__iter__", _sda__iter__)
setattr(sda, "__reversed__", _sda__reversed__)
del (_sda_one_replica_buffer_indices, _sda_copy_to_host_async,
_sda_check_if_deleted, _sda_block_until_ready, _sda_value, _sda__getitem__)
ShardedDeviceArray: Type[object]
if _USE_CPP_SDA:
ShardedDeviceArray = pmap_lib.ShardedDeviceArrayBase
else:
ShardedDeviceArray = _ShardedDeviceArray
def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x,
idx)
# The fast path is handled directly in shard_args().
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
def _shard_sharded_device_array_slow_path(x, devices, indices):
candidates = defaultdict(list)
for buf, idx in safe_zip(x.device_buffers, x.indices):
candidates[_hashable_index(idx)].append(buf)
bufs = []
for idx, device in safe_zip(indices, devices):
# Look up all buffers that contain the correct slice of the logical array.
candidates_list = candidates[_hashable_index(idx)]
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return shard_arg_handlers[type(x._value)](x._value, devices, indices)
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
if buf.device() == device:
bufs.append(buf)
break
else:
bufs.append(buf.copy_to_device(device))
return bufs
def _sharded_device_array_constant_handler(c, val, canonicalize_types=True):
return xla.pyval_to_ir_constants(c, np.asarray(val),
canonicalize_types=canonicalize_types)
def _sharded_device_array_mlir_constant_handler(val, canonicalize_types=True):
return mlir.ir_constants(np.asarray(val),
canonicalize_types=canonicalize_types)
def _register_handlers_for_sharded_device_array(sda):
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
mlir.register_constant_handler(sda,
_sharded_device_array_mlir_constant_handler)
core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
dispatch.device_put_handlers[sda] = dispatch._device_put_array
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
xla.canonicalize_dtype_handlers[sda] = identity
_register_handlers_for_sharded_device_array(_ShardedDeviceArray)
_register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray)
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
def xla_pmap_impl(fun: lu.WrappedFun, *args,
backend: Optional[str],
axis_name: core.AxisName,
axis_size: int,
global_axis_size: Optional[int],
devices: Optional[Sequence[Any]],
name: str,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
*abstract_args)
# Don't re-abstractify args unless logging is enabled for performance.
if config.jax_distributed_debug:
distributed_debug_log(("Running pmapped function", name),
("python function", fun.f),
("devices", devices),
("abstract args", map(xla.abstractify, args)),
("fingerprint", fingerprint))
return compiled_fun(*args)
@lu.cache
def parallel_callable(fun: lu.WrappedFun,
backend_name: Optional[str],
axis_name: core.AxisName,
axis_size: int,
global_axis_size: Optional[int],
devices: Optional[Sequence[Any]],
name: str,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
*avals):
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes, avals)
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
@dataclasses.dataclass(frozen=True)
class ParallelCallableInfo:
name: str
backend: xla.Backend
axis_name: core.AxisName
axis_size: int
global_axis_size: Optional[int]
devices: Optional[Sequence[xla.Device]]
in_axes: Iterable[Optional[int]]
out_axes_thunk: Callable[[], Sequence[Optional[int]]]
avals: Sequence[core.AbstractValue]
@maybe_cached_property
def local_devices(self):
if self.devices:
out = [d for d in self.devices
if d.process_index == xb.process_index(self.backend)]
assert len(out) > 0
else:
out = None # type: ignore
return out
@maybe_cached_property
def out_axes(self):
return self.out_axes_thunk()
class ShardInfo(NamedTuple):
sharded_avals: Sequence[core.AbstractValue]
out_sharded_avals: Sequence[core.AbstractValue]
global_sharded_avals: Sequence[core.AbstractValue]
num_local_shards: int
num_global_shards: int
class ReplicaInfo(NamedTuple):
jaxpr_replicas: int
num_local_replicas: int
num_global_replicas: int
def find_replicas(jaxpr, axis_size, global_axis_size):
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
num_local_replicas = axis_size * jaxpr_replicas
num_global_replicas = global_axis_size * jaxpr_replicas
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
def should_tuple_args(shards: ShardInfo):
# tuplify long arg lists for TPU
return len(shards.global_sharded_avals) > 100
def stage_parallel_callable(
pci: ParallelCallableInfo,
fun: lu.WrappedFun,
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
sharded_avals = tuple(
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(pci.in_axes, pci.avals))
if any(s is not None for s in global_arg_shapes):
# TODO(skye): we could take this branch unconditionally if we handled
# grad of global_arg_shapes correctly.
global_sharded_avals = [
aval.update(shape=shape) if shape is not None else aval
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
else:
global_sharded_avals = sharded_avals # type: ignore
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for pmap in {elapsed_time} sec"):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
# for now raise an error
if pci.devices is not None:
is_multi_host_pmap = len(pci.local_devices) != len(pci.devices)
else:
is_multi_host_pmap = xb.process_count(pci.backend) > 1
if is_multi_host_pmap:
check_multihost_collective_allowlist(jaxpr)
replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
parts = find_partitions(jaxpr)
num_local_shards = replicas.num_local_replicas * parts.local_num_partitions
num_global_shards = replicas.num_global_replicas * parts.num_partitions
shards = ShardInfo(
sharded_avals, out_sharded_avals, global_sharded_avals,
num_local_shards, num_global_shards)
return jaxpr, consts, replicas, parts, shards
def _shardings_to_mlir_shardings(
shardings: Optional[Sequence['PartitionsOrReplicated']]
) -> Optional[Sequence[Optional[xc.OpSharding]]]:
if shardings is None:
return None
return [xla.sharding_to_proto(s) for s in shardings]
@profiler.annotate_function
def lower_parallel_callable(
fun: lu.WrappedFun,
backend_name: Optional[str],
axis_name: core.AxisName,
axis_size: int,
global_axis_size: Optional[int],
devices: Optional[Sequence[xla.Device]],
name: str,
in_axes: Iterable[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
avals: Sequence[core.AbstractValue]):
if devices is not None and len(devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
# raise ValueError("'axis_size' must be specified for nested multi-host pmaps")
if (xb.process_count() == 1 and global_axis_size is not None and
global_axis_size != axis_size):
raise ValueError(
f"Specified axis_size {global_axis_size} doesn't match received "
f"axis_size {axis_size}.")
if devices is not None and backend_name is None:
backend = xb.get_device_backend(devices[0])
else:
backend = xb.get_backend(backend_name)
must_run_on_all_devices = False
no_nested_sharding = False
if global_axis_size is None:
if xb.process_count(backend) == 1:
global_axis_size = axis_size
elif devices:
# This allows each host in a multi-host pmap to run on a different number
# of devices, but precludes nested sharding (i.e. inner pmaps or
# sharded_jits).
global_axis_size = len(devices)
no_nested_sharding = True
else:
# This assumes all hosts run on the same number of devices. We make sure
# this assumption is true by requiring that the pmap is run on all devices
# (and making the further assumption that each host has the same number of
# devices). Nested sharding is ok in this case.
global_axis_size = axis_size * xb.process_count(backend)
assert all(
len(xb.local_devices(process_index, backend)) == xb.local_device_count(backend)
for process_index in range(xb.process_count(backend)))
must_run_on_all_devices = True
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
pci, fun, global_arg_shapes)
if logging.vlog_is_on(2):
logging.vlog(2, "sharded_avals: %s", shards.sharded_avals)