/
xla_bridge.py
383 lines (292 loc) · 12.9 KB
/
xla_bridge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface and utility functions to XLA.
This module wraps the XLA client(s) and builders to standardize their interfaces
and provide some automatic type mapping logic for converting between Numpy and
XLA. There are also a handful of related casting utilities.
"""
from functools import partial
import os
from typing import Callable, Dict
import warnings
from absl import logging
from ..config import flags
from .. import util
from .. import dtypes
import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
import threading
try:
from . import tpu_client
except ImportError:
tpu_client = None
from . import version
from . import xla_client
FLAGS = flags.FLAGS
flags.DEFINE_string(
'jax_xla_backend', 'xla',
'Default is "xla" for the XLA service directly, '
'or "tpu_driver" for using high-performance access to Cloud TPU hardware.')
flags.DEFINE_string(
'jax_backend_target', 'local',
'Either "local" or "rpc:address" to connect to a remote service target.')
flags.DEFINE_string(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', ''),
'Platform name for XLA. The default is to attempt to use a GPU if '
'available, but fall back to CPU otherwise. To set the platform manually, '
'pass "cpu" for CPU or "gpu" for GPU.')
def get_compile_options(num_replicas, num_partitions, device_assignment=None):
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: int indicating the number of replicas for which to compile.
num_partitions: int indicating the number of partitions for which to compile.
device_assignment: Optional tuple of integers indicating the assignment of
logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
if device_assignment is not None:
logging.vlog(
2,
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
num_replicas, num_partitions, device_assignment)
device_assignment = onp.array(device_assignment)
# Allow 1D device assignment if num_partitions is 1.
if (device_assignment.ndim == 1) and (num_partitions == 1):
device_assignment = device_assignment[:, None]
if num_replicas != device_assignment.shape[0]:
msg = 'device_assignment does not match num_replicas: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_replicas))
if num_partitions != device_assignment.shape[1]:
msg = 'device_assignment does not match num_partitions: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_partitions))
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
assert device_assignment.replica_count() == num_replicas
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
return compile_options
_backends = {}
def register_backend(name, factory):
_backends[name] = factory
def _get_local_backend(platform=None):
if not platform:
platform = FLAGS.jax_platform_name
# Canonicalize platform names.
cpu = 'cpu'
gpu = 'gpu'
if platform == 'Host':
platform = cpu
elif platform == 'CUDA':
platform = gpu
elif platform == '':
platform = None
backend = xla_client.get_local_backend(platform)
if backend is None:
raise RuntimeError("No local XLA backends found.")
if backend.platform == cpu and platform != cpu:
warnings.warn('No GPU/TPU found, falling back to CPU.')
return backend
register_backend('xla', _get_local_backend)
# memoize the TPU driver to be consistent with xla_client behavior
_tpu_backend = None
def _get_tpu_driver_backend(platform):
del platform
global _tpu_backend
if _tpu_backend is None:
backend_target = FLAGS.jax_backend_target
if backend_target is None:
raise ValueError('When using TPU Driver as the backend, you must specify '
'--jax_backend_target=<hostname>:8470.')
_tpu_backend = tpu_client.TpuBackend.create(worker=backend_target)
return _tpu_backend
if tpu_client:
register_backend('tpu_driver', _get_tpu_driver_backend)
_backend_lock = threading.Lock()
@util.memoize
def get_backend(platform=None):
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
# 'backend' values are handled
if isinstance(platform, xla_client.Backend):
return platform
with _backend_lock:
backend = _backends.get(FLAGS.jax_xla_backend)
if backend is None:
msg = 'Unknown jax_xla_backend value "{}".'
raise ValueError(msg.format(FLAGS.jax_xla_backend))
return backend(platform)
def get_device_backend(device=None):
"""Returns the Backend associated with `device`, or the default Backend."""
platform = device.platform if device else None
return get_backend(platform)
def device_count(backend=None):
"""Returns the total number of devices.
On most platforms, this is the same as ``local_device_count()``. However, on
multi-host platforms, this will return the total number of devices across all
hosts.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu', 'gpu', or 'tpu'.
Returns:
Number of devices.
"""
return int(get_backend(backend).device_count())
def local_device_count(backend=None):
"""Returns the number of devices on this host."""
return int(get_backend(backend).local_device_count())
def devices(backend=None):
"""Returns a list of all devices.
Each device is represented by a subclass of Device (e.g. CpuDevice,
GpuDevice). The length of the returned list is equal to
``device_count()``. Local devices can be identified by comparing
``Device.host_id`` to ``host_id()``.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu', 'gpu', or 'tpu'.
Returns:
List of Device subclasses.
"""
return get_backend(backend).devices()
def local_devices(host_id=None, backend=None):
"""Returns a list of devices local to a given host (this host by default)."""
if host_id is None:
host_id = get_backend(backend).host_id()
return [d for d in devices(backend) if d.host_id == host_id]
def host_id(backend=None):
"""Returns the integer host ID of this host.
On most platforms, this will always be 0. This will vary on multi-host
platforms though.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu', 'gpu', or 'tpu'.
Returns:
Integer host ID.
"""
return get_backend(backend).host_id()
def host_ids(backend=None):
"""Returns a sorted list of all host IDs."""
return sorted(list(set(d.host_id for d in devices(backend))))
def host_count(backend=None):
"""Returns the number of hosts."""
return len(host_ids(backend))
### utility functions
@util.memoize
def dtype_to_etype(dtype):
"""Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64)."""
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
@util.memoize
def supported_numpy_dtypes():
return {dtypes.canonicalize_dtype(dtype)
for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}
# TODO(mattjj,frostig): try to remove this function
def normalize_to_xla_dtypes(val):
"""Normalize dtypes in a value."""
if hasattr(val, '__array__') or onp.isscalar(val):
return onp.asarray(val,
dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
elif isinstance(val, (tuple, list)):
return tuple(normalize_to_xla_dtypes(x) for x in val)
raise TypeError('Can\'t convert to XLA: {}'.format(val))
class _JaxComputationBuilder(xla_client.ComputationBuilder):
"""Base class implementing all of JaxComputationBuilder.
This class is intended to override and augment the interface of an XLA
ComputationBuilder to form JaxComputationBuilder
"""
# Method name case follows that of the XLA ComputationBuilder
# pylint: disable=invalid-name
def Build(self, *args, **kwargs):
return super(_JaxComputationBuilder, self).Build(
*args, **kwargs)
def NumpyArrayConstant(self, value, canonicalize_types=True):
if canonicalize_types:
value = normalize_to_xla_dtypes(value)
return super(_JaxComputationBuilder, self).Constant(value)
def ConstantLike(self, example_value, value, canonicalize_types=True):
example_value = onp.asarray(example_value)
return self.Constant(onp.array(value, dtype=example_value.dtype))
def Constant(self, py_val, canonicalize_types=True):
"""Translate constant `py_val` to a constant for this ComputationBuilder.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant, either a ComputationDataHandle or None
"""
py_type = type(py_val)
if py_type in _constant_handlers:
return _constant_handlers[py_type](self, py_val, canonicalize_types)
else:
raise TypeError("No constant handler for type: {}".format(py_type))
# TODO(mattjj): remove when CrossReplicaSum is added to XLA:CPU
def CrossReplicaSum(self, operand, replica_groups):
"""Workaround for CrossReplicaSum not being implemented on some backends."""
if len(replica_groups[0]) == 1:
return operand
else:
return super(_JaxComputationBuilder, self).CrossReplicaSum(
operand, replica_groups)
# TODO(mattjj): remove when AllToAll is added to XLA:CPU
def AllToAll(self, operand, split_axis, concat_axis, replica_groups):
"""Workaround for AllToAll not being implemented on some backends."""
if len(replica_groups[0]) == 1:
return operand
else:
return super(_JaxComputationBuilder, self).AllToAll(
operand, split_axis, concat_axis, replica_groups)
def make_computation_builder(name):
return _JaxComputationBuilder(name)
def register_constant_handler(type_, handler_fun):
_constant_handlers[type_] = handler_fun
_constant_handlers: Dict[type, Callable] = {}
def _ndarray_constant_handler(c, val, canonicalize_types=True):
"""Constant handler for ndarray literals, handling zero-size strides.
This function essentially calls c.NumpyArrayConstant(val) except it has
special handling of arrays with any strides of size zero: for those, it
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
to avoid staging in large literals that might arise from np.zeros or np.ones
or the output of lax.broadcast (which uses onp.broadcast_to which in turn
uses size-zero strides).
Args:
c: XLA client ComputationBuilder.
val: an ndarray.
Returns:
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
staged into the XLA Computation.
"""
# TODO(mattjj): revise this to use c.BroadcastInDim rather than Transpose
if onp.any(onp.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = onp.where(onp.equal(0, val.strides))
other_axes, = onp.where(onp.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
for ax in range(val.ndim))]
xla_val = c.Broadcast(
c.NumpyArrayConstant(collapsed_val, canonicalize_types),
onp.take(val.shape, zero_stride_axes))
permutation = onp.argsort(tuple(zero_stride_axes) + tuple(other_axes))
return c.Transpose(xla_val, permutation)
else:
return c.NumpyArrayConstant(val, canonicalize_types)
register_constant_handler(onp.ndarray, _ndarray_constant_handler)
def _scalar_constant_handler(c, val, canonicalize_types=True):
return c.NumpyArrayConstant(val, canonicalize_types)
for scalar_type in [onp.int8, onp.int16, onp.int32, onp.int64,
onp.uint8, onp.uint16, onp.uint32, onp.uint64,
onp.float16, onp.float32, onp.float64, onp.float128,
onp.bool_, onp.longlong]:
register_constant_handler(scalar_type, _scalar_constant_handler)
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
return c.NumpyArrayConstant(dtype.type(val))
for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))