-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
dispatch.py
1080 lines (913 loc) · 42.2 KB
/
dispatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Primitive dispatch and jit dispatch.
from __future__ import annotations
import atexit
import contextlib
from functools import partial
import itertools
import time
from typing import (
Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union)
from typing_extensions import Protocol
import os
import re
import threading
import warnings
from absl import logging
import numpy as np
import jax
from jax import core
from jax import linear_util as lu
from jax.errors import UnexpectedTracerError
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.masking as masking
import jax.interpreters.mlir as mlir
import jax.interpreters.xla as xla
import jax.interpreters.partial_eval as pe
from jax._src import device_array
from jax._src import dtypes
from jax._src import profiler
from jax._src import stages
from jax._src import traceback_util
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
import jax._src.util as util
from jax._src.util import flatten, unflatten
from etils import epath
FLAGS = flags.FLAGS
flags.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which HLO/MHLO IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX "
"will not dump IR.")
traceback_util.register_exclusion(__file__)
MYPY = False # Are we currently type checking with mypy?
xe = xc._xla
Backend = xe.Client
Device = xc.Device
Buffer = xe.Buffer
XlaExecutable = xc.Executable
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
# This flag is set on exit; no logging should be attempted
_on_exit = False
### op-by-op execution
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
def arg_spec(x: Any) -> ArgSpec:
aval = xla.abstractify(x)
try:
return aval, x._device
except:
return aval, None
def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
**params)
return compiled_fun(*args)
# TODO(phawkins): update code referring to xla.apply_primitive to point here.
xla.apply_primitive = apply_primitive
RuntimeToken = Any
class RuntimeTokenSet(threading.local):
tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]]
output_tokens: Dict[Device, RuntimeToken]
def __init__(self):
self.tokens = {}
self.output_tokens = {}
def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
if eff not in self.tokens:
self.tokens[eff] = device_put(np.zeros(0, np.bool_), device), device
elif self.tokens[eff][1] != device:
(old_token,), _ = self.tokens[eff]
old_token.aval = core.ShapedArray((0,), np.bool_)
self.tokens[eff] = device_put(old_token, device), device
return self.tokens[eff][0]
def update_token(self, eff: core.Effect, token: RuntimeToken):
self.tokens[eff] = token, self.tokens[eff][1]
def set_output_token(self, device: Device, token: RuntimeToken):
# We're free to clobber the previous output token because on each
# device we have a total ordering of computations. Only the token
# from the latest computation matters. If this weren't the case
# we'd need to store a set of output tokens.
self.output_tokens[device] = token
def clear(self):
self.tokens = {}
self.output_tokens = {}
def block_until_ready(self):
for t, _ in self.tokens.values():
t[0].block_until_ready()
# TODO(sharadmv): use a runtime mechanism to block on computations instead
# of using output tokens.
for t in self.output_tokens.values():
t[0].block_until_ready()
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
@atexit.register
def wait_for_tokens():
runtime_tokens.block_until_ready()
@util.cache()
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
avals, arg_devices = util.unzip2(arg_specs)
donated_invars = (False,) * len(arg_specs)
device = _device_from_arg_devices(arg_devices)
def prim_fun(*args):
out = prim.bind(*args, **params)
if prim.multiple_results:
return out
else:
return out,
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
prim.name, donated_invars, False, *arg_specs)
if not prim.multiple_results:
return lambda *args, **kw: compiled(*args, **kw)[0]
else:
return compiled
def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]:
"""Given devices of inputs, determine where to perform a computation.
Args:
devices: list where each element is a either a `Device` instance or `None`.
Returns:
A `Device` instance or None.
Raises:
ValueError if input devices are inconsistent.
"""
try:
device, = {d for d in devices if d is not None} or (None,)
return device
except ValueError as err:
msg = "primitive arguments must be colocated on the same device, got {}"
raise ValueError(msg.format(", ".join(map(str, devices)))) from err
# JIT execution
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline, keep_unused: bool):
del inline # Only used at tracing time
arg_specs = unsafe_map(arg_spec, args)
compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
keep_unused, *arg_specs)
try:
return compiled_fun(*args)
except FloatingPointError:
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid value encountered in the output of a jit-decorated function. "
"Calling the de-optimized version.")
# We want to run the wrapped function again (after xla_callable already ran
# it), but linear_util.WrappedFun instances are meant to be run only once.
# In addition to re-executing the Python code, which is usually undesirable
# but which config.jax_debug_nans is meant to opt into, we'll be
# re-executing any linear_util.py-style side effects, i.e. re-populating
# Stores created by any transformation_with_aux's applied to fun. Since this
# is intentional here, to avoid "Store occupied" errors we clone the
# WrappedFun with empty stores.
stores = [lu.Store() for _ in fun.stores]
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params,
fun.in_type)
with core.new_sublevel():
_ = clone.call_wrapped(*args) # may raise, not return
# If control reaches this line, we got a NaN on the output of `compiled_fun`
# but not `clone.call_wrapped` on the same arguments. Let's tell the user.
fun_info = pe.fun_sourceinfo(fun.f)
msg = ("An invalid value was encountered in the output of the "
f"`jit`-decorated function {fun_info}. Because "
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not "
"produce invalid values during its execution. This behavior can "
"result from `jit` optimizations causing the invalud value to be "
"produced. It may also arise from having nan/inf constants as "
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
"\n\n"
"It may be possible to avoid the invalid value by removing the "
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
"https://github.com/google/jax.")
raise FloatingPointError(msg)
xla.xla_call_p.def_impl(_xla_call_impl)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, keep_unused, *arg_specs):
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
keep_unused, *arg_specs).compile().unsafe_call
xla_callable = lu.cache(_xla_callable_uncached)
@contextlib.contextmanager
def log_elapsed_time(fmt: str):
if _on_exit:
yield
else:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
start_time = time.time()
yield
elapsed_time = time.time() - start_time
logging.log(log_priority, fmt.format(elapsed_time=elapsed_time))
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, always_lower: bool, keep_unused: bool,
*arg_specs):
"""Lower into XLA.
Args:
always_lower: If `True`, even trivial programs (not doing any computation
such as lambda x: x) will be lowered into an XLA program.
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
"""
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
if fun.in_type is None:
# Add an annotation inferred from the arguments; no dynamic axes here.
in_type = tuple(unsafe_zip(abstract_args, itertools.repeat(True)))
fun = lu.annotate(fun, in_type)
else:
# Check that the provided abstract_args are consistent with in_type by first
# collecting values of axis size arguments, then substituting them in for
# DBIdx occurrences.
axis_sizes: Dict[core.DBIdx, int] = {}
abstract_args_iter = iter(abstract_args)
for expected_type, explicit in fun.in_type:
if explicit:
aval = next(abstract_args_iter)
if isinstance(expected_type, core.DShapedArray):
# Check the value for any DBIdx variables is consistent.
assert all(axis_sizes.setdefault(d1, d2) == d2
for d1, d2 in zip(expected_type.shape, aval.shape)
if type(d1) is core.DBIdx)
# Check the type matches after substitution.
expected_shape = [axis_sizes.get(d, d) for d in expected_type.shape] # type: ignore
expected_aval = core.ShapedArray(
shape=tuple(expected_shape), dtype=expected_type.dtype,
weak_type=expected_type.weak_type)
assert core.typematch(expected_aval, aval)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
fun, pe.debug_info_final(fun, "jit"))
out_avals, kept_outputs = util.unzip2(out_type)
if any(isinstance(c, core.Tracer) for c in consts):
raise UnexpectedTracerError("Encountered an unexpected tracer.")
if config.jax_dynamic_shapes:
keep_unused = True
has_outfeed = False
else:
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = apply_outfeed_rewriter(jaxpr)
if not keep_unused:
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
abstract_args, arg_devices = util.unzip2(
[a for i, a in enumerate(arg_specs) if i in kept_var_idx])
donated_invars = [x for i, x in enumerate(donated_invars)
if i in kept_var_idx]
del kept_const_idx
else:
kept_var_idx = set(range(len(abstract_args)))
nreps = jaxpr_replicas(jaxpr)
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and
not _backend_supports_unbounded_dynamic_shapes(backend)):
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars)):
return XlaComputation(
name, None, True, None, None, None, jaxpr=jaxpr, consts=consts,
device=device, in_avals=abstract_args, out_avals=out_avals,
has_unordered_effects=False, ordered_effects=[],
kept_var_idx=kept_var_idx, keepalive=None)
if not _on_exit:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
if len(abstract_args) > 10:
msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
else:
msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
logging.log(log_priority, msg)
if nreps > 1:
warnings.warn(
f"The jitted function {name} includes a pmap. Using "
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")
if nreps > xb.device_count(backend):
raise ValueError(
f"compiling computation `{name}` that requires {nreps} replicas, but "
f"only {xb.device_count(backend)} XLA devices are available.")
if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)):
raise NotImplementedError(
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")
# pass long arg lists as tuple for TPU
tuple_args = len(abstract_args) > 100
axis_env = xla.AxisEnv(nreps, (), ())
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module_name = f"jit_{fun.__name__}"
unordered_effects = [eff for eff in closed_jaxpr.effects
if eff not in core.ordered_effects]
ordered_effects = [eff for eff in closed_jaxpr.effects
if eff in core.ordered_effects]
module, keepalive = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, unordered_effects, ordered_effects,
backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack,
donated_invars)
return XlaComputation(
name, module, False, donated_invars, fun.in_type, out_type, nreps=nreps,
device=device, backend=backend, tuple_args=tuple_args,
in_avals=abstract_args, out_avals=out_avals,
has_unordered_effects=bool(unordered_effects),
ordered_effects=ordered_effects, kept_var_idx=kept_var_idx,
keepalive=keepalive)
def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
return backend.platform == 'iree'
def prefetch(x):
if isinstance(x, device_array.DeviceArray):
x.copy_to_host_async()
return x
def jaxpr_literals(jaxpr):
"""Generates all the literals inside a jaxpr, including nested subjaxprs."""
for eqn in jaxpr.eqns:
for v in eqn.invars:
if type(v) is core.Literal:
yield v.val
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_literals(subjaxpr)
def jaxpr_has_pmap(jaxpr):
"""Whether there is an xla_pmap primitive anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if 'xla_pmap' in eqn.primitive.name:
return True
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_has_pmap(subjaxpr):
return True
return False
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
return (any(type(d) is core.Var for v in jaxpr.invars
if type(v.aval) is core.DShapedArray for d in v.aval.shape) or
any(type(d) is core.Var
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
for e in j.eqns for v in itertools.chain(e.invars, e.outvars)
if type(v.aval) is core.DShapedArray for d in v.aval.shape))
def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
# applications that do not produce used outputs. Must handle side-effecting
# primitives and nested jaxpr.
used.update(
v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
kept_const_idx, new_constvars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
kept_var_idx, new_invars = util.unzip2(
(i, v) for i, v in enumerate(jaxpr.invars) if v in used)
new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns,
jaxpr.effects)
return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
# We can optionally set a Jaxpr rewriter that can be applied just before
# compilation. This mechanism is used for compiling id_tap, we can
# remove it once we bring the id_tap implementation into the core.
outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
if outfeed_rewriter is not None:
return outfeed_rewriter(jaxpr)
else:
return jaxpr
def jaxpr_replicas(jaxpr) -> int:
"""The number of replicas needed for a jaxpr.
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
subjaxprs. For a list of eqns, take the maximum number of replicas.
"""
if isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = jaxpr.jaxpr
return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
# TODO(mattjj): this function assumes that only pmap has a parameter named
# axis_size, and that it corresponds to cross-replica mapping
def eqn_replicas(eqn):
call_jaxpr = eqn.params.get("call_jaxpr")
if call_jaxpr:
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
elif eqn.primitive in xla._initial_style_primitives:
return initial_style_primitive_replicas(eqn.params)
else:
return 1
def initial_style_primitive_replicas(params):
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)
def _xla_callable_device(nreps, backend, device,
arg_devices) -> Optional[Device]:
if nreps > 1:
if device is not None or backend is not None:
raise ValueError(f"can't specify device or backend for jit-of-pmap, "
f"got device={device} and backend={backend}")
return None
else:
# TODO(skye): dedup with C++ jit logic for determining jit device?
if device is not None:
assert backend is None
return device
if backend is not None:
return xb.get_backend(backend).get_default_device_assignment(1)[0]
arg_device = _device_from_arg_devices(arg_devices)
if arg_device is not None:
return arg_device
return config.jax_default_device
# Argument and result handlers
num_buffers_handlers: Dict[Type[core.AbstractValue],
Callable[[core.AbstractValue], int]] = {}
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
"""Returns the number of buffers in the runtime representation of `aval`.
In general this may differ from the number of buffers in the compiler-IR
representation of the same value.
"""
try:
return num_buffers_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
num_buffers_handlers[core.AbstractToken] = lambda _: 1
num_buffers_handlers[core.ShapedArray] = lambda _: 1
num_buffers_handlers[core.DShapedArray] = lambda _: 1
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
def _input_handler(backend: Backend,
in_type: Optional[pe.InputType],
out_type: Optional[pe.OutputType],
) -> Optional[Callable]:
if in_type is None:
assert out_type is None
return None
in_avals, which_explicit = util.unzip2(in_type)
# Check whether we actually need an input_handler.
needs_implicit = which_explicit and not all(which_explicit)
needs_out_handling = any(type(d) is core.InDBIdx for a in out_type or []
if type(a) is core.DShapedArray for d in a.shape)
if not needs_implicit and not needs_out_handling:
return None
assert config.jax_dynamic_shapes
# Precompute how to grab implicit inputs from explicit inputs' axis sizes.
which_explicit = which_explicit or [True] * len(in_avals)
implicit_idxs = {i for i, ex in enumerate(which_explicit) if not ex}
implicit_args_from_axes: List[Tuple[int, int, int]] = []
for arg_idx, aval in enumerate(in_avals):
if isinstance(aval, core.DShapedArray):
for axis_idx, d in enumerate(aval.shape):
if isinstance(d, core.DBIdx) and d.val in implicit_idxs:
implicit_args_from_axes.append((d.val, arg_idx, axis_idx))
assert {i for i, _, _ in implicit_args_from_axes} == implicit_idxs
# Precompute which input values are needed for output types.
inputs_needed_for_out_types = out_type and [
d.val for aval in out_type if type(aval) is core.DShapedArray # type: ignore
for d in aval.shape if type(d) is core.InDBIdx]
def elaborate(explicit_args: Sequence[Any]) -> Tuple[Tuple, Optional[Tuple]]:
if needs_implicit:
# Build full argument list, leaving Nones for implicit arguments.
explicit_args_ = iter(explicit_args)
args = [next(explicit_args_) if ex else None for ex in which_explicit]
assert next(explicit_args_, None) is None
# Populate implicit arguments.
for i, j, k in implicit_args_from_axes:
if args[i] is None:
args[i] = args[j].shape[k] # type: ignore
else:
if args[i] != args[j].shape[k]:
raise Exception("inconsistent argument axis sizes for type")
else:
args = list(explicit_args)
if needs_out_handling:
# Make a list of inputs needed by output types, leaving unneeded as None.
out_type_env = [None] * len(args)
for i in inputs_needed_for_out_types or []:
out_type_env[i] = args[i]
else:
out_type_env = None # type: ignore
return tuple(args), out_type_env and tuple(out_type_env) # type: ignore
return elaborate
def _result_handler(backend: Backend,
sticky_device: Optional[Device],
out_type: Optional[pe.OutputType]
) -> Callable:
out_avals, kept_outputs = util.unzip2(out_type)
handlers = map(partial(aval_to_result_handler, sticky_device), out_avals)
dyn_outs = any(type(aval) is core.DShapedArray and
any(type(d) in (core.InDBIdx, core.OutDBIdx) for d in aval.shape)
for aval in out_avals)
if not dyn_outs:
return SimpleResultHandler(handlers)
assert config.jax_dynamic_shapes
def result_handler(input_env, lists_of_bufs):
results = []
for handler, bufs in unsafe_zip(handlers, lists_of_bufs):
results.append(handler((input_env, results), *bufs))
return [r for r, keep in unsafe_zip(results, kept_outputs) if keep]
return result_handler
class SimpleResultHandler:
handlers: Sequence[ResultHandler]
def __init__(self, handlers): self.handlers = handlers
def __iter__(self): return iter(self.handlers)
def __len__(self): return len(self.handlers)
def __call__(self, env, lists_of_bufs):
return tuple(h(env, *bs) for h, bs in zip(self.handlers, lists_of_bufs))
def _maybe_create_array_from_da(buf, aval, device):
if config.jax_array:
from jax.experimental.array import Array
from jax.experimental.sharding import SingleDeviceSharding
return Array(buf.shape, SingleDeviceSharding(buf.device()), [buf],
committed=(device is not None))
else:
return device_array.make_device_array(aval, device, buf)
if MYPY:
ResultHandler = Any
else:
class ResultHandler(Protocol):
def __call__(self, env: Optional[Sequence[Any]], *args: xla.Buffer) -> Any:
"""Boxes raw buffers into their user-facing representation."""
def aval_to_result_handler(sticky_device: Optional[Device],
aval: core.AbstractValue) -> ResultHandler:
try:
return result_handlers[type(aval)](sticky_device, aval)
except KeyError as err:
raise TypeError(f"No result handler for type: {type(aval)}") from err
def array_result_handler(sticky_device: Optional[Device],
aval: core.ShapedArray):
if aval.dtype == dtypes.float0:
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
aval = core.raise_to_shaped(aval)
handler = lambda _, b: _maybe_create_array_from_da(b, aval, sticky_device)
handler.args = aval, sticky_device # for C++ dispatch path in api.py
return handler
def dynamic_array_result_handler(sticky_device: Optional[Device],
aval: core.DShapedArray):
if aval.dtype == dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
else:
return partial(_dynamic_array_result_handler, sticky_device, aval)
def _dynamic_array_result_handler(sticky_device, aval, env, buf):
if all(type(d) is int for d in aval.shape):
del env
return _maybe_create_array_from_da(buf, aval, sticky_device)
else:
assert env is not None
in_env, out_env = env
shape = [in_env[d.val] if type(d) is core.InDBIdx else
out_env[d.val] if type(d) is core.OutDBIdx else d
for d in aval.shape]
aval = core.ShapedArray(tuple(shape), aval.dtype)
return _maybe_create_array_from_da(buf, aval, sticky_device)
result_handlers: Dict[
Type[core.AbstractValue],
Callable[[Optional[Device], Any], ResultHandler]] = {}
result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token
result_handlers[core.ShapedArray] = array_result_handler
result_handlers[core.DShapedArray] = dynamic_array_result_handler
result_handlers[core.ConcreteArray] = array_result_handler
def needs_check_special():
return config.jax_debug_infs or config.jax_debug_nans
def check_special(name, bufs):
if needs_check_special():
for buf in bufs:
_check_special(name, buf.xla_shape(), buf)
def _check_special(name, xla_shape, buf):
assert not xla_shape.is_tuple()
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
device: Device, input_bufs):
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
tokens_flat = flatten(tokens)
input_bufs = [*tokens_flat, *input_bufs]
def _remove_tokens(output_bufs):
token_bufs, output_bufs = util.split_list(
output_bufs, [has_unordered_effects + len(ordered_effects)])
if has_unordered_effects:
output_token_buf, *token_bufs = token_bufs
runtime_tokens.set_output_token(device, output_token_buf)
for eff, token_buf in zip(ordered_effects, token_bufs):
runtime_tokens.update_token(eff, token_buf)
return output_bufs
return input_bufs, _remove_tokens
def _execute_compiled(name: str, compiled: XlaExecutable,
input_handler: Optional[Callable],
output_buffer_counts: Sequence[int],
result_handler: Callable,
has_unordered_effects: bool,
ordered_effects: List[core.Effect],
kept_var_idx, *args):
device, = compiled.local_devices()
args, env = input_handler(args) if input_handler else (args, None)
in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx)
if has_unordered_effects or ordered_effects:
in_flat, token_handler = _add_tokens(has_unordered_effects, ordered_effects,
device, in_flat)
out_flat = compiled.execute(in_flat)
check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts)
if ordered_effects or has_unordered_effects:
out_bufs = token_handler(out_bufs)
return result_handler(env, out_bufs)
def _execute_replicated(name: str, compiled: XlaExecutable,
input_handler: Optional[Callable],
output_buffer_counts: Sequence[int],
result_handler: Callable,
has_unordered_effects: bool,
ordered_effects: List[core.Effect],
kept_var_idx, *args):
if has_unordered_effects or ordered_effects:
# TODO(sharadmv): support jit-of-pmap with effects
raise NotImplementedError(
"Cannot execute replicated computation with effects.")
if input_handler: raise NotImplementedError # TODO(mattjj, dougalm)
input_bufs = [flatten(device_put(x, device) for i, x in enumerate(args)
if i in kept_var_idx)
for device in compiled.local_devices()]
input_bufs_flip = list(unsafe_zip(*input_bufs))
out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(input_bufs_flip)
out_flat = [bufs[0] for bufs in out_bufs_flat_rep]
check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts)
return result_handler(None, out_bufs)
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
has_unordered_effects: bool,
ordered_effects: List[core.Effect], kept_var_idx, *args):
env: Dict[core.Var, Any] = {}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
for v in jaxpr.outvars]
return [_copy_device_array_to_device(x, device) if device_array.type_is_device_array(x)
else h(None, *device_put(x, device)) for h, x in zip(handlers, outs)]
class XlaComputation(stages.XlaLowering):
name: str
_is_trivial: bool
_executable: Optional[XlaCompiledComputation]
_donated_invars: Optional[Sequence[bool]]
def __init__(self, name: str, hlo, is_trivial: bool,
donated_invars: Optional[Sequence[bool]],
in_type: Optional[pe.InputType],
out_type: Optional[pe.OutputType],
**compile_args):
self.name = name
self._hlo = hlo
self._is_trivial = is_trivial
self._donated_invars = donated_invars
self._in_type = in_type
self._out_type = out_type
self._executable = None
self.compile_args = compile_args
def is_trivial(self):
return self._is_trivial
# -- stages.XlaLowering overrides
def hlo(self) -> xc.XlaComputation:
if self.is_trivial():
raise ValueError("A trivial computation has no HLO")
if isinstance(self._hlo, xc.XlaComputation):
return self._hlo
return xe.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self._hlo),
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
if self.is_trivial():
raise ValueError("A trivial computation has no MHLO")
if isinstance(self._hlo, xc.XlaComputation):
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo
def compile(self) -> XlaCompiledComputation:
if self._executable is None:
if self.is_trivial():
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
**self.compile_args)
else:
self._executable = XlaCompiledComputation.from_xla_computation(
self.name, self._hlo, self._in_type, self._out_type,
**self.compile_args)
return self._executable
@profiler.annotate_function
def backend_compile(backend, built_c, options):
# we use a separate function call to ensure that XLA compilation appears
# separately in Python profiling results
return backend.compile(built_c, compile_options=options)
# TODO(phawkins): update users.
xla.backend_compile = backend_compile
_ir_dump_counter = itertools.count()
def _make_string_safe_for_filename(s: str) -> str:
return re.sub(r'[^\w.)( -]', '', s)
def _dump_ir_to_file(name: str, ir: str):
id = next(_ir_dump_counter)
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
name = epath.Path(FLAGS.jax_dump_ir_to) / name
name.write_text(ir)
def compile_or_get_cached(backend, computation, compile_options):
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc
if isinstance(computation, ir.Module):
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
# Convert ir.Module to str representation (the default), unless the
# back-end expliclity flags the ability to handle a module directly
# (avoiding the overhead of back and forth conversions)
if getattr(backend, "needs_str_ir", True):
computation = mlir.module_to_string(computation)
else:
module_name = computation.name()
# Persistent compilation cache only implemented on TPU.
# TODO(skye): add warning when initializing cache on unsupported default platform
if cc.is_initialized() and backend.platform == 'tpu':
cached_executable = cc.get_executable(computation, compile_options, backend)
if cached_executable is not None:
logging.info('Persistent compilation cache hit for %s.', module_name)
return cached_executable
else:
compiled = backend_compile(backend, computation, compile_options)
cc.put_executable(module_name, computation, compile_options, compiled,
backend)
return compiled
if FLAGS.jax_dump_ir_to:
if isinstance(computation, xc.XlaComputation):
ir_str = computation.as_hlo_text()
elif isinstance(computation, ir.Module):
ir_str = mlir.module_to_string(computation)
else:
assert isinstance(computation, str)
ir_str = computation
_dump_ir_to_file(module_name, ir_str)
return backend_compile(backend, computation, compile_options)
class XlaCompiledComputation(stages.XlaExecutable):
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call,
keepalive: Any):
self._xla_executable = xla_executable
self.in_avals = in_avals
self._kept_var_idx = kept_var_idx
self.unsafe_call = unsafe_call
# Only the `unsafe_call` function is cached, so to avoid the `keepalive`
# being garbage collected we attach it to `unsafe_call`.
self.unsafe_call.keepalive = keepalive
@staticmethod
def from_xla_computation(
name: str,
xla_computation: Optional[ir.Module],
in_type: Optional[pe.InputType],
out_type: Optional[pe.OutputType],
nreps: int,
device: Optional[Device],
backend: Backend,
tuple_args: bool,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
has_unordered_effects: bool,
ordered_effects: List[core.Effect],
kept_var_idx: Set[int],
keepalive: Optional[Any]) -> XlaCompiledComputation:
sticky_device = device
input_handler = _input_handler(backend, in_type, out_type)
result_handler = _result_handler(backend, sticky_device, out_type)
options = xb.get_compile_options(
num_replicas=nreps, num_partitions=1,
device_assignment=(sticky_device,) if sticky_device else None)
options.parameter_is_tupled_arguments = tuple_args
with log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec"):
compiled = compile_or_get_cached(backend, xla_computation, options)
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
if ordered_effects or has_unordered_effects:
num_output_tokens = len(ordered_effects) + has_unordered_effects
buffer_counts = ([1] * num_output_tokens) + buffer_counts
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
result_handler, has_unordered_effects,
ordered_effects, kept_var_idx)
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
keepalive)
def is_trivial(self):
return self._xla_executable == None
@property
def xla_executable(self):
# TODO(frostig): remove in favor of runtime_executable?
if self.is_trivial():
raise ValueError("A trivial compiled computation has no XLA executable")
return self._xla_executable
@staticmethod
def from_trivial_jaxpr(
jaxpr, consts, device, in_avals, out_avals, has_unordered_effects,
ordered_effects, kept_var_idx, keepalive: Optional[Any]
) -> XlaCompiledComputation:
assert keepalive is None
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
out_avals, result_handlers, has_unordered_effects,
ordered_effects, kept_var_idx)
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
keepalive)
# -- stages.XlaExecutable overrides
def xla_extension_executable(self):
return self.xla_executable
def call(self, *args):
arg_specs = unsafe_map(arg_spec, args)
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
if i in self._kept_var_idx]
check_arg_avals_for_call(self.in_avals, arg_avals)
return self.unsafe_call(*args)
def check_arg_avals_for_call(ref_avals, arg_avals):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
if not core.typematch(ref_aval, arg_aval):
ref_avals_fmt = ', '.join(str(a) for a in ref_avals)
arg_avals_fmt = ', '.join(str(a) for a in arg_avals)
raise TypeError(
f"Computation compiled for input types:\n {ref_avals_fmt}\n"
f"called with:\n {arg_avals_fmt}")
def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
x = xla.canonicalize_dtype(x)
try:
return device_put_handlers[type(x)](x, device)
except KeyError as err:
raise TypeError(f"No device_put handler for type: {type(x)}") from err