-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
lowering.py
1691 lines (1394 loc) · 57.1 KB
/
lowering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for lowering JAX to Mosaic-compatible MLIR dialects."""
from __future__ import annotations
import dataclasses
import functools
from typing import Any, Callable, Sequence
from jax import core as jax_core
from jax import lax
from jax import tree_util
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import source_info_util
from jax._src import state
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import math
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import core
from jax._src.pallas import indexing
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as state_primitives
from jax._src.util import safe_map
from jax._src.util import safe_zip
from jax._src.util import split_list
from jax._src.util import unzip2
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
import numpy as np
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
NDIndexer = indexing.NDIndexer
TPUMemorySpace = tpu_core.TPUMemorySpace
VMEM = tpu_core.TPUMemorySpace.VMEM
SMEM = tpu_core.TPUMemorySpace.SMEM
partial = functools.partial
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
@dataclasses.dataclass
class MeshContext:
logical_to_mesh: ir.Value
axis_names: tuple[str, ...]
mesh_strides: tuple[int, ...]
@dataclasses.dataclass
class LoweringContext:
ir_context: ir.Context
grid_mapping: core.GridMapping | None
grid_indices: Sequence[ir.Value] | None
block_shapes: list[tuple[int | core.Mapped, ...]]
name_stack: source_info_util.NameStack
mesh_context: MeshContext | None
replace = dataclasses.replace
@dataclasses.dataclass
class LoweringRuleContext:
lowering_context: LoweringContext
avals_in: Sequence[jax_core.AbstractValue]
avals_out: Sequence[jax_core.AbstractValue]
block_shapes: list[tuple[int | core.Mapped, ...]] | None
replace = dataclasses.replace
def _memory_space_to_tpu_memspace(memory_space: TPUMemorySpace | None
) -> ir.Attribute:
if memory_space is None:
memory_space = VMEM
return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>")
def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None):
if isinstance(aval, tpu_core.AbstractSemaphore):
if aval.sem_type == tpu_core.SemaphoreType.DMA:
return ir.Type.parse("!tpu.dma_semaphore")
elif aval.sem_type == tpu_core.SemaphoreType.REGULAR:
return ir.Type.parse("!tpu.semaphore")
raise NotImplementedError(aval.sem_type)
if isinstance(aval, state.AbstractRef):
if shape is None:
shape = aval.shape
memspace = _memory_space_to_tpu_memspace(memory_space)
return ir.MemRefType.get(shape, mlir.dtype_to_ir_type(aval.dtype),
memory_space=memspace)
if isinstance(aval, jax_core.ShapedArray):
if shape is None:
shape = aval.shape
if not shape:
return mlir.dtype_to_ir_type(aval.dtype)
return ir.VectorType.get(shape, mlir.dtype_to_ir_type(aval.dtype))
raise NotImplementedError(aval)
def ir_constant(x, mlir_type=None):
if not hasattr(x, "dtype"):
if isinstance(x, int):
x = np.array(x, np.int32)
elif isinstance(x, float):
x = np.array(x, np.float32)
if not mlir_type:
mlir_type = mlir.dtype_to_ir_type(x.dtype)
if isinstance(x, int) or x.dtype == np.int32 or x.dtype == np.uint32:
return arith.ConstantOp(mlir_type, ir.IntegerAttr.get(mlir_type, int(x))
).result
elif isinstance(x, float) or x.dtype == np.float32:
return arith.ConstantOp(
mlir_type, ir.FloatAttr.get(mlir_type, float(x))
).result
elif x.dtype == jnp.bfloat16:
return arith.ConstantOp(
mlir_type, ir.FloatAttr.get(mlir_type, float(x))
).result
elif x.dtype == jnp.bool_:
return arith.ConstantOp(
mlir_type, ir.BoolAttr.get(bool(x))
).result
raise NotImplementedError(x.dtype)
lowering_rules = {}
skip_mlir_conversions = set()
def lower_jaxpr_to_module(
ctx: ir.Context,
grid_mapping: core.GridMapping,
jaxpr: jax_core.Jaxpr,
dimension_semantics: tuple[str | None, ...] | None,
mesh: mesh_lib.Mesh | None = None
) -> ir.Module:
m = ir.Module.create()
sym_tab = ir.SymbolTable(m.operation)
used_axis_names = jax_core.used_axis_names_jaxpr(jaxpr)
mesh_info = None
if used_axis_names:
axis_names = list(used_axis_names)
if mesh is None:
raise ValueError("Cannot use axis names in pallas_call without shard_map."
)
# We need mesh <-> logical translation tables. Since the logical IDs are
# just linearized versions of the mesh IDs, we create those tables.
mesh_strides = pallas_utils.strides_from_shape(tuple(
mesh.shape[a] for a in axis_names
))
logical_to_mesh = np.empty((mesh.size, len(axis_names)), dtype=np.int32)
for i, idx in enumerate(np.ndindex(*mesh.device_ids.shape)):
logical_to_mesh[i] = np.array(idx)
mesh_info = (logical_to_mesh, tuple(axis_names), mesh_strides)
extra_args = mesh_info[:1] if mesh_info else ()
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
name="main", mesh_info=mesh_info)
m.body.append(func_op)
sym_tab.insert(func_op)
num_smem_inputs = grid_mapping.num_index_operands
if mesh_info:
# We are now passing in the logical -> mesh index mapping
# TODO(sharadmv,apaszke): avoid stalling pipeline by marking the index
# mapping as scalar prefetch and instead just mark it as an SMEM operand.
num_smem_inputs += 1
window_params = []
grid = grid_mapping.grid
if grid:
for i, bm in enumerate(grid_mapping.block_mappings):
# TODO(sharadmv): generate default block mapping if left as no_block_spec
if bm is None:
raise NotImplementedError("Please specify block mappings if "
"grid is specified.")
if bm.index_map_jaxpr is None:
raise NotImplementedError("Please specify index_maps if "
"grid is specified.")
func_name = f"transform_{i}"
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
[*[None] * len(grid), *[SMEM] * num_smem_inputs],
name=func_name)
assert mlir_func.verify(), mlir_func
block_shape = [
1 if b is core.mapped else b for b in bm.block_shape
]
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
window_params.append(
ir.DictAttr.get(
dict(
window_bounds=window_shape,
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
)
)
)
m.body.append(mlir_func)
sym_tab.insert(mlir_func)
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(grid)
num_scratch_inputs = grid_mapping.num_scratch_operands
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), num_smem_inputs)
func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), num_scratch_inputs)
def _get_semantics(s: str | None) -> str:
if s is None:
return "#tpu.dimension_semantics<arbitrary>"
return f"#tpu.dimension_semantics<{s}>"
if dimension_semantics is None:
func_dimension_semantics = [
_get_semantics("parallel")
if i in grid_mapping.mapped_dims
else _get_semantics(None)
for i, d in enumerate(grid_mapping.grid)
]
else:
dimension_semantics_iter = iter(dimension_semantics)
func_dimension_semantics = [
_get_semantics("parallel")
if i in grid_mapping.mapped_dims
else _get_semantics(next(dimension_semantics_iter))
for i, d in enumerate(grid_mapping.grid)
]
func_op.attributes["dimension_semantics"] = ir.ArrayAttr.get(
map(ir.Attribute.parse, func_dimension_semantics)
)
return m, extra_args
def lower_jaxpr_to_transform_func(
ctx: ir.Context, jaxpr: jax_core.Jaxpr, memspaces: Sequence[Any],
*, name: str) -> func.FuncOp:
block_shapes = [i.aval.shape for i in jaxpr.invars]
arg_types = [*map(aval_to_ir_type, [invar.aval for invar in jaxpr.invars],
block_shapes, memspaces)]
lowering_context = LoweringContext(
ctx, None, None, block_shapes, source_info_util.NameStack(),
mesh_context=None)
body_func = functools.partial(jaxpr_subcomp, lowering_context, jaxpr)
body_func.__name__ = name
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
body.func_op.verify()
return body.func_op
def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
if consts:
raise NotImplementedError
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
lowering_context = ctx.lowering_context.replace(
block_shapes=ctx.block_shapes)
out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
if not multiple_results:
return out[0]
return out
return f_lowered
def lower_jaxpr_to_func(
ctx: ir.Context,
jaxpr: jax_core.Jaxpr,
*,
grid_mapping: core.GridMapping | None,
name: str,
mesh_info: Any,
) -> func.FuncOp:
memory_spaces = [None if bm is None else bm.memory_space
for bm in grid_mapping.block_mappings]
if grid_mapping:
arg_types = map(
aval_to_ir_type,
[jax_core.ShapedArray((), jnp.int32) for _ in grid_mapping.grid],
)
else:
arg_types = []
def _get_arg_type(aval, block_mapping: core.BlockMapping | None,
memory_space: tpu_core.TPUMemorySpace | None):
if isinstance(aval, tpu_core.AbstractMemoryRef):
assert memory_space is None
memory_space = aval.memory_space
if isinstance(aval, tpu_core.AbstractSemaphore):
return aval_to_ir_type(aval), None
if block_mapping is None:
return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
shape = tuple(
1 if b is core.mapped else b for b in block_mapping.block_shape
)
return (aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
block_mapping.block_shape)
if mesh_info is not None:
l_to_m, axis_names, mesh_strides = mesh_info
l_to_m_aval = state.AbstractRef(
jax_core.raise_to_shaped(jax_core.get_aval(l_to_m)))
arg_types.append(_get_arg_type(l_to_m_aval, None, SMEM)[0])
if grid_mapping is None:
block_mappings = [None] * len(jaxpr.invars)
else:
scalar_prefetch = grid_mapping.num_index_operands
block_mappings = grid_mapping.block_mappings
num_scratch = grid_mapping.num_scratch_operands
block_mappings = [
*[None] * scalar_prefetch,
*block_mappings,
*[None] * num_scratch,
]
memory_spaces = [
*[SMEM] * scalar_prefetch,
*memory_spaces,
*[None] * num_scratch,
]
assert len(memory_spaces) == len(jaxpr.invars), (
"Must have as many memory spaces as inputs and outputs.")
invar_arg_types, block_shapes = unzip2(
map(_get_arg_type, [invar.aval for invar in jaxpr.invars], block_mappings,
memory_spaces)
)
arg_types = [*arg_types, *invar_arg_types]
if grid_mapping:
def body_func(*args):
grid_indices, args = split_list(args, [len(grid_mapping.grid)])
grid_indices = [
g
for i, g in enumerate(grid_indices)
if i not in grid_mapping.mapped_dims
]
if mesh_info is not None:
(l_to_m,), args = split_list(args, [1])
mesh_context = MeshContext(l_to_m, axis_names, mesh_strides)
else:
mesh_context = None
lowering_context = LoweringContext(
ctx,
grid_mapping,
tuple(grid_indices),
block_shapes,
source_info_util.NameStack(),
mesh_context=mesh_context,
)
return jaxpr_subcomp(lowering_context, jaxpr, *args)
else:
lowering_context = LoweringContext(
ctx, None, None, block_shapes, source_info_util.NameStack()
)
body_func = functools.partial(jaxpr_subcomp, lowering_context, jaxpr)
body_func.__name__ = name
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
body.func_op.verify()
return body.func_op
class LoweringException(Exception):
pass
def jaxpr_subcomp(
ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value
) -> Sequence[ir.Value]:
assert not jaxpr.constvars
env = {}
block_shape_env = {}
def read_block_shape(atom: jax_core.Atom):
if isinstance(atom, jax_core.Literal):
return None
return block_shape_env.get(atom, None)
def read_env(atom: jax_core.Atom):
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
def write_env(var: jax_core.Var, val):
assert isinstance(val, ir.Value), type(val)
env[var] = val
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
block_shape_env[invar] = bs
map(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack
)
loc = mlir._source_info_to_location(
eqn.primitive, eqn.params, source_info, ctx.name_stack
)
with source_info_util.user_context(eqn.source_info.traceback), loc:
if eqn.primitive in lowering_rules:
if eqn.primitive not in skip_mlir_conversions:
invals = [_ensure_mlir_value(x, v.aval)
for x, v in zip(invals, eqn.invars)]
block_shapes = map(read_block_shape, eqn.invars)
rule_context = LoweringRuleContext(
ctx,
[v.aval for v in eqn.invars],
[v.aval for v in eqn.outvars],
block_shapes,
)
try:
ans = lowering_rules[eqn.primitive](
rule_context, *invals, **eqn.params
)
except LoweringException:
raise # We only add the extra info to the innermost exception.
except Exception as e:
raise LoweringException(
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
f" {rule_context}\nWith inval"
f" shapes={map(lambda t: getattr(t, 'shape', None), invals)}\nWith"
" inval"
f" types={map(lambda t: getattr(t, 'type', None), invals)}\nIn"
f" jaxpr:\n{jaxpr}"
) from e
else:
raise NotImplementedError(
"Unimplemented primitive in Pallas TPU lowering: "
f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues.")
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, ans)
else:
write_env(eqn.outvars[0], ans)
outvals = map(read_env, jaxpr.outvars)
outvals = [
ir_constant(x) if isinstance(var, jax_core.Literal) else x
for x, var in zip(outvals, jaxpr.outvars)
]
return outvals
def _ensure_mlir_value(val, aval):
if isinstance(val, ir.Value):
return val
elif isinstance(val, (np.generic, np.ndarray, int, float)):
return ir_constant(val, mlir.dtype_to_ir_type(aval.dtype))
else:
raise RuntimeError(
f"Unsupported argument to a JAX primitive of type: {type(val)}"
)
def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
non_slice_idx_avals, indexed_dims):
non_slice_idx_iter = iter(zip(non_slice_idx, non_slice_idx_avals))
splatted_idx_idx_avals = tuple(
next(non_slice_idx_iter)
if indexed
else (primitives.Slice(0, s), primitives.Slice(0, s))
for s, indexed in zip(ref_aval.shape,indexed_dims)
)
splatted_idx, splatted_idx_avals = unzip2(splatted_idx_idx_avals)
if non_slice_idx:
(int_indexer_shape,) = {idx_aval.shape for idx_aval in splatted_idx_avals
if not isinstance(idx_aval, primitives.Slice)}
else:
int_indexer_shape = ()
nd_indexer = NDIndexer(splatted_idx, ref_aval.shape, int_indexer_shape)
nd_indexer_avals = NDIndexer(splatted_idx_avals, ref_aval.shape,
int_indexer_shape)
return nd_indexer, nd_indexer_avals
def _get_lowering_rule(
ctx: LoweringRuleContext, ref, *non_slice_idx, indexed_dims: Sequence[bool]
):
# Call _load_lowering_rule (since it's more general)
ref_aval, *non_slice_idx_avals = ctx.avals_in
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer(
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims)
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, None, None))
avals_flat = tree_util.tree_leaves((ref_aval, nd_indexer_avals, None, None))
ctx = ctx.replace(avals_in=avals_flat)
return _load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
lowering_rules[state_primitives.get_p] = _get_lowering_rule
skip_mlir_conversions.add(state_primitives.get_p)
def _swap_lowering_rule(
ctx: LoweringRuleContext,
ref,
val,
*non_slice_idx,
indexed_dims: Sequence[bool],
):
# Call _masked_swap_lowering_rule (since it's more general)
ref_aval, val_aval, *non_slice_idx_avals = ctx.avals_in
nd_indexer, nd_indexer_avals = _convert_flat_indexing_to_indexer(
ref_aval, non_slice_idx, non_slice_idx_avals, indexed_dims)
args_flat, args_tree = tree_util.tree_flatten((ref, nd_indexer, val, None))
avals_flat = tree_util.tree_leaves(
(ref_aval, nd_indexer_avals, val_aval, None)
)
ctx = ctx.replace(avals_in=avals_flat)
return _masked_swap_lowering_rule(ctx, *args_flat, args_tree=args_tree)
lowering_rules[state_primitives.swap_p] = _swap_lowering_rule
skip_mlir_conversions.add(state_primitives.swap_p)
def _make_index(s):
if isinstance(s, (int, np.ndarray)):
return ir_constant(s, ir.IndexType.get())
if s.type == ir.IndexType.get():
return s
return arith.IndexCastOp(ir.IndexType.get(), s).result
def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
ref, idx, mask, _ = args_tree.unflatten(args_flat)
_, idx_aval, _, _ = args_tree.unflatten(ctx.avals_in)
if mask is not None:
raise NotImplementedError
ref_type = ir.MemRefType(ref.type)
is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
ref_aval, *_ = ctx.avals_in
(aval_out,) = ctx.avals_out
ref_block_shape, *_ = ctx.block_shapes
if not is_smem_load and not ref_block_shape:
raise NotImplementedError(
"Indexing into a ()-shaped Ref not yet supported on TPU.")
if any(
(not isinstance(a, primitives.Slice) and a.shape)
for a in idx_aval.indices
):
raise ValueError("Cannot do int indexing on TPU")
starts = tuple(
i.start if isinstance(i, primitives.Slice) else i for i in idx.indices
)
mlir_indices = [
s if isinstance(s, primitives.Slice) else _make_index(s) for s in starts
]
# Need to now insert indexing the 0-th element for mapped dimensions
idx_iter = iter(mlir_indices)
mlir_indices = [
_make_index(0) if b is core.mapped else next(idx_iter)
for b in ref_block_shape
]
assert len(mlir_indices) == len(ref_block_shape)
load_shape = list(aval_out.shape)
for i, a in enumerate(idx_aval.indices):
if not isinstance(a, primitives.Slice):
load_shape.insert(i, 1)
assert len(load_shape) == len(ref_aval.shape)
load_shape_iter = iter(load_shape)
load_shape = [
1 if b is core.mapped else next(load_shape_iter) for b in ref_block_shape
]
load_aval = aval_out.update(shape=tuple(load_shape))
if is_smem_load:
if ctx.avals_out[0].shape:
raise ValueError("Can only load scalars from SMEM")
return memref.LoadOp(ref, mlir_indices).result
else:
load_val = vector.LoadOp(aval_to_ir_type(load_aval), ref, mlir_indices).result
if load_aval == aval_out:
return load_val
vec_type = ir.VectorType.get(aval_out.shape,
mlir.dtype_to_ir_type(aval_out.dtype))
return vector.ShapeCastOp(vec_type, load_val).result
lowering_rules[primitives.load_p] = _load_lowering_rule
skip_mlir_conversions.add(primitives.load_p)
def _masked_swap_lowering_rule(
ctx: LoweringRuleContext, *args_flat, args_tree, **_
):
ref, idx, val, mask = args_tree.unflatten(args_flat)
_, idx_aval, _, _ = args_tree.unflatten(ctx.avals_in)
if mask is not None:
raise NotImplementedError
ref_type = ir.MemRefType(ref.type)
is_smem_store = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
ref_block_shape, *_ = ctx.block_shapes
_, val_aval, *_ = ctx.avals_in
(aval_out,) = ctx.avals_out
if not isinstance(val, ir.Value):
val = ir_constant(val, mlir_type=mlir.dtype_to_ir_type(val_aval.dtype))
if any(
(not isinstance(a, primitives.Slice) and a.shape)
for a in idx_aval.indices
):
raise ValueError("Cannot do int indexing on TPU")
if not is_smem_store and not ref_block_shape:
raise NotImplementedError(
"Indexing into a ()-shaped Ref not yet supported on TPU.")
starts = tuple(
i.start if isinstance(i, primitives.Slice) else i for i in idx.indices
)
mlir_indices = [
s if isinstance(s, primitives.Slice) else _make_index(s) for s in starts
]
# Need to now insert indexing the 0-th element for mapped dimensions
idx_iter = iter(mlir_indices)
mlir_indices = [
_make_index(0) if b is core.mapped else next(idx_iter)
for b in ref_block_shape
]
assert len(mlir_indices) == len(ref_block_shape)
if is_smem_store:
if val_aval.shape:
raise ValueError("Can only store scalars to SMEM")
result = memref.LoadOp(ref, mlir_indices).result
memref.StoreOp(val, ref, mlir_indices)
return result
mem_slice_shape = list(aval_out.shape)
for i, a in enumerate(idx_aval.indices):
if not isinstance(a, primitives.Slice):
mem_slice_shape.insert(i, 1)
mem_slice_shape_iter = iter(mem_slice_shape)
mem_slice_shape = [
1 if b is core.mapped else next(mem_slice_shape_iter)
for b in ref_block_shape
]
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
mem_aval_vec_type = ir.VectorType.get(mem_aval.shape,
mlir.dtype_to_ir_type(mem_aval.dtype))
result = vector.LoadOp(mem_aval_vec_type, ref, mlir_indices).result
if mem_aval != aval_out:
# We are slicing a scalar so provided dummy 1 indices
result_vec_type = ir.VectorType.get(aval_out.shape,
mlir.dtype_to_ir_type(aval_out.dtype))
result = vector.ShapeCastOp(result_vec_type, result).result
val_vec_type = ir.VectorType.get(mem_aval.shape,
mlir.dtype_to_ir_type(mem_aval.dtype))
val = vector.ShapeCastOp(val_vec_type, val).result
vector.StoreOp(val, ref, mlir_indices)
return result
lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
skip_mlir_conversions.add(primitives.swap_p)
def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values):
del ctx, values
return val
lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule
def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
(x_aval,) = ctx.avals_in
out_type = aval_to_ir_type(ctx.avals_out[0])
if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = ir.Attribute.parse("#vector.kind<maxf>")
val = ir.FloatAttr.get(ir.F32Type.get(), float("-inf"))
identity = ir.DenseElementsAttr.get_splat(out_type, val)
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
kind = ir.Attribute.parse("#vector.kind<maxsi>")
raise NotImplementedError
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
kind = ir.Attribute.parse("#vector.kind<maxui>")
raise NotImplementedError
acc = arith.ConstantOp(out_type, identity)
op = vector.MultiDimReductionOp(
kind,
x,
acc,
ir.ArrayAttr.get(
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
),
)
return op.result
lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule
def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
(x_aval,) = ctx.avals_in
out_type = aval_to_ir_type(ctx.avals_out[0])
if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = ir.Attribute.parse("#vector.kind<add>")
val = ir.FloatAttr.get(ir.F32Type.get(), 0.0)
identity = ir.DenseElementsAttr.get_splat(out_type, val)
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
kind = ir.Attribute.parse("#vector.kind<add>")
raise NotImplementedError
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
kind = ir.Attribute.parse("#vector.kind<add>")
raise NotImplementedError
acc = arith.ConstantOp(out_type, identity)
op = vector.MultiDimReductionOp(
kind,
x,
acc,
ir.ArrayAttr.get(
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes]
),
)
return op.result
lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule
def _broadcast_in_dim_lowering_rule(
ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions
):
(aval_in,) = ctx.avals_in
(aval_out,) = ctx.avals_out
if broadcast_dimensions:
out_shape_list = [1] * len(shape)
for i, s in zip(broadcast_dimensions, aval_in.shape):
out_shape_list[i] = s
out_shape = tuple(out_shape_list)
out_type = ir.VectorType.get(
out_shape, mlir.dtype_to_ir_type(aval_out.dtype)
)
val = vector.ShapeCastOp(out_type, val).result
if out_shape == aval_out.shape:
return val
out_type = ir.VectorType.get(
aval_out.shape, mlir.dtype_to_ir_type(aval_out.dtype)
)
return vector.BroadcastOp(out_type, val).result
lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule
def _dot_general_lowering_rule(
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
):
(lhs_dims, rhs_dims), _ = dimension_numbers
(aval_out,) = ctx.avals_out
out_type = aval_to_ir_type(aval_out)
if ctx.avals_out[0].dtype == jnp.float32:
val = ir.FloatAttr.get(ir.F32Type.get(), 0.0)
elif ctx.avals_out[0].dtype == jnp.float16:
val = ir.FloatAttr.get(ir.F16Type.get(), 0.0)
else:
raise NotImplementedError(ctx.avals_out[0].dtype)
if any(len(a.shape) != 2 for a in ctx.avals_in):
raise NotImplementedError(ctx.avals_in)
lhs_aval, _ = ctx.avals_in
# This is really a matrix-vector product. It only looks like matrix-matrix.
if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
bcast_shape = jnp.broadcast_shapes(
ctx.avals_in[0].shape, ctx.avals_out[0].shape
)
bcast_shape = ir.VectorType.get(
list(bcast_shape), mlir.dtype_to_ir_type(ctx.avals_out[0].dtype)
)
if ctx.avals_in[0].shape != bcast_shape:
x = vector.BroadcastOp(bcast_shape, x)
if ctx.avals_in[1].shape != bcast_shape:
y = vector.BroadcastOp(bcast_shape, y)
red_type = aval_to_ir_type(lhs_aval.update(shape=(lhs_aval.shape[0],)))
acc = arith.ConstantOp(
red_type, ir.DenseElementsAttr.get_splat(red_type, val)
)
red = vector.MultiDimReductionOp(
ir.Attribute.parse("#vector.kind<add>"),
arith.MulFOp(x, y),
acc,
ir.ArrayAttr.get(
[ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 1)]
),
)
return vector.ShapeCastOp(out_type, red).result
if lhs_dims == (1,):
transpose_lhs = False
elif lhs_dims == (0,):
transpose_lhs = True
else:
raise NotImplementedError
if rhs_dims == (0,):
transpose_rhs = False
elif rhs_dims == (1,):
transpose_rhs = True
else:
raise NotImplementedError
if precision is not None:
if precision[0] != precision[1]:
raise NotImplementedError("Per-operand dot precision unsupported")
precision = precision[0]
if precision is None or precision == lax.Precision.DEFAULT:
precision_attr = None # That's the default in Mosaic.
elif precision == lax.Precision.HIGHEST:
precision_attr = ir.Attribute.parse(
"#tpu.contract_precision<fp32>"
)
else:
raise NotImplementedError(f"Unsupported dot precision: {precision}")
out_tile = arith.ConstantOp(
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
)
op = tpu.MatmulOp(
out_type, x, y, out_tile,
transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs,
precision=precision_attr
)
return op.result
lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule
_INT_DTYPES = {
8: np.dtype(np.int8),
16: np.dtype(np.int16),
32: np.dtype(np.int32),
}
def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
):
del weak_type
out_aval = ctx.avals_out[0]
old_dtype = ctx.avals_in[0].dtype
out_type = aval_to_ir_type(out_aval)
if old_dtype == new_dtype:
return x
if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
new_dtype, jnp.floating
):
if old_dtype.itemsize < new_dtype.itemsize:
return arith.ExtFOp(out_type, x).result
else:
return arith.TruncFOp(out_type, x).result
elif old_dtype == jnp.bool_ and jnp.issubdtype(new_dtype, jnp.integer):
return arith.ExtUIOp(out_type, x).result
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
new_dtype, jnp.floating
):
# TODO(sharadmv,apaszke): remove this when Mosaic handles SIToFP with
# differing element bitwidths
if old_dtype.itemsize < new_dtype.itemsize:
ext_dtype = _INT_DTYPES[new_dtype.itemsize * 8]
ext_type = aval_to_ir_type(out_aval.update(dtype=ext_dtype))
x = arith.ExtSIOp(ext_type, x).result
elif old_dtype.itemsize > new_dtype.itemsize:
ext_dtype = _INT_DTYPES[new_dtype.itemsize * 8]
ext_type = aval_to_ir_type(out_aval.update(dtype=ext_dtype))
x = arith.TruncIOp(ext_type, x).result
return arith.SIToFPOp(out_type, x).result
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
new_dtype, jnp.signedinteger
):
if old_dtype.itemsize < new_dtype.itemsize:
return arith.ExtSIOp(out_type, x).result
else:
return arith.TruncIOp(out_type, x).result
elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
new_dtype, jnp.signedinteger
):
return arith.FPToSIOp(out_type, x).result
raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule
def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions):
if dimensions is not None:
raise NotImplementedError
if any(d is None for d in new_sizes):
raise NotImplementedError
if not ctx.avals_in[0].shape:
return vector.BroadcastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
lowering_rules[lax.reshape_p] = _reshape_lowering_rule
def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
del dimensions # Unused.
(aval_in,) = ctx.avals_in
(aval_out,) = ctx.avals_out
if not aval_out.shape:
return vector.ExtractOp(x, [], [0] * len(aval_in.shape)).result
return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule
def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
return tpu.ConcatenateOp(
aval_to_ir_type(ctx.avals_out[0]), xs, dimension=dimension
).result
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule
def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension):
out_type = aval_to_ir_type(ctx.avals_out[0])
return tpu.IotaOp(out_type, dimension=dimension).result
lowering_rules[lax.iota_p] = _iota_lowering_rule
def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):
if permutation != (1, 0):
raise NotImplementedError
out_type = aval_to_ir_type(ctx.avals_out[0])
i64_type = ir.IntegerType.get_signless(64)
transp = ir.ArrayAttr.get(
[ir.IntegerAttr.get(i64_type, i) for i in permutation]
)
return vector.TransposeOp(out_type, x, transp).result
lowering_rules[lax.transpose_p] = _transpose_lowering_rule
def _bcast(x, y, x_aval, y_aval, out_aval):
if isinstance(x, (np.ndarray, np.number, int, float)):
if hasattr(y, "type") and y.type == ir.IndexType.get():
mlir_type = y.type
else:
mlir_type = mlir.dtype_to_ir_type(x_aval.dtype)
x = ir_constant(x, mlir_type)
if isinstance(y, (np.ndarray, np.number, int, float)):
if hasattr(x, "type") and x.type == ir.IndexType.get():
mlir_type = x.type
else:
mlir_type = mlir.dtype_to_ir_type(y_aval.dtype)
y = ir_constant(y, mlir_type)
out_shape = list(out_aval.shape)
if x_aval.shape != out_aval.shape:
x_ty = ir.VectorType.get(out_shape, mlir.dtype_to_ir_type(x_aval.dtype))
x = vector.BroadcastOp(x_ty, x)
if y_aval.shape != out_aval.shape:
y_ty = ir.VectorType.get(out_shape, mlir.dtype_to_ir_type(y_aval.dtype))
y = vector.BroadcastOp(y_ty, y)
return x, y
def _add_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
(aval_out,) = ctx.avals_out
if jnp.issubdtype(aval_out.dtype, jnp.integer):
return arith.AddIOp(x, y).result
if jnp.issubdtype(aval_out.dtype, jnp.floating):
return arith.AddFOp(x, y).result
raise NotImplementedError(aval_out.dtype)
lowering_rules[lax.add_p] = _add_lowering_rule
skip_mlir_conversions.add(lax.add_p)