-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
xla.py
1477 lines (1247 loc) · 56.7 KB
/
xla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict, deque
import itertools as it
import operator as op
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
Tuple, Union, NamedTuple)
from warnings import warn
from absl import logging
import numpy as np
from ..config import flags, bool_env, config
from .. import core
from .. import ad_util
from .. import dtypes
from .. import lazy
from .. import linear_util as lu
from jax._src import source_info_util
from ..abstract_arrays import (make_shaped_array, array_types)
from ..core import (ConcreteArray, ShapedArray, AbstractToken,
Literal, pp_eqn_compact, raise_to_shaped, abstract_token)
from jax._src.pprint_util import pp
from ..util import (partial, partialmethod, cache, prod, unzip2,
extend_name_stack, wrap_name, safe_zip, safe_map)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from . import partial_eval as pe
from . import ad
from . import masking
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
xe = xc._xla
xops = xc._xla.ops
# Types
Backend = Any # xc.LocalBackend (why does mypy not like this?)
Device = Any # xc.Device
PyLocalBuffer = Any
XlaOp = Any # xla_extension.XlaOp
XlaShape = Any # xla_client.Shape
XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder
XlaExecutable = Any # xla_extension.LocalExecutable
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_debug_nans',
bool_env('JAX_DEBUG_NANS', False),
'Add nan checks to every operation.')
flags.DEFINE_bool('jax_log_compiles',
bool_env('JAX_LOG_COMPILES', False),
'Print a message each time a `jit` computation is compiled.')
# This flag is set on exit; no logging should be attempted
_on_exit = False
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
# unit representation
def _make_unit_constant(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool')))
def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),)
def _device_put_unit(_, device):
backend = xb.get_device_backend(device)
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')),
device),)
def _make_array_shape(a):
if a.dtype is dtypes.float0:
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
else:
return (xc.Shape.array_shape(a.dtype, a.shape),)
### handlers
xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))
def aval_to_xla_shapes(aval):
try:
return xla_shape_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = {
core.AbstractUnit: _make_unit_shape,
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
def aval_to_result_handler(device: Optional[Device], aval: core.AbstractValue) -> Callable:
try:
return xla_result_handlers[type(aval)](device, aval)
except KeyError as err:
raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err
def array_result_handler(device: Optional[Device], aval: core.ShapedArray):
if aval.dtype is dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0)
return partial(make_device_array, raise_to_shaped(aval), device,
lazy.array(aval.shape))
xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = {
core.AbstractUnit: lambda _, __: lambda _: core.unit,
ShapedArray: array_result_handler,
ConcreteArray: array_result_handler,
}
def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
handler = device_put_handlers.get(type(x))
if handler:
x = canonicalize_dtype(x)
return handler(x, device)
elif hasattr(x, '__jax_array__'):
return device_put(x.__jax_array__(), device)
else:
raise TypeError(f"No device_put handler for type: {type(x)}")
def _device_put_array(x, device: Optional[Device]):
backend = xb.get_device_backend(device)
if x.dtype is dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
return (backend.buffer_from_pyval(x, device),)
def _device_put_scalar(x, device):
return _device_put_array(dtypes.coerce_to_array(x), device)
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = {
core.Unit: _device_put_unit
}
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
# TODO(mattjj): try to remove this canonicalize_dtype stuff
def canonicalize_dtype(x):
typ = type(x)
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
for typ in typ.mro():
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
def _canonicalize_ndarray_dtype(x):
return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(
x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ]))
canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity}
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in array_types)
canonicalize_dtype_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
def abstractify(x) -> core.AbstractValue:
typ = type(x)
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.mro():
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
def _make_abstract_python_scalar(typ, _):
return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True)
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {
core.Unit: lambda _: core.abstract_unit,
}
pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
# We can optionally set a Jaxpr rewriter that can be applied just before
# compilation. This mechanism is used for compiling id_tap, we can
# remove it once we bring the id_tap implementation into the core.
outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
if outfeed_rewriter is not None:
return outfeed_rewriter(jaxpr)
else:
return jaxpr
outfeed_primitives: Set[core.Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
for eqn in jaxpr.eqns)
def _param_uses_outfeed(param):
if type(param) is core.Jaxpr:
if jaxpr_uses_outfeed(param):
return True
elif type(param) is core.ClosedJaxpr:
if jaxpr_uses_outfeed(param.jaxpr):
return True
return False
def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
if prim in outfeed_primitives:
return True
for param in params.values():
if isinstance(param, tuple):
if any(unsafe_map(_param_uses_outfeed, param)):
return True
elif _param_uses_outfeed(param):
return True
return False
### op-by-op execution
def arg_spec(x):
aval = abstractify(x)
try:
return aval, x._device
except:
return aval, None
def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
return compiled_fun(*args)
def _partition_outputs(avals, outs):
nouts = [aval._num_buffers for aval in avals]
if not core.skip_checks:
assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}."
outs = iter(outs)
return [[next(outs) for _ in range(nout)] for nout in nouts]
@cache()
def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue,
Optional[Device]], **params):
avals, arg_devices = unzip2(arg_specs)
donated_invars = (False,) * len(arg_specs)
device = _device_from_arg_devices(arg_devices)
backend = xb.get_device_backend(device)
if primitive_uses_outfeed(prim, params):
# We use the _xla_callable path, where we pre-process the primitives
def prim_fun(*args):
return prim.bind(*args, **params)
return _xla_callable(lu.wrap_init(prim_fun), device, None, "prim", donated_invars,
*arg_specs)
aval_out = prim.abstract_eval(*avals, **params)
if not prim.multiple_results:
handle_result = aval_to_result_handler(device, aval_out)
else:
handlers = map(partial(aval_to_result_handler, device), aval_out)
handle_result = lambda *bufs:\
tuple(handler(*bs) for handler, bs in zip(handlers, _partition_outputs(aval_out, bufs)))
tuple_args = len(avals) > 100
if prim in initial_style_translations:
nreps = initial_style_primitive_replicas(params)
else:
nreps = 1
if nreps > xb.device_count(backend):
raise ValueError(
f"compiling a primitive computation `{prim}` that requires {nreps} "
f"replicas, but only {xb.device_count(backend)} XLA devices are "
f"available on backend {backend.platform}.")
built_c = primitive_computation(prim, AxisEnv(nreps, (), ()), backend,
tuple_args, *avals, **params)
options = xb.get_compile_options(
num_replicas=nreps,
num_partitions=1,
device_assignment=device and (device.id,))
options.parameter_is_tupled_arguments = tuple_args
compiled = backend_compile(backend, built_c, options)
if nreps == 1:
return partial(_execute_compiled_primitive, prim, compiled, handle_result)
else:
return partial(_execute_replicated_primitive, prim, compiled, handle_result)
def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]:
"""Given devices of inputs, determine where to perform a computation.
Args:
devices: list where each element is a either a `Device` instance or `None`.
Returns:
A `Device` instance or None.
Raises:
ValueError if input devices are inconsistent.
"""
try:
device, = {d for d in devices if d is not None} or (None,)
return device
except ValueError as err:
msg = "primitive arguments must be colocated on the same device, got {}"
raise ValueError(msg.format(", ".join(map(str, devices)))) from err
@cache()
def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params):
c = xb.make_computation_builder(f"primitive_computation_{prim.name}")
c.set_op_metadata(xc.OpMetadata(
op_type=prim.name,
op_name=str(pp_eqn_compact(prim.name, params))))
platform = xb.get_backend(backend).platform
xla_args, _ = _xla_callable_args(c, avals, tuple_args)
# return val always set as a side-effect on c
if prim in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][prim]
ans = rule(c, *xla_args, **params)
elif prim in translations:
rule = translations[prim]
ans = rule(c, *xla_args, **params)
elif prim in initial_style_translations:
rule = initial_style_translations[prim]
ans = rule(c, axis_env, extend_name_stack(prim.name), avals, backend,
*xla_args, **params)
else:
raise NotImplementedError(f"XLA translation rule for {prim} not found")
assert isinstance(ans, xe.XlaOp)
c.clear_op_metadata()
try:
return c.build()
except RuntimeError as e:
msg = (" ".join(map(str, e.args)) + "\n"
"This is a bug in JAX's shape-checking rules; please report it!\n"
"https://github.com/google/jax/issues\n")
raise RuntimeError(msg) from e
def primitive_subcomputation(prim, *avals, **params):
axis_env = AxisEnv(1, (), ())
return primitive_computation(prim, axis_env, None, False, *avals, **params)
def backend_compile(backend, built_c, options):
# we use a separate function call to ensure that XLA compilation appears
# separately in Python profiling results
return backend.compile(built_c, compile_options=options)
def _execute_compiled_primitive(prim, compiled, result_handler, *args):
device, = compiled.local_devices()
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
out_bufs = compiled.execute(input_bufs)
if FLAGS.jax_debug_nans: check_nans(prim, out_bufs)
return result_handler(*out_bufs)
def _execute_replicated_primitive(prim, compiled, result_handler, *args):
input_bufs = [
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
for device in compiled.local_devices()]
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
return result_handler(*out_bufs)
def check_nans(prim, bufs):
for buf in bufs:
# TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
# the default.
_check_nans(prim.name, getattr(buf, "xla_shape", buf.shape)(), buf)
def _check_nans(name, xla_shape, buf):
assert not xla_shape.is_tuple()
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
if np.any(np.isnan(buf.to_py())):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
### compiling jaxprs
def prefetch(x):
if isinstance(x, DeviceArray):
x.copy_to_host_async()
return x
def jaxpr_literals(jaxpr):
"""Generates all the literals inside a jaxpr, including nested subjaxprs."""
for eqn in jaxpr.eqns:
for v in eqn.invars:
if type(v) is core.Literal:
yield v.val
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_literals(subjaxpr)
def _flatmap(func: Callable, vars: Sequence):
return list(it.chain.from_iterable(map(func, vars)))
def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
return map(func, vars, _partition_outputs([v.aval for v in vars], nodes))
def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args):
if backend not in ('cpu', 'gpu', 'tpu'):
platform = xb.get_backend(backend).platform # canonicalize
else:
platform = backend
def read(v):
if type(v) is Literal:
return [xb.constant(c, canonicalize_dtype(v.val))]
else:
return env[v]
def aval(v):
if type(v) is Literal:
return abstractify(v.val)
else:
return v.aval
def write(v, node):
assert node is not None
env[v] = node
env = {}
_partitionmap(write, [core.unitvar], [_make_unit_constant(c)])
_partitionmap(write, jaxpr.constvars, consts)
_partitionmap(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
frame = source_info_util.user_frame(eqn.source_info)
c.set_op_metadata(xc.OpMetadata(
op_type=eqn.primitive.name,
op_name=str(pp(name_stack) >> pp_eqn_compact(
eqn.primitive.name, eqn.params)),
source_file=frame.file_name if frame else None,
source_line=frame.line_num if frame else None))
in_nodes = _flatmap(read, eqn.invars)
if eqn.primitive in backend_specific_translations[platform]:
rule = backend_specific_translations[platform][eqn.primitive]
ans = rule(c, *in_nodes, **eqn.params)
elif eqn.primitive in translations:
ans = translations[eqn.primitive](c, *in_nodes, **eqn.params)
elif eqn.primitive in initial_style_translations:
new_params = check_backend_params(eqn.params, backend)
rule = initial_style_translations[eqn.primitive]
ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name),
map(aval, eqn.invars), backend, *in_nodes, **new_params)
elif eqn.primitive in parallel_translations:
rule = parallel_translations[eqn.primitive]
ans = rule(c, *in_nodes, axis_env=axis_env, platform=platform, **eqn.params)
elif eqn.primitive in call_translations:
new_params = check_backend_params(eqn.params, backend)
rule = call_translations[eqn.primitive]
ans = rule(c, axis_env, in_nodes,
name_stack, backend=backend, **new_params)
else:
raise NotImplementedError(
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
assert isinstance(ans, xe.XlaOp)
c.get_shape(ans) # force xla to do shape error checking
if eqn.primitive.multiple_results or any(v.aval._num_buffers > 1 for v in eqn.outvars):
out_nodes = xla_destructure(c, ans)
else:
out_nodes = [ans]
c.clear_op_metadata()
_partitionmap(write, eqn.outvars, out_nodes)
return _flatmap(read, jaxpr.outvars)
def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
def check_backend_params(params, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
inner_backend = params.get('backend', None)
if inner_backend and inner_backend != outer_backend:
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")
return {k: params[k] for k in params if k != 'backend'}
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
def extend_axis_env(env: AxisEnv, name, size: int):
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
def axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError("unbound axis name: {}".format(axis_name)) from None
def axis_groups(axis_env: AxisEnv, name):
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)
def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
Arguments:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))
def jaxpr_replicas(jaxpr):
"""The number of replicas needed for a jaxpr.
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
subjaxprs. For a list of eqns, take the maximum number of replicas.
"""
return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
# TODO(mattjj): this function assumes that only pmap has a parameter named
# axis_size, and that it corresponds to cross-replica mapping
def eqn_replicas(eqn):
call_jaxpr = eqn.params.get("call_jaxpr")
if call_jaxpr:
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
elif eqn.primitive in initial_style_translations:
return initial_style_primitive_replicas(eqn.params)
else:
return 1
def initial_style_primitive_replicas(params):
return max(core.traverse_jaxpr_params(jaxpr_replicas, params), default=1)
# TODO(mattjj,skyewm): the functions here are utilities for checking if
# not-yet-supported features are used with multi-host programming
def jaxpr_has_pmap(jaxpr):
"""Whether there is an xla_pmap primitive anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if 'xla_pmap' in eqn.primitive.name:
return True
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_has_pmap(subjaxpr):
return True
return False
def jaxpr_collectives(jaxpr):
"""Generates all the collective primitives anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if eqn.primitive in parallel_translations:
yield eqn.primitive
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_collectives(subjaxpr)
### xla_call underlying jit
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
try:
return compiled_fun(*args)
except FloatingPointError:
assert FLAGS.jax_debug_nans # compiled_fun can only raise in this case
print("Invalid value encountered in the output of a jit function. "
"Calling the de-optimized version.")
# We want to run the wrapped function again (after _xla_callable already ran
# it), but linear_util.WrappedFun instances are meant to be run only once.
# In addition to re-executing the Python code, which is usually undesirable
# but which FLAGS.jax_debug_nans is meant to opt into, we'll be re-executing
# any linear_util.py-style side effects, i.e. re-populating Stores created
# by any transformation_with_aux's applied to fun. Since this is
# intentional here, to avoid "Store occupied" errors we reset the stores to
# be empty.
for store in fun.stores: store and store.reset()
return fun.call_wrapped(*args) # probably won't return
def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]:
"""Expands a given shape tree into a flat list of indices to arrays.
Given the following computation:
>>> c = xc.XlaBuilder("example")
>>> p0 = xb.parameter(c, 1, xc.shape_from_pyval(jnp.ones([1])))
>>> p1 = xb.parameter(c, 2, xc.shape_from_pyval(jnp.ones([2])))
>>> p2 = xb.parameter(c, 3, xc.shape_from_pyval(jnp.ones([3])))
>>> o = xops.Tuple(c, [p0, p1, p2])
We can query the arrays in the output tuple:
>>> flatten_shape(c.GetShape(o))
(((0,), f32[1]{0}),
((1,), f32[2]{0}),
((2,), f32[3]{0}))
Or the arrays in one of the parameters (which is itself an array):
>>> flatten_shape(c.GetShape(p0))
(((), f32[1]{0}),)
Args
s: The input shape.
Returns:
An iterable of pairs of indices and shapes for each array within the shape
tree.
"""
def _flatten_shape(s, index):
if s.is_array():
yield index, s
else:
assert s.is_tuple()
for i, sub in enumerate(s.tuple_shapes()):
subindex = index + (i,)
if sub.is_tuple():
yield from _flatten_shape(sub, subindex)
else:
yield subindex, sub
return tuple(_flatten_shape(s, index=()))
def _xla_consts(c, consts):
unique_consts = {id(const): const for const in consts}
xla_consts = {
id_: xb.constant(c, const) for id_, const in unique_consts.items()}
return [xla_consts[id(const)] for const in consts]
@lu.cache
def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = unzip2(arg_specs)
if config.omnistaging_enabled:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
if any(isinstance(c, core.Tracer) for c in consts):
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
else:
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore
fun, pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
jaxpr = apply_outfeed_rewriter(jaxpr)
nreps = jaxpr_replicas(jaxpr)
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = device.platform if device else backend
if config.omnistaging_enabled:
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
else:
out_avals = [pval.get_aval() for pval in pvals]
result_handlers = map(partial(_pval_to_result_handler, device), pvals) # type: ignore
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to force their (potentially lazy) arguments.
if not jaxpr.eqns:
return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers)
if not _on_exit:
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args)
if nreps > 1:
warn(f"The jitted function {fun.__name__} includes a pmap. Using "
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")
if nreps > xb.device_count(backend):
raise ValueError(
f"compiling computation that requires {nreps} replicas, but only "
f"{xb.device_count(backend)} XLA devices are available")
if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")
tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder("jit_{}".format(fun.__name__))
xla_consts = _xla_consts(c, consts)
xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args, donated_invars=donated_invars)
out_nodes = jaxpr_subcomp(
c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts,
extend_name_stack(wrap_name(name, 'jit')), *xla_args)
out_tuple = xops.Tuple(c, out_nodes)
backend = xb.get_backend(backend)
if backend.platform in ("gpu", "tpu"):
donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
if any(donated_invars):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(c.GetShape(a))
for a, d in zip(xla_args, donated_invars) if d]
warn("Some donated buffers were not usable: {}".format(", ".join(unused_donations)))
built = c.build(out_tuple)
options = xb.get_compile_options(
num_replicas=nreps,
num_partitions=1,
device_assignment=(device.id,) if device else None)
options.parameter_is_tupled_arguments = tuple_args
compiled = backend_compile(backend, built, options)
if nreps == 1:
return partial(_execute_compiled, compiled, out_avals, result_handlers)
else:
return partial(_execute_replicated, compiled, out_avals, result_handlers)
def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
"""Configures input/output "must" aliasing based on `donated_args`."""
# First for every input array add it to `donations` iff it is a member of
# `donated_args`.
donations = defaultdict(deque)
for arg_index, arg in enumerate(xla_args):
if donated_args[arg_index]:
for param_index, element in flatten_shape(c.GetShape(arg)):
key = (element.dimensions(), element.numpy_dtype())
if tuple_args:
param_number = 0
param_index = (arg_index,) + tuple(param_index)
donations[key].append((param_number, param_index, arg_index))
else:
param_number = arg_index
donations[key].append((param_number, param_index, arg_index))
# Consume donations for outputs.
out_donated_args = list(donated_args)
for output_index, element in flatten_shape(c.GetShape(out_tuple)):
key = (element.dimensions(), element.numpy_dtype())
if donations.get(key, ()):
param_number, param_index, arg_index = donations[key].popleft()
out_donated_args[arg_index] = False
c.setup_alias(output_index, param_number, param_index)
return tuple(out_donated_args)
def _xla_callable_device(nreps, backend, device, arg_devices):
if nreps > 1:
if device is not None or backend is not None:
raise ValueError(f"can't specify device or backend for jit-of-pmap, "
f"got device={device} and backend={backend}")
return None
else:
if device is None and backend is None:
return _device_from_arg_devices(arg_devices)
elif device is not None and backend is None:
return device
elif device is None and backend is not None:
return xb.get_backend(backend).get_default_device_assignment(1)[0]
else:
assert False # Unreachable given the error check in _xla_callable
# Used within _xla_callable_args and _xla_param to distinguish between None (no
# sharding annotation set) and replicated.
_replicated_param = object()
def _xla_callable_args(
c, avals, tuple_args, *,
replicated=None,
partitions=None,
partitions_proto: bool = False,
donated_invars=None):
assert partitions is None or len(partitions) == len(avals)
if not tuple_args:
if replicated is None:
replicated = [None] * len(avals)
if partitions is None:
parts: List[object] = [None] * len(avals)
elif partitions_proto:
parts = partitions
else:
parts = [_replicated_param if part is None else part
for part in partitions]
counts = it.count()
xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto)
if a is not abstract_token else xops.CreateToken(c)
for (a, r, p) in safe_zip(avals, replicated, parts)
for xla_shape in aval_to_xla_shapes(a)]
if donated_invars is not None:
donated_invars = [d
for (a, r, p, d) in safe_zip(avals, replicated, parts, donated_invars)
for xla_shape in aval_to_xla_shapes(a)]
return xla_args, donated_invars
else:
if replicated is not None:
replicated = [r for a, r in zip(avals, replicated)
if a is not abstract_token]
if partitions is None:
tuple_parts = None
elif partitions_proto:
tuple_parts = xb.tuple_sharding_proto(partitions)
else:
tuple_parts = tuple(partitions)
tuple_shape = xc.Shape.tuple_shape(
[shape for a in avals for shape in aval_to_xla_shapes(a) if a is not abstract_token])
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts, partitions_proto)
xla_inputs = iter(xla_destructure(c, tuple_param))
xla_args = [next(xla_inputs) if a is not abstract_token else
xops.CreateToken(c) for a in avals]
assert next(xla_inputs, None) is None
return xla_args, donated_invars
def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_proto):
make_param = partial(xb.parameter, builder, param_num, xla_shape,
replicated=replicated)
with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding
if partitions is None:
return make_param()
elif partitions is _replicated_param:
return with_sharding(builder, None, make_param)
else:
return with_sharding(builder, partitions, make_param)
def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args):
device, = compiled.local_devices()
input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
out_bufs = compiled.execute(input_bufs)
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args):
input_bufs = [
list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
for device in compiled.local_devices()]
out_bufs = compiled.execute_on_local_devices(input_bufs)[0]
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args):
env = {core.unitvar: core.unit}
map(env.setdefault, jaxpr.invars, args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v]
for v in jaxpr.outvars]
return [_copy_device_array_to_device(x, device) if type_is_device_array(x)
else h(*device_put(x, device)) for h, x in zip(handlers, outs)]
xla_call_p = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
xla_call_p.def_impl(_xla_call_impl)
def _xla_call_partial_eval_update_params(params, in_unknowns):
call_jaxpr = params['call_jaxpr']
donated_invars = params['donated_invars']
if not in_unknowns and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
new_donated_invars = (False,) * len(call_jaxpr.invars)
else:
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, uk in zip(donated_invars, in_unknowns) if uk]
new_donated_invars = ((False,) * (len(call_jaxpr.invars) - len(donated_invars))
+ tuple(donated_invars))
return dict(params, donated_invars=new_donated_invars)
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
def _xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
def _xla_call_translation_rule(c, axis_env,
in_nodes, name_stack, backend, name,
call_jaxpr, donated_invars, device=None):
del device, donated_invars # Ignored.
subc = xb.make_computation_builder(f"jit_{name}")
args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'jit')), *args)
subc = subc.build(xops.Tuple(subc, out_nodes))
return xops.Call(c, subc, list(in_nodes))
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
### translation tables
translations: Dict[core.Primitive, Callable] = {}
parallel_translations: Dict[core.Primitive, Callable] = {}
initial_style_translations: Dict[core.Primitive, Callable] = {}
call_translations: Dict[core.Primitive, Callable] = {}
backend_specific_translations: Dict[str, Dict[core.Primitive, Callable]] = defaultdict(dict)
call_translations[xla_call_p] = _xla_call_translation_rule
def zeros_like_translation_rule(c, x):
shape = c.get_shape(x)
assert not shape.is_tuple()
zero = xb.constant(c, np.array(0, shape.element_type()))
return xops.Broadcast(zero, shape.dimensions())
translations[ad_util.zeros_like_p] = zeros_like_translation_rule
def add_jaxvals_translation_rule(c, x, y):
shape = c.get_shape(x)
assert not shape.is_tuple()
return xops.Add(x, y)
translations[ad_util.add_jaxvals_p] = add_jaxvals_translation_rule
translations[ad_util.stop_gradient_p] = lambda c, x: x
@lu.transformation
def _tuple_output(*args, **kwargs):
ans = yield args, kwargs
yield (ans,)
def lower_fun(fun, multiple_results, parallel=False):
# This function can only be used to lower functions that take JAX array types
# as arguments (and e.g. don't accept unit values), because it assumes it can
# map from XLA types to JAX types. In general that mapping is not possible (as
# the mapping from JAX types to XLA types is not invertible), but for now at
# least we assume that the mapping from JAX *array* types to XLA array types
# is invertible. This assumption is unchecked!
# TODO(mattjj): remove assumption can map XLA array types to JAX array types
def f(c, *xla_args, **params):
# TODO(mattjj): revise this 'calling convention'
avals = [_array_aval_from_xla_shape(c.get_shape(x)) for x in xla_args]
if parallel:
axis_env = params.pop('axis_env')
del params['platform']
else:
axis_env = AxisEnv(1, (), ())
wrapped_fun = lu.wrap_init(fun, params)
if not multiple_results:
wrapped_fun = _tuple_output(wrapped_fun)
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '',
*xla_args)
else:
pvals = [pe.PartialVal.unknown(a) for a in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
stage_out=True) # type: ignore
xla_consts = _xla_consts(c, consts)
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args)
if multiple_results:
return xops.Tuple(c, outs)
else:
assert len(outs) == 1, outs
return outs[0]
return f
def _array_aval_from_xla_shape(xla_shape):
# This function instantiates the assumption that we can map fro XLA array
# types to JAX array types.
# TODO(mattjj): remove assumption can map XLA array types to JAX array types
assert not xla_shape.is_tuple()
return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
def lower_fun_initial_style(fun):
def f(c, axis_env, name_stack, avals, backend, *xla_args, **params):
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
name_stack, *xla_args)
else:
pvals = [pe.PartialVal.unknown(a) for a in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) # type: ignore
xla_consts = _xla_consts(c, consts)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack,
*xla_args)
return xops.Tuple(c, outs)
return f
### device-persistent data
class Token(object): pass
token = Token()
pytype_aval_mappings[Token] = lambda _: abstract_token