-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
sharding.py
419 lines (334 loc) · 15 KB
/
sharding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import functools
from collections import Counter
from typing import Sequence, Tuple, Optional, Mapping, Dict, Set, Union, cast
from jax._src.util import safe_zip
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.interpreters import pxla, mlir
import numpy as np
Shape = Tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]
@pxla.use_cpp_class(xc.Sharding if xc._version >= 94 else None)
class Sharding(metaclass=abc.ABCMeta):
# Abstract methods below that subclasses should implement.
@abc.abstractproperty
def device_set(self) -> Set[Device]:
"""A unique set of devices that this sharding represents.
Devices can be non-addressable too.
"""
raise NotImplementedError('Subclasses should implement this method.')
@abc.abstractmethod
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
raise NotImplementedError('Subclasses should implement this method.')
@abc.abstractmethod
def shard_shape(self, global_shape: Shape) -> Shape:
raise NotImplementedError('Subclasses should implement this method.')
#############################################################################
# Default implementations below that all subclasses will inherit.
@pxla.maybe_cached_property
def addressable_devices(self) -> Set[Device]:
"""A set of addressable devices by the current process"""
process_index = xb.process_index()
return {d for d in self.device_set if d.process_index == process_index}
def is_fully_addressable(self) -> bool:
# The pytype disable is because pytype can't recognize a cached property.
return len(self.device_set) == len(self.addressable_devices) # type: ignore
def device_indices(self, device: Device,
global_shape: Shape) -> Optional[Index]:
return self.devices_indices_map(global_shape)[device]
@functools.lru_cache(maxsize=4096)
def addressable_devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
process_index = xb.process_index()
return {d: ind for d, ind in self.devices_indices_map(global_shape).items()
if d.process_index == process_index}
@pxla.use_cpp_class(xc.XLACompatibleSharding if xc._version >= 94 else None)
class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
# Abstract methods below that subclasses should implement.
@abc.abstractproperty
def _device_assignment(self) -> XLADeviceAssignment:
raise NotImplementedError('Subclasses should implement this method.')
@abc.abstractmethod
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
raise NotImplementedError('Subclasses should implement this method.')
#############################################################################
# Default implementations below that all subclasses will inherit.
@pxla.maybe_cached_property
def _addressable_device_assignment(self) -> XLADeviceAssignment:
process_index = xb.process_index()
return [d for d in self._device_assignment if d.process_index == process_index]
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
op_sharding = cast(xc.OpSharding, self._to_xla_op_sharding(len(global_shape)))
if pxla.is_op_sharding_replicated(op_sharding):
return global_shape
partitions, _ = pxla._get_num_ways_dim_sharded(op_sharding)
assert len(partitions) == len(global_shape), (len(partitions), len(global_shape))
out = []
for dim, (s, p) in enumerate(safe_zip(global_shape, partitions)):
quotient, remainder = divmod(s, p)
if remainder != 0:
raise ValueError(
f"Sharding {self} implies that array axis {dim} is partitioned "
f"{p} times, but the dimension size is {s} "
f"(full shape: {global_shape}, "
f"per-dimension tiling factors: {partitions} should evenly divide "
"the shape)")
out.append(quotient)
return tuple(out)
@functools.lru_cache()
def _check_mesh_resource_axis(mesh, parsed_pspec):
try:
[mesh.shape[r] for p in parsed_pspec if p is not None
for r in p]
except KeyError as e:
raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is "
"undefined.") from None
def _hashed_index(x) -> int:
# This works for both `pjit`/`xmap` indices and `pmap` indices (which might
# have an integer instead of a slice).
assert all(v.step is None for v in x if isinstance(v, slice))
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
@functools.lru_cache(maxsize=4096)
def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
try:
device_indices_map_fn = sharding.devices_indices_map
except AttributeError:
raise ValueError(
f'Cannot calculate replica ids from sharding: {sharding}. Please '
'create a device to index mapping for your sharding from which replica '
'ids will be calculated.') from None
index_to_replica: Dict[int, int] = Counter()
out = {}
for device, index in device_indices_map_fn(global_shape).items():
h_index = _hashed_index(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
out[device] = replica_id
return out
@pxla.use_cpp_class(xc.MeshPspecSharding if xc._version >= 95 else None)
class MeshPspecSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):
self.mesh = mesh
self.spec = spec
self._parsed_pspec = _parsed_pspec
self._preprocess()
def _preprocess(self):
# This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to
# axis_resources for vmap handlers. In such cases you need to preserve the
# `sync` attribute of parsed pspecs.
# PartitionSpec is inferred from the parsed pspec in this case.
# TODO(yaskatariya): Remove this and replace this with a normalized
# representation of Parsed Pspec
if self._parsed_pspec is None:
from jax.experimental import pjit
self._parsed_pspec, _, _, _ = pjit._prepare_axis_resources(
self.spec, "MeshPspecSharding spec")
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
def __repr__(self):
return f'MeshPspecSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self.mesh, self._parsed_pspec))
return self._hash
def __eq__(self, other):
if not isinstance(other, MeshPspecSharding):
return False
if id(self) == id(other):
return True
if id(self.mesh) == id(other.mesh) and self._parsed_pspec == other._parsed_pspec:
return True
return self.mesh == other.mesh and self._parsed_pspec == other._parsed_pspec
def is_compatible_aval(self, aval_shape: Shape):
if len(aval_shape) < len(self._parsed_pspec):
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
f"{len(self._parsed_pspec)}, but was applied to a value of rank "
f"{len(aval_shape)}")
@classmethod
def _from_parsed_pspec(cls, mesh, parsed_pspec):
return cls(mesh, parsed_pspec.get_partition_spec(), parsed_pspec)
@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
return set(self.mesh.devices.flat)
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index]:
# TODO(yashkatariya): Remove this when utilities are moved to pxla.py.
from jax.experimental import global_device_array
# `get_shard_indices` is cached.
return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec)
@pxla.maybe_cached_property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.mesh.devices.flat)
@functools.lru_cache(maxsize=4096)
def _to_xla_op_sharding(
self,
num_dimensions: int,
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
) -> xc.OpSharding:
from jax.experimental.pjit import get_array_mapping
array_mapping = get_array_mapping(self._parsed_pspec)
# TODO(yashkatariya): Move away from sharding spec in MeshPspecSharding
# since we don't really need sharding spec.
sharding_spec = pxla.new_mesh_sharding_specs(
self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping)
# Used in `with_sharding_constraint`.
special_axes = {}
# Manual axes is only used with xmap.
if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext):
axis_names = self.mesh.axis_names
# Ignore type because mypy doesn't recognize the `hasattr` check above.
for manual_axis in axis_ctx.manual_axes: # type: ignore
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
return sharding_spec.sharding_proto(special_axes=special_axes)
@functools.lru_cache()
def _get_replicated_op_sharding():
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.REPLICATED
return proto
@pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
class SingleDeviceSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(self, device: Device):
self._device = device
def __repr__(self):
return f"SingleDeviceSharding(device={repr(self._device)})"
def __hash__(self):
return hash(self._device)
def __eq__(self, other):
if not isinstance(other, SingleDeviceSharding):
return False
if id(self) == id(other):
return True
return self._device == other._device
@property
def device_set(self) -> Set[Device]:
return {self._device}
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index]:
return {self._device: (slice(None),) * len(global_shape)}
@property
def _device_assignment(self) -> XLADeviceAssignment:
return [self._device]
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
return _get_replicated_op_sharding()
@pxla.use_cpp_class(xc.PmapSharding if xc._version >= 94 else None)
class PmapSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(self, devices: np.ndarray, sharding_spec: pxla.ShardingSpec):
self.devices = devices
# The sharding spec should be pmap's sharding spec.
self.sharding_spec = sharding_spec
def __eq__(self, other):
if not isinstance(other, PmapSharding):
return False
if id(self) == id(other):
return True
return (self.sharding_spec == other.sharding_spec and
np.array_equal(self.devices, other.devices))
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((tuple(self.devices.flat), self.sharding_spec))
return self._hash
@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
return set(self.devices.flat)
@functools.lru_cache(maxsize=4096)
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore
@pxla.maybe_cached_property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.devices.flat)
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
sharded_dim = None
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, pxla.Unstacked):
sharded_dim = i
break
if sharded_dim is None:
return global_shape
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
# TODO(yashkatariya): Remove this when minimum_jaxlib version is 0.3.17
def _hash_op_sharding(op: xc.OpSharding):
if op.type == xc.OpSharding.Type.TUPLE:
return hash(tuple(_hash_op_sharding(o) for o in op.tuple_shardings))
return hash((tuple(op.tile_assignment_devices), tuple(op.tile_assignment_dimensions),
op.type, op.replicate_on_last_tile_dim, tuple(op.last_tile_dims)))
@pxla.use_cpp_class(xc.OpShardingSharding if xc._version >= 95 else None)
class OpShardingSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
self._devices = tuple(devices)
self._op_sharding = op_sharding
@pxla.maybe_cached_property
def _op_sharding_hash(self):
if xla_extension_version >= 81:
return hash(xc.HloSharding.from_proto(self._op_sharding))
else:
return _hash_op_sharding(self._op_sharding)
def __eq__(self, other):
if not isinstance(other, OpShardingSharding):
return False
if id(self) == id(other):
return True
return (pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding) and
self._devices == other._devices)
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._devices, self._op_sharding_hash))
return self._hash
def __repr__(self):
if pxla.is_op_sharding_replicated(self._op_sharding):
return 'OpShardingSharding(REPLICATED)'
return f'OpShardingSharding({repr(self._op_sharding)})'
def is_compatible_aval(self, aval_shape: Shape):
num_ways_dim_sharded, _ = pxla._get_num_ways_dim_sharded(self._op_sharding)
if len(aval_shape) < len(num_ways_dim_sharded):
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
f"{len(num_ways_dim_sharded)}, but was applied to a value of rank "
f"{len(aval_shape)}")
@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
return set(self._devices)
@functools.lru_cache(maxsize=4096)
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index]:
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
len(self._devices))
return dict(safe_zip(self._devices, indices))
@property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self._devices)
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
return self._op_sharding
@classmethod
def get_replicated(cls, device_assignment):
proto = _get_replicated_op_sharding()
return cls(device_assignment, proto)