-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
mlir.py
2002 lines (1745 loc) · 82 KB
/
mlir.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lowering and execution path that converts jaxprs into MLIR.
from __future__ import annotations
import collections
import dataclasses
import functools
from functools import partial
import io
import itertools
import re
import typing
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
Protocol, Sequence, Set, Tuple, Type, Union)
import warnings
import numpy as np
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src import effects as effects_lib
from jax._src import linear_util as lu
from jax._src import pickle_util
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.config import config
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.sharding_impls import XLACompatibleSharding
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
T = typing.TypeVar("T")
Value = Any # = ir.Value
# mypy implicitly sets this variable to true when type checking.
MYPY = False
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
# IR Helpers
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
a = np.packbits(np.array(xs, np.bool_), bitorder='little')
# TODO(b/209005197): Work around for MLIR crash for non-splat single element
# buffers.
if len(xs) == 1:
a = np.array(0 if a.item() == 0 else 0xff, np.uint8)
return ir.DenseElementsAttr.get(
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
def shape_tensor(sizes: Sequence[Union[int, ir.RankedTensorType]]
) -> ir.RankedTensorType:
int1d = aval_to_ir_type(core.ShapedArray((1,), np.int32))
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
def lower_dim(d):
if type(d) is int:
return ir_constant(np.array([d], np.int32))
else:
if d.type != i32_type:
d = hlo.ConvertOp(i32_type, d)
return hlo.ReshapeOp(int1d, d).result
ds = map(lower_dim, sizes)
if not ds:
return ir_constant(np.array([], np.int32))
elif len(ds) == 1:
return ds[0]
else:
return hlo.ConcatenateOp(ds, i64_attr(0)).result
def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
"""Side-effects on `ctx`"""
ctx_new = ctx.replace(**ctx_override_kwargs)
out = lowering_fun(ctx_new, *args)
ctx.set_tokens_out(ctx_new.tokens_out)
return out
# IR Types
# Non-canonicalized dtype to IR type mapping.
_dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
np.dtype(dtypes.float8_e4m3b11fnuz): ir.Float8E4M3B11FNUZType.get,
np.dtype(dtypes.float8_e4m3fn): ir.Float8E4M3FNType.get,
np.dtype(dtypes.float8_e5m2): ir.Float8E5M2Type.get,
np.dtype(dtypes.bfloat16): ir.BF16Type.get,
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}
if dtypes.int4 is not None:
_dtype_to_ir_type.update({
np.dtype(dtypes.int4): partial(ir.IntegerType.get_signless, 4),
np.dtype(dtypes.uint4): partial(ir.IntegerType.get_unsigned, 4),
})
def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
assert isinstance(dtype, (np.dtype, np.generic)), type(dtype)
dtype = np.dtype(dtype)
try:
ir_type_factory = _dtype_to_ir_type[dtype]
except KeyError as err:
raise TypeError(
f"No dtype_to_ir_type handler for dtype: {dtype}") from err
return ir_type_factory()
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
) -> Sequence[ir.Type]:
aval = core.physical_aval(aval) # type: ignore
if not core.is_constant_shape(aval.shape):
return _dynamic_array_ir_types(aval) # type: ignore
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
dyn_size = ir.ShapedType.get_dynamic_size()
shape = [d if type(d) is int else dyn_size for d in aval.shape]
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
ir_type_handlers: Dict[Type[core.AbstractValue],
Callable[[Any], Sequence[ir.Type]]] = {}
def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
"""Converts a JAX aval to zero or more MLIR IR types.
In general, a JAX value may be represented by multiple IR values, so this
function returns multiple types."""
try:
return ir_type_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
ir_type_handlers[core.ShapedArray] = _array_ir_types
ir_type_handlers[core.ConcreteArray] = _array_ir_types
ir_type_handlers[core.AbstractToken] = lambda _: [hlo.TokenType.get()]
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
"""Convenience wrapper around aval_to_ir_types for single types.
For some common cases, e.g. dense arrays, we know JAX values are represented
by a single IR value."""
types = aval_to_ir_types(aval)
if len(types) != 1:
raise TypeError(f"aval_to_ir_type called on {aval} which corresponds to "
f"multiple IR types {types}")
return types[0]
# Constants
class ConstantHandler(Protocol):
def __call__(self, val: Any, canonicalize_types: bool) -> Sequence[ir.Value]:
"""Builds an IR representation for a constant `val`.
A JAX value is represented by zero or more IR values."""
_constant_handlers : Dict[type, ConstantHandler] = {}
def register_constant_handler(type_: type, handler_fun: ConstantHandler):
_constant_handlers[type_] = handler_fun
def get_constant_handler(type_: type) -> ConstantHandler:
return _constant_handlers[type_]
def ir_constants(val: Any,
canonicalize_types: bool = True) -> Sequence[ir.Value]:
"""Translate a Python `val` to an IR constant, canonicalizing its dtype.
Args:
val: a Python value to be translated to a constant.
Returns:
A representation of the constant as a list of IR values.
"""
for t in type(val).__mro__:
handler = _constant_handlers.get(t)
if handler:
out = handler(val, canonicalize_types)
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
return out
if hasattr(val, '__jax_array__'):
return ir_constants(val.__jax_array__(), canonicalize_types)
raise TypeError(f"No constant handler for type: {type(val)}")
def ir_constant(val: Any, canonicalize_types: bool = True) -> ir.Value:
"""Convenience wrapper around ir_constants for singleton values."""
values = ir_constants(val, canonicalize_types=canonicalize_types)
if len(values) != 1:
raise TypeError(f"ir_constant called on {val} which corresponds to "
f"multiple IR values {values}")
return values[0]
def _numpy_array_constant(x: np.ndarray, canonicalize_types
) -> Sequence[ir.Value]:
if canonicalize_types:
x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
nelems = x.size
x = np.packbits(x, bitorder='little')
# TODO(b/209005197): Work around for MLIR crash for non-splat single element
# buffers.
if nelems == 1:
x = np.array(0 if x.item() == 0 else 0xff, np.uint8)
elif x.dtype == dtypes.bfloat16:
x = x.view(np.uint16)
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
return (hlo.ConstantOp(attr).result,)
def _masked_array_constant_handler(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
register_constant_handler(np.ma.MaskedArray, _masked_array_constant_handler)
def _ndarray_constant_handler(val: np.ndarray, canonicalize_types
) -> Sequence[ir.Value]:
"""Constant handler for ndarray literals, handling zero-size strides.
In most cases this function calls _numpy_array_constant(val) except it has
special handling of arrays with any strides of size zero: for those, it
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
to avoid staging in large literals that might arise from np.zeros or np.ones
or the output of lax.broadcast (which uses np.broadcast_to which in turn
uses size-zero strides).
Args:
val: an ndarray.
Returns:
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
staged into the XLA Computation.
"""
if dtypes.result_type(val) == dtypes.float0:
return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_),
canonicalize_types=False)
elif np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore
for ax in range(val.ndim))] # type: ignore
if canonicalize_types:
collapsed_val = np.asarray(
collapsed_val, dtypes.canonicalize_dtype(collapsed_val.dtype))
out = hlo.BroadcastInDimOp(
ir.RankedTensorType.get(
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
_numpy_array_constant(collapsed_val, canonicalize_types=False)[0],
dense_int_elements(other_axes)).result
return (out,)
else:
return _numpy_array_constant(val, canonicalize_types)
register_constant_handler(np.ndarray, _ndarray_constant_handler)
for _scalar_type in [np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64,
np.complex64, np.complex128,
np.bool_, np.longlong, dtypes.bfloat16]:
register_constant_handler(_scalar_type, _ndarray_constant_handler) # type: ignore
def _python_scalar_handler(dtype, val, canonicalize_dtypes):
return _numpy_array_constant(np.array(val, dtype), canonicalize_dtypes)
for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
def _token_constant_handler(val, canonicalize_types):
return [hlo.CreateTokenOp().result]
register_constant_handler(core.Token, _token_constant_handler)
# Source locations
def _traceback_to_location(tb: xc.Traceback) -> ir.Location:
"""Converts a full traceback to a callsite() MLIR location."""
frame_locs = []
for code, lasti in zip(*tb.raw_frames()):
frame = source_info_util.raw_frame_to_frame(code, lasti)
frame_locs.append(ir.Location.file(xla.get_canonical_source_file(frame),
frame.start_line, frame.start_column))
if len(frame_locs) == 0:
return ir.Location.unknown()
else:
return ir.Location.callsite(frame_locs[-1], frame_locs[-2::-1])
def _source_info_to_location(
primitive: core.Primitive, params: Dict,
source_info: source_info_util.SourceInfo,
name_stack: source_info_util.NameStack) -> ir.Location:
eqn_str = (f'{str(source_info.name_stack)}/'
f'{core.str_eqn_compact(primitive.name, params)}')
if config.jax_include_full_tracebacks_in_locations:
if source_info.traceback is None:
loc = ir.Location.unknown()
else:
loc = _traceback_to_location(source_info.traceback)
else:
frame = source_info_util.user_frame(source_info)
if frame is None:
loc = ir.Location.unknown()
else:
loc = ir.Location.file(xla.get_canonical_source_file(frame),
frame.start_line, frame.start_column)
loc = ir.Location.name(eqn_str, childLoc=loc)
# TODO(phawkins): also include primitive.name as the operator type.
return loc
# Translation rules
def make_ir_context() -> ir.Context:
"""Creates an MLIR context suitable for JAX IR."""
context = ir.Context()
# If threading is enabled, each MLIR context will keep alive a thread pool.
# Since we cache MLIR modules (and hence contexts), this means we might keep
# several threads alive for each cache entry. This is a terrible idea. However
# we don't do any heavy computation on MLIR modules from Python anyway, so we
# just disable threading.
context.enable_multithreading(False)
dialects.mhlo.register_mhlo_dialect(context)
dialects.chlo.register_dialect(context)
dialects.hlo.register_dialect(context)
return context
AxisContext = Union[
sharding_impls.SPMDAxisContext,
sharding_impls.ReplicaAxisContext,
sharding_impls.ShardingContext,
]
class ShapePolyLoweringState:
# The names of the dimension variables, sorted by name. This is the order in
# which they are passed to the IR functions that need them. This is only
# used for native serialization with polymorphic shapes when
# --jax_dynamic_shapes is off.
dim_vars: Sequence[str]
# Whether the module uses dimension variables, either in its inputs or
# from an inner call to a polymorphic Exported.
uses_dim_vars: bool
def __init__(self, dim_vars: Sequence[str]):
self.dim_vars = dim_vars
self.uses_dim_vars = (len(dim_vars) > 0)
@dataclasses.dataclass
class ModuleContext:
"""Module-wide context information for MLIR lowering."""
context: ir.Context
module: ir.Module
ip: ir.InsertionPoint
symbol_table: ir.SymbolTable
backend_or_name: Optional[Union[str, xb.XlaBackend]]
platform: str
axis_context: AxisContext
name_stack: source_info_util.NameStack
keepalives: List[Any]
channel_iterator: Iterator[int]
host_callbacks: List[Any]
# Keep state for the lowering of shape polymorphism
shape_poly_state: ShapePolyLoweringState
# Cached primitive lowerings.
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
cached_call_jaxpr_lowerings: Dict[Any, func_dialect.FuncOp]
@property
def axis_env(self) -> sharding_impls.AxisEnv:
return self.axis_context.axis_env
def __init__(
self,
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
keepalives: List[Any],
channel_iterator: Iterator[int],
host_callbacks: List[Any],
context: Optional[ir.Context] = None,
module: Optional[ir.Module] = None,
ip: Optional[ir.InsertionPoint] = None,
symbol_table: Optional[ir.SymbolTable] = None,
cached_primitive_lowerings: Optional[Dict[Any,
func_dialect.FuncOp]] = None,
cached_call_jaxpr_lowerings: Optional[Dict[Any,
func_dialect.FuncOp]] = None,
shape_poly_state = None):
assert platform is not None
self.context = context or make_ir_context()
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
self.ip = ip or ir.InsertionPoint(self.module.body)
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
self.backend_or_name = backend_or_name
self.platform = platform
self.axis_context = axis_context
self.name_stack = name_stack
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
else cached_primitive_lowerings)
self.channel_iterator = channel_iterator
self.keepalives = keepalives
self.host_callbacks = host_callbacks
self.cached_call_jaxpr_lowerings = ({}
if cached_call_jaxpr_lowerings is None
else cached_call_jaxpr_lowerings)
self.shape_poly_state = shape_poly_state or ShapePolyLoweringState(())
@property
def backend(self) -> xb.XlaBackend:
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
return xb.get_backend(self.backend_or_name)
return self.backend_or_name
def new_channel(self) -> int:
return next(self.channel_iterator)
def add_host_callback(self, host_callback: Any) -> None:
self.host_callbacks.append(host_callback)
def add_keepalive(self, keepalive: Any) -> None:
self.keepalives.append(keepalive)
def replace(self, **kw): return dataclasses.replace(self, **kw)
@dataclasses.dataclass
class LoweringRuleContext:
"""Per-rule context information for MLIR lowering."""
module_context: ModuleContext
primitive: Optional[core.Primitive]
avals_in: Sequence[core.AbstractValue]
avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None.
tokens_in: TokenSet
tokens_out: Optional[TokenSet] # Mutable store for output containers
axis_size_env: Optional[Dict[core.Var, ir.Value]] = None # Dynamic axis sizes
dim_var_values: Sequence[ir.Value] = () # The values for the dimension variables
# in same order as module_context.shape_poly_state.dim_vars
def set_tokens_out(self, tokens_out: TokenSet):
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
self.tokens_out = tokens_out
def replace(self, **kw): return dataclasses.replace(self, **kw)
if not MYPY:
class LoweringRule(Protocol):
def __call__(self, ctx: LoweringRuleContext,
*args: Union[ir.Value, Sequence[ir.Value]],
**kw) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
"""Converts a JAX primitive invocation into MLIR."""
else:
LoweringRule = Any
_lowerings: Dict[core.Primitive, LoweringRule] = {}
_platform_specific_lowerings: Dict[str, Dict[core.Primitive, LoweringRule]]
_platform_specific_lowerings = collections.defaultdict(dict)
def register_lowering(prim: core.Primitive, rule: LoweringRule,
platform: Optional[str] = None):
if platform is None:
_lowerings[prim] = rule
else:
# For backward compatibility reasons, we allow rules to be registered
# under "gpu" even though the platforms are now called "cuda" and "rocm".
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
# this expansion.
for p in xb.expand_platform_alias(platform):
_platform_specific_lowerings[p][prim] = rule
return rule
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
def wrap_singleton_ir_values(x: Union[ir.Value, Sequence[ir.Value]]
) -> Sequence[ir.Value]:
"""Adds a consistent tuples to a mixture of tupled and untuple values."""
return (x,) if isinstance(x, ir.Value) else tuple(x)
def flatten_lowering_ir_args(
xs: Sequence[Union[ir.Value, Sequence[ir.Value]]]
) -> Sequence[Sequence[ir.Value]]:
return util.flatten(map(wrap_singleton_ir_values, xs))
_module_name_regex = re.compile(r"[^\w.-]")
def sharded_aval(aval: core.AbstractValue,
sharding: Optional[XLACompatibleSharding]) -> core.AbstractValue:
"""Returns the new aval sharded based on sharding proto."""
if sharding is None:
return aval
if isinstance(aval, core.AbstractToken):
return aval
if not isinstance(aval, core.ShapedArray):
raise NotImplementedError
return aval.update(sharding.shard_shape(aval.shape))
def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Union[int, Value], ...]:
if config.jax_dynamic_shapes:
return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
else:
ctx = ctx.replace(
primitive="eval_dynamic_shape",
avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars))
res = lower_fun(
partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
multiple_results=True)(ctx, *ctx.dim_var_values)
return util.flatten(res) # type: ignore
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
shape: core.Shape) -> Tuple[Value, ...]:
"""Evaluates the dynamic shapes as int32 values."""
def convert_dim(d: Union[int, Value]):
if type(d) is int:
return ir_constant(np.array(d, dtype=np.int32))
else:
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
if d.type != i32_type: # type: ignore
return hlo.ConvertOp(i32_type, d).result
else:
return d
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
class LoweringResult(NamedTuple):
module: ir.Module
keepalive: Optional[Any]
host_callbacks: List[Any]
shape_poly_state: ShapePolyLoweringState
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
def _to_logical_op_sharding(
aval: core.AbstractValue, sharding: Optional[XLACompatibleSharding],
) -> Optional[xc.HloSharding]:
if sharding is None:
return None
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
assert isinstance(aval, core.ShapedArray)
return sharding._to_xla_hlo_sharding(aval.ndim)
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
ordered_effects: List[core.Effect],
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None,
result_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None,
arg_names: Optional[Sequence[Optional[str]]] = None,
result_names: Optional[Sequence[Optional[str]]] = None,
num_replicas: int = 1,
num_partitions: int = 1,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
platform = xb.canonicalize_platform(platform)
if not xb.is_known_platform(platform):
raise ValueError(f"Unknown platform {platform}")
input_output_aliases = None
in_avals = (jaxpr.in_avals if arg_shardings is None else
map(sharded_aval, jaxpr.in_avals, arg_shardings))
out_avals = (jaxpr.out_avals if result_shardings is None else
map(sharded_aval, jaxpr.out_avals, result_shardings))
if platform in _platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
in_avals, out_avals, donated_args)
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
if any(donated_args):
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
if platform not in _platforms_with_donation:
msg = f"Donation is not implemented for {platform}.\n{msg}"
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
# HLO channels need to start at 1
channel_iter = itertools.count(1)
# Create a keepalives list that will be mutated during the lowering.
keepalives: List[Any] = []
host_callbacks: List[Any] = []
dim_vars: Sequence[str]
if not config.jax_dynamic_shapes:
# Find the dimension variables
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
for d in aval.shape if not core.is_constant_dim(d)]
dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new.get_vars()),
all_dim_poly, set())))
else:
dim_vars = ()
arg_op_shardings = (
map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings)
if arg_shardings is not None else arg_shardings)
result_op_shardings = (
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks,
shape_poly_state=ShapePolyLoweringState(dim_vars))
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
module_name = _module_name_regex.sub("_", module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
lower_jaxpr_to_fun(
ctx, "main", jaxpr, ordered_effects, public=True, create_tokens=True,
replace_tokens_with_dummy=True,
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_op_shardings,
result_shardings=result_op_shardings,
input_output_aliases=input_output_aliases,
arg_names=arg_names,
result_names=result_names)
try:
if not ctx.module.operation.verify():
module_string = module_to_string(ctx.module)
raise ValueError(
f"Cannot lower jaxpr with verifier errors: {module_string}")
except ir.MLIRError as e:
module_string = module_to_string(ctx.module)
raise ValueError(
f"Cannot lower jaxpr with verifier errors: {module_string}") from e
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
ctx.shape_poly_state)
def module_to_string(module: ir.Module) -> str:
output = io.StringIO()
module.operation.print(file=output, enable_debug_info=True,
print_generic_op_form=False)
return output.getvalue()
def module_to_bytecode(module: ir.Module) -> bytes:
output = io.BytesIO()
module.operation.write_bytecode(file=output)
return output.getvalue()
def _set_up_aliases(avals_in, avals_out, donated_args):
input_output_aliases = [None] * len(avals_in)
# To match-up in-avals to out-avals we only care about the number of
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
strip_metadata = lambda a: a.strip_named_shape().strip_weak_type()
avals_in = map(strip_metadata, avals_in)
avals_out = map(strip_metadata, avals_out)
donations = collections.defaultdict(collections.deque)
for i, (aval, donated) in enumerate(zip(avals_in, donated_args)):
if donated:
donations[aval].append(i)
out_donated_args = list(donated_args)
for i, aval in enumerate(avals_out):
if donations.get(aval, ()):
input_id = donations[aval].popleft()
input_output_aliases[input_id] = i
out_donated_args[input_id] = False
return input_output_aliases, out_donated_args
Token = Sequence[ir.Value]
def token_type() -> Sequence[ir.Type]:
return [hlo.TokenType.get()]
def create_token() -> Token:
return wrap_singleton_ir_values(hlo.CreateTokenOp().result)
class TokenSet:
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
effectful jaxprs, we need to thread HLO tokens to sequence them. Each effect
will need its own token that will be threaded in and out of the effectful
primitives. A `TokenSet` encapsulates a set of HLO tokens that will be
used by the lowering rules.
"""
_tokens: typing.OrderedDict[core.Effect, Token]
def __init__(self, *args, **kwargs):
self._tokens = collections.OrderedDict(*args, **kwargs)
def __len__(self):
return len(self._tokens)
def get(self, effect: core.Effect) -> Token:
return self._tokens[effect]
@classmethod
def create(cls, effects: Sequence[core.Effect]) -> TokenSet:
"""Creates a `TokenSet` corresponding to a list of `core.Effect`s."""
tokens = [create_token() for _ in effects]
return TokenSet(zip(effects, tokens))
def items(self) -> Sequence[Tuple[core.Effect, Token]]:
return tuple(self._tokens.items())
def effects(self) -> set[core.Effect]:
return set(self._tokens.keys())
def subset(self, effects: Sequence[core.Effect]) -> TokenSet:
"""Return a subset of the `TokenSet` restricted to a set of `core.Effect`s."""
return TokenSet((eff, self._tokens[eff]) for eff in effects)
def update_tokens(self, tokens: TokenSet) -> TokenSet:
"""Returns a new `TokenSet` with tokens replaced with ones from the input `TokenSet`."""
new_tokens = []
for eff in self.effects():
if eff in tokens._tokens:
new_tokens.append((eff, tokens._tokens[eff]))
else:
new_tokens.append((eff, self._tokens[eff]))
return TokenSet(new_tokens)
def dummy_token_type() -> Sequence[ir.Type]:
return aval_to_ir_types(core.ShapedArray((0,), np.bool_))
def dummy_token() -> Sequence[ir.Value]:
return ir_constants(np.zeros(0, np.bool_))
def lower_jaxpr_to_fun(
ctx: ModuleContext,
name: str,
jaxpr: core.ClosedJaxpr,
effects: Sequence[core.Effect],
*,
create_tokens: bool = False,
public: bool = False,
replace_tokens_with_dummy: bool = False,
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None,
use_sharding_annotations: bool = True,
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
num_output_tokens: int = 0,
api_name: str = "jit",
arg_names: Optional[Sequence[Optional[str]]] = None,
result_names: Optional[Sequence[Optional[str]]] = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
Assumes that an MLIR context, location, and insertion point are set.
Args:
ctx: the lowering context.
name: the function name. The name will be uniquified by the symbol table,
so it is ok to use the same name multiple times.
jaxpr: the jaxpr to lower.
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
that will be created in or used by the lowered function.
create_tokens: if true, the HLO will create tokens and ignore dummy input tokens.
public: if true, the function's visibility is set to "public".
replace_tokens_with_dummy: if true, token arguments/return values are
replaced with bool arrays of size [0].
replicated_args: if present, annotates arguments as replicated.
arg_shardings: sharding annotations for each argument (optional).
result_shardings: sharding annotations for each result (optional).
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
parameters and return values to express sharding. If False, use
hlo.custom_call operators with sharding annotations.
TODO(b/228598865): remove this option when "mhlo.sharding" annotations are
propagated on non-entry functions during MLIR->HLO conversion.
input_output_aliases: optional sequence that maps argument numbers to the
corresponding output that should alias them.
api_name: The name of the higher level primitive which should show up in the
name stack.
Returns:
MLIR func op
"""
def aval_to_types(aval):
if replace_tokens_with_dummy and aval is core.abstract_token:
aval = core.ShapedArray((), np.dtype(np.bool_))
return aval_to_ir_types(aval)
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
dim_var_types = map(aval_to_types, dim_var_avals)
# Function inputs: *dim_var_values, *tokens, *actual_inputs
input_types = map(aval_to_types, jaxpr.in_avals)
output_types = map(aval_to_types, jaxpr.out_avals)
num_tokens = len(effects)
if create_tokens:
# If we create the tokens they won't be inputs to the MLIR function.
token_types = [dummy_token_type() for _ in effects]
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)]
else:
# If we aren't creating tokens they will be the initial inputs to the
# MLIR function.
output_token_types = []
token_types = [token_type() for _ in effects]
token_avals = [core.AbstractToken] * num_tokens
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
# Update the existing aliases to account for the new output values
input_output_aliases = [None if a is None
else a + num_output_tokens + num_tokens
for a in input_output_aliases]
if arg_shardings is not None:
token_shardings = [None] * (num_dim_vars + num_tokens)
arg_shardings = [*token_shardings, *arg_shardings]
if result_shardings is not None:
token_shardings = [None] * (num_tokens + num_output_tokens)
result_shardings = [*token_shardings, *result_shardings]
if replicated_args is not None:
token_replicated_args = [False] * (num_dim_vars + num_tokens)
replicated_args = [*token_replicated_args, *replicated_args]
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private")
ctx.symbol_table.insert(func_op)
ir_arg_shardings = None
if arg_shardings is not None:
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
ir_arg_shardings = util.flatten(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(in_avals, arg_shardings, input_types)])
del in_avals
ir_result_shardings = None
if result_shardings is not None:
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
ir_result_shardings = util.flatten(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(out_avals, result_shardings, output_types)])
del out_avals
if (
replicated_args is not None
or ir_arg_shardings is not None
or input_output_aliases is not None
or arg_names is not None
or num_tokens > 0
):
arg_attrs: List[Dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_input_types))]
if replicated_args is not None:
replicated_ir_args = [[replicated] * len(types) for replicated, types
in zip(replicated_args, input_types)]
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
if replicated:
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
if use_sharding_annotations and ir_arg_shardings is not None:
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
if sharding is not None:
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
if input_output_aliases is not None:
output_ids = util.unflatten(list(range(len(flat_output_types))),
map(len, output_types))
aliases: List[Optional[int]] = []
for types, alias in zip(input_types, input_output_aliases):
if alias is None:
aliases.extend([None] * len(types))
else:
aliases.extend(output_ids[alias])
for attrs, alias in zip(arg_attrs, aliases):
if alias is not None:
attrs["tf.aliasing_output"] = i32_attr(alias)
if num_tokens > 0:
token_arg_attrs = arg_attrs[num_dim_vars:num_tokens]
for attrs in token_arg_attrs:
attrs["jax.token"] = ir.BoolAttr.get(True)
if arg_names:
named_arg_attrs = arg_attrs[num_dim_vars + num_tokens:]
for attrs, name_ in zip(named_arg_attrs, arg_names):
if name_:
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
result_attrs: List[Dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_output_types))]
if num_tokens > 0:
token_result_attrs = result_attrs[:num_tokens]
for attrs in token_result_attrs:
attrs["jax.token"] = ir.BoolAttr.get(True)
if result_names:
named_result_attrs = result_attrs[num_tokens:]
if len(named_result_attrs) == len(result_names):
for attrs, name_ in zip(named_result_attrs, result_names):
attrs['jax.result_info'] = ir.StringAttr.get(name_)
if use_sharding_annotations and ir_result_shardings is not None:
for attrs, sharding in zip(result_attrs, ir_result_shardings):
if sharding is not None:
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
func_op.result_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in result_attrs])
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):
flat_args = entry_block.arguments
# We separate out the dimension variable inputs, the token inputs and
# the regular inputs. The dimension variables and token inputs
# will be passed to `jaxpr_subcomp` separately from the `args`.
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
# A lowering context just for function body entry/exit code.
entry_lowering_ctx = LoweringRuleContext(
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values)
if not use_sharding_annotations and ir_arg_shardings is not None:
flat_args = [
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)),
[num_dim_vars, num_tokens])
if create_tokens:
tokens_in = TokenSet.create(effects)