-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
lowering.py
1565 lines (1366 loc) · 49.3 KB
/
lowering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for lowering JAX primitives to Triton IR."""
from __future__ import annotations
import dataclasses
import functools
import operator
from typing import Any, Dict, Sequence, Tuple
import zlib
import jax
from jax import lax
from jax import tree_util
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import state
from jax._src import util
from jax._src.lax.control_flow import for_loop
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib import hlo_helpers
from jax._src.lib import version
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas import indexing
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.state import AbstractRef
from jax._src.state import discharge
from jax._src.state import primitives as sp
from jax._src.util import merge_lists
from jax._src.util import partition_list
from jax._src.util import split_list
from jax._src.util import weakref_lru_cache
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp
from jax_triton import triton_lib
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
from jax_triton.triton_lib import get_triton_type
import numpy as np
from triton._C.libtriton.triton import ir as tl_ir
from triton.compiler import code_generator as code_gen
import triton.language as tl
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
partial = functools.partial
Grid = Tuple[int, ...]
NDIndexer = indexing.NDIndexer
GridMapping = pallas_core.GridMapping
BlockMapping = pallas_core.BlockMapping
# # General lowering logic
@dataclasses.dataclass
class TritonModuleContext:
name: str
ir_context: tl_ir.context
builder: tl_ir.builder
module: tl_ir.module
grid_mapping: GridMapping
program_ids: Sequence[tl.tensor]
@dataclasses.dataclass
class BlockInfo:
full_shape_dtype: jax.ShapeDtypeStruct
start_indices: Sequence[Any]
block_shape: Tuple[int, ...]
@dataclasses.dataclass
class TritonLoweringRuleContext:
context: TritonModuleContext
avals_in: Any
avals_out: Any
block_infos: Sequence[BlockInfo | None]
def __post_init__(self):
self.builder = self.context.builder
replace = dataclasses.replace
@dataclasses.dataclass
class TritonLoweringResult:
"""Keeps pybind11 objects alive."""
ir_context: tl_ir.context
module: tl_ir.module
builder: tl_ir.builder
grid: Tuple[int, ...]
@dataclasses.dataclass
class TritonCompilationResult:
kernel_name: str
ttir: str
ptx: str
shared_mem_bytes: int
compute_capability: int
lowering_result: TritonLoweringResult
class TritonLoweringException(Exception):
pass
def _eval_index_map(
ctx: TritonModuleContext, idx, block_mapping: BlockMapping | None
):
if block_mapping is None:
return None
block_indices = tuple(
lower_jaxpr_to_triton_ir(
ctx, block_mapping.index_map_jaxpr.jaxpr, None, *idx
)
)
return tuple(
i if b is pallas_core.mapped else i.__mul__(b, _builder=ctx.builder)
for i, b in zip(block_indices, block_mapping.block_shape)
)
def _convert_dtype(dtype: jnp.dtype) -> tl.dtype:
if dtype == jnp.float32:
return tl.float32
elif dtype == jnp.float64:
return tl.float64
elif dtype == jnp.float16:
return tl.float16
elif dtype == jnp.bfloat16:
return tl.bfloat16
elif dtype == jnp.int32:
return tl.int32
elif dtype == jnp.int64:
return tl.int64
raise ValueError(f"Unhandled dtype: {dtype}")
triton_lowering_rules = {}
def _process_grid_to_3d_grid(builder, grid_mapping: GridMapping):
if len(grid_mapping.grid) <= 3:
program_ids = [
tl.program_id(axis=i, _builder=builder)
for i in range(len(grid_mapping.grid))
]
return grid_mapping.grid, program_ids
grid_prefix = grid_mapping.grid[:-2]
grid_suffix = grid_mapping.grid[-2:]
total_axis_size = np.prod(grid_prefix)
new_grid = (total_axis_size, *grid_suffix)
out_indices = [0] * len(grid_prefix)
grid0 = tl.program_id(0, _builder=builder)
for i, s in reversed(list(enumerate(grid_prefix))):
grid0, out_indices[i] = (
grid0.__floordiv__(s, _builder=builder),
grid0.__mod__(s, _builder=builder),
)
out_indices = [
*out_indices,
tl.program_id(1, _builder=builder),
tl.program_id(2, _builder=builder),
]
assert len(out_indices) == len(grid_mapping.grid)
return new_grid, out_indices
def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr, in_shapes, grid_mapping: GridMapping, name: str
) -> tl_ir.module:
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
ir_context = tl_ir.context()
ir_context.load_triton()
builder = tl_ir.builder(ir_context)
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
builder.arch = triton_kernel_call_lib.get_compute_capability(device)
module = builder.create_module()
in_avals = [var.aval for var in jaxpr.invars]
triton_types = [get_triton_type(x) for x in in_avals]
arg_types = [code_gen.str_to_ty(arg) for arg in triton_types]
assert len(jaxpr.outvars) == 0
prototype = tl.function_type([], arg_types)
out = prototype.to_ir(builder)
fn = builder.get_or_insert_function(module, name, out, "public", False)
module.push_back(fn)
entry = fn.add_entry_block()
args = []
for i in range(len(in_avals)):
fn.set_arg_attr(i, "tt.divisibility", 16)
ptr = tl.tensor(fn.args(i), prototype.param_types[i])
args.append(ptr)
builder.set_insertion_point_to_start(entry)
new_grid, program_ids = _process_grid_to_3d_grid(builder, grid_mapping)
local_program_ids = [
pid for i, pid in enumerate(program_ids) if i not in grid_mapping.mapped_dims
]
ctx = TritonModuleContext(
name, ir_context, builder, module, grid_mapping, local_program_ids
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"Scalar prefetch not supported in Triton lowering.")
start_indices = map(
partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings
)
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
start_idx,
block_mapping.block_shape,
)
if block_mapping is not None
else None
for shape_dtype, block_mapping, start_idx in zip(
in_shapes, grid_mapping.block_mappings, start_indices
)
]
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
module.context = ir_context
ctx.builder.ret([])
return TritonLoweringResult(ir_context, module, builder, new_grid)
def lower_jaxpr_to_triton_ir(
ctx: TritonModuleContext,
jaxpr: jax_core.Jaxpr,
block_infos: Sequence[BlockInfo | None] | None,
*args
) -> Sequence[Any]:
env = {}
block_info_env = {}
def read_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder)
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty
if t.type.scalar != dst_ty:
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64
# comes out of .tolist() as list[float], which then comes out of
# _to_tensor as a block of f32.
t = tl.semantic.cast(t, dst_ty, ctx.builder)
return t
return env[var]
def read_block_info_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
return None
return block_info_env.get(var, None)
def write_env(var: jax_core.Var, val):
env[var] = val
if block_infos is None:
block_infos = [None] * len(jaxpr.invars)
for invar, block_info in zip(jaxpr.invars, block_infos):
block_info_env[invar] = block_info
map(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
if eqn.primitive not in triton_lowering_rules:
raise NotImplementedError(
"Unimplemented primitive in Pallas GPU lowering: "
f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues.")
rule = triton_lowering_rules[eqn.primitive]
avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars]
eqn_block_infos = map(read_block_info_env, eqn.invars)
rule_ctx = TritonLoweringRuleContext(
ctx, avals_in, avals_out, eqn_block_infos
)
try:
outvals = rule(rule_ctx, *invals, **eqn.params)
except TritonLoweringException:
raise # We only add the extra info to the innermost exception.
except Exception as e:
raise TritonLoweringException(
f"Exception while lowering eqn:\n {eqn}\n"
f"With context:\n {rule_ctx}\n"
f"With inval shapes={map(lambda t: t.shape, invals)}\n"
f"With inval types={map(lambda t: t.type, invals)}\n"
f"In jaxpr:\n{jaxpr}") from e
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
return map(read_env, jaxpr.outvars)
# # Primitive lowering rules
# ## Programming model primitives
def _program_id_lowering_rule(ctx: TritonLoweringRuleContext, *, axis):
return ctx.context.program_ids[axis]
triton_lowering_rules[primitives.program_id_p] = _program_id_lowering_rule
# ## Atomic op primitives
_ATOMIC_OP_MAPPING = {
primitives.AtomicOpType.XCHG: tl.core.atomic_xchg,
primitives.AtomicOpType.ADD: tl.core.atomic_add,
primitives.AtomicOpType.MAX: tl.core.atomic_max,
primitives.AtomicOpType.MIN: tl.core.atomic_min,
primitives.AtomicOpType.AND: tl.core.atomic_and,
primitives.AtomicOpType.OR: tl.core.atomic_or,
primitives.AtomicOpType.XOR: tl.core.atomic_xor,
}
def _atomic_lowering_rule(
ctx: TritonLoweringRuleContext,
*args_flat,
args_tree,
atomic_type: primitives.AtomicOpType,
):
ptr, idx, value, mask = args_tree.unflatten(args_flat)
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape, ctx.builder
)
op = _ATOMIC_OP_MAPPING.get(atomic_type)
if op is None:
raise NotImplementedError(atomic_type)
return op(ptr, value, mask=mask, _builder=ctx.builder)
triton_lowering_rules[primitives.atomic_rmw_p] = _atomic_lowering_rule
_TRITON_FN_MAPPING = {
# Unary ops.
lax.neg_p: tl.semantic.minus,
lax.abs_p: tl.abs,
lax.ceil_p: tl.math.ceil,
lax.floor_p: tl.math.floor,
lax.exp_p: tl.exp,
lax.exp2_p: tl.math.exp2,
lax.expm1_p: tl.math.expm1,
lax.log_p: tl.log,
lax.log1p_p: tl.math.log1p,
lax.sqrt_p: tl.sqrt,
lax.cbrt_p: tl.math.cbrt,
lax.rsqrt_p: tl.math.rsqrt,
lax.sin_p: tl.sin,
lax.cos_p: tl.cos,
lax.tan_p: tl.math.tan,
lax.asin_p: tl.math.asin,
lax.acos_p: tl.math.acos,
lax.atan_p: tl.math.atan,
lax.atan2_p: tl.math.atan2,
lax.sinh_p: tl.math.sinh,
lax.cosh_p: tl.math.cosh,
lax.tanh_p: tl.math.tanh,
lax.asinh_p: tl.math.asinh,
lax.acosh_p: tl.math.acosh,
lax.atanh_p: tl.math.atanh,
lax.not_p: tl.semantic.invert,
lax.population_count_p: tl.math.popc,
lax.clz_p: tl.math.clz,
# Binary ops.
lax.add_p: tl.semantic.add,
lax.sub_p: tl.semantic.sub,
lax.mul_p: tl.semantic.mul,
lax.pow_p: tl.math.pow,
lax.rem_p: tl.semantic.mod,
lax.and_p: tl.semantic.and_,
lax.or_p: tl.semantic.or_,
lax.xor_p: tl.semantic.xor_,
lax.eq_p: tl.semantic.equal,
lax.ne_p: tl.semantic.not_equal,
lax.gt_p: tl.semantic.greater_than,
lax.ge_p: tl.semantic.greater_equal,
lax.lt_p: tl.semantic.less_than,
lax.le_p: tl.semantic.less_equal,
lax.max_p: tl.math.max,
lax.min_p: tl.math.min,
lax.shift_left_p: tl.semantic.shl,
lax.shift_right_arithmetic_p: tl.semantic.ashr,
lax.shift_right_logical_p: tl.semantic.lshr,
lax.nextafter_p: tl.math.nextafter,
ad_util.add_any_p: tl.semantic.add,
# Other ops.
indexing.broadcast_to_p: tl.broadcast_to,
primitives.atomic_cas_p: tl.atomic_cas,
primitives.max_contiguous_p: tl.max_contiguous,
primitives.multiple_of_p: tl.multiple_of,
}
for primitive, fn in _TRITON_FN_MAPPING.items():
if tl.core.is_builtin(fn):
def rule(ctx, *args, fn=fn, **kwargs):
kwargs = tree_util.tree_map(tl.constexpr, kwargs)
return fn(*args, **kwargs, _builder=ctx.builder)
else:
rule = lambda ctx, *args, fn=fn: fn(*args, ctx.builder)
triton_lowering_rules[primitive] = rule
def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max):
operand = tl.math.max(operand, min, _builder=ctx.builder)
return tl.math.min(operand, max, _builder=ctx.builder)
triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule
def _logistic_lowering_rule(ctx: TritonLoweringRuleContext, a):
one_ = tl.core._to_tensor(1.0, ctx.builder)
x = tl.exp(a.__neg__(_builder=ctx.builder), _builder=ctx.builder)
x = x.__add__(one_, _builder=ctx.builder)
x = one_.__truediv__(x, _builder=ctx.builder)
return x
triton_lowering_rules[lax.logistic_p] = _logistic_lowering_rule
def _div_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
if a.dtype.is_floating() or b.dtype.is_floating():
return a.__truediv__(b, _builder=ctx.builder)
return a.__floordiv__(b, _builder=ctx.builder)
triton_lowering_rules[lax.div_p] = _div_lowering_rule
def _iota_lowering_rule(
ctx: TritonLoweringRuleContext, *, dtype, shape, dimension
):
if dimension != 0:
raise NotImplementedError()
return tl.arange(0, shape[0], _builder=ctx.builder)
triton_lowering_rules[lax.iota_p] = _iota_lowering_rule
def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
if y == 2:
return a.__mul__(a, _builder=ctx.builder)
if y == 3:
return a.__mul__(a.__mul__(a, _builder=ctx.builder), _builder=ctx.builder)
if y == -2:
one_ = tl.core._to_tensor(1.0, ctx.builder)
a_sq = a.__mul__(a, _builder=ctx.builder)
return one_.__truediv__(a_sq, _builder=ctx.builder)
return tl.math.pow(a, y, _builder=ctx.builder)
triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
def _convert_element_type_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type
):
if new_dtype == ctx.avals_in[0].dtype:
return a
return tl.semantic.cast(a, _convert_dtype(new_dtype), ctx.builder)
triton_lowering_rules[lax.convert_element_type_p] = (
_convert_element_type_lowering_rule
)
def select_n_lowering_rule(ctx: TritonLoweringRuleContext, pred, a, b):
return tl.semantic.where(pred, b, a, ctx.builder)
triton_lowering_rules[lax.select_n_p] = select_n_lowering_rule
def _broadcast_in_dim_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, broadcast_dimensions, shape
):
shape = map(tl.constexpr, shape)
if not a.type.is_block():
return tl.broadcast_to(a, shape, _builder=ctx.builder)
expand_dims = [i for i in range(len(shape)) if i not in broadcast_dimensions]
for dim in expand_dims:
a = tl.semantic.expand_dims(a, dim, ctx.builder)
return tl.broadcast_to(a, shape, _builder=ctx.builder)
triton_lowering_rules[jax.lax.broadcast_in_dim_p] = (
_broadcast_in_dim_lowering_rule
)
def _squeeze_lowering_rule(ctx: TritonLoweringRuleContext, a, *, dimensions):
del dimensions
return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None)
triton_lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule
def _reshape_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, new_sizes, dimensions
):
del new_sizes # Unused.
if dimensions is not None:
return ValueError("`dimensions` is not supported.")
dst_shape = map(tl.constexpr, ctx.avals_out[0].shape)
if not a.type.is_block():
assert all(dim_size.value == 1 for dim_size in dst_shape)
return tl.broadcast_to(a, dst_shape, _builder=ctx.builder)
# Expand-dims or reduce-sum to handle singleton dims as `tl.reshape` is not
# currently implemented.
i = 0
while a.shape != dst_shape:
dim_size = a.shape[i].value if i < len(a.shape) else None
dst_dim_size = dst_shape[i].value if i < len(dst_shape) else None
if dim_size == dst_dim_size:
i += 1
elif dst_dim_size == 1:
a = tl.expand_dims(a, axis=i, _builder=ctx.builder)
i += 1
elif dim_size == 1:
in_shape = tuple(d.value for d in a.shape)
out_shape = tuple(d.value for di, d in enumerate(a.shape) if di != i)
reduce_ctx = ctx.replace(
avals_in=[ctx.avals_in[0].update(shape=in_shape)],
avals_out=[ctx.avals_in[0].update(shape=out_shape)],
)
a = _reduce_lowering(jnp.add, reduce_ctx, a, axes=(i,))
else: # We expect this to fail.
return tl.reshape(a, dst_shape, _builder=ctx.builder)
return a
triton_lowering_rules[jax.lax.reshape_p] = _reshape_lowering_rule
def _compute_pointers_from_indices(
root_ptr: tl.core.tensor,
block_info: BlockInfo | None,
nd_indexer: NDIndexer,
array_shape: Tuple[int, ...],
builder: tl_ir.builder,
) -> tl.core.tensor:
if block_info is None:
full_shape = array_shape
num_mapped_dims = 0
block_shape = array_shape
else:
full_shape = block_info.full_shape_dtype.shape
num_mapped_dims = sum(
b is pallas_core.mapped for b in block_info.block_shape
)
block_shape = block_info.block_shape
strides = pallas_utils.strides_from_shape(full_shape)
indexer_shape = nd_indexer.get_indexer_shape()
int_indexer_shape = nd_indexer.int_indexer_shape
indices = nd_indexer.indices
other_shape = indexer_shape[len(int_indexer_shape) :]
bcast_indices = []
other_shape_idx = 0
if block_info is None:
start_index_offsets = [None] * len(indices)
else:
start_index_offsets = block_info.start_indices
assert len(indices) + num_mapped_dims == len(full_shape)
assert len(start_index_offsets) == len(full_shape)
indexer_iter = iter(indices)
for dim_stride, dim_block_size, start_offset in zip(
strides, block_shape, start_index_offsets
):
if dim_block_size is pallas_core.mapped:
index = tl.core._to_tensor(0, builder)
else:
index = next(indexer_iter)
if isinstance(index, primitives.Slice):
# Handle slices with static and dynamic indices and static sizes
if isinstance(index.start, int):
ptr_dim_offset = tl.arange(
index.start, index.start + index.size, _builder=builder
)
else:
ptr_dim_offset = index.start.__add__(
tl.arange(0, index.size, _builder=builder), _builder=builder
)
# We need to add broadcastable dimensions for the advanced int indexing
# and for previous slices
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
num_right_expand_dims = len(other_shape) - other_shape_idx - 1
other_shape_idx += 1
elif isinstance(index, slice):
if index != slice(None):
raise NotImplementedError("Only `slice(None)` allowed.")
ptr_dim_offset = tl.arange(0, dim_block_size, _builder=builder)
num_left_expand_dims = len(int_indexer_shape) + other_shape_idx
num_right_expand_dims = len(other_shape) - other_shape_idx - 1
other_shape_idx += 1
else:
# indexer is either a *scalar* or an array of size `int_indexer_shape`
ptr_dim_offset = index
num_left_expand_dims = 0
num_right_expand_dims = len(other_shape)
if not ptr_dim_offset.type.is_block():
num_left_expand_dims = max(len(indexer_shape) - 1, 0)
else:
num_right_expand_dims = len(other_shape)
if not ptr_dim_offset.type.is_block() and indexer_shape:
ptr_dim_offset = tl.broadcast_to(
ptr_dim_offset,
[tl.constexpr(1)] * len(indexer_shape),
_builder=builder,
)
else:
for _ in range(num_left_expand_dims):
ptr_dim_offset = tl.semantic.expand_dims(ptr_dim_offset, 0, builder)
for _ in range(num_right_expand_dims):
ndim = len(ptr_dim_offset.shape)
ptr_dim_offset = tl.semantic.expand_dims(ptr_dim_offset, ndim, builder)
if start_offset is not None:
ptr_dim_offset = ptr_dim_offset.__add__(start_offset, _builder=builder)
stride_size = tl.core._to_tensor(int(dim_stride), builder)
bcast_indices.append(ptr_dim_offset.__mul__(stride_size, _builder=builder))
block_shapes = [
() if not index.type.is_block() else tuple(index.type.get_block_shapes())
for index in bcast_indices
]
bcast_indices = [
tl.core.broadcast_to(
index, map(tl.constexpr, indexer_shape), _builder=builder
)
if indexer_shape != block_shape
else index
for index, block_shape in zip(bcast_indices, block_shapes)
]
ptr = root_ptr
for bcast_idx in bcast_indices:
ptr = ptr.__add__(bcast_idx, _builder=builder)
return ptr
def _pack_indices(non_slice_idx, indexed_dims):
non_slice_idx_iter = iter(non_slice_idx)
return tuple(
next(non_slice_idx_iter) if indexed else slice(None)
for indexed in indexed_dims
)
def _get_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, *non_slice_idx, indexed_dims
):
if not isinstance(ptr.type, tl.pointer_type):
assert not non_slice_idx
return ptr
ref_aval, *idx_avals = ctx.avals_in
idx_avals = _pack_indices(idx_avals, indexed_dims)
if non_slice_idx:
(int_indexer_shape,) = {
i.shape for i in idx_avals if not isinstance(i, slice)
}
else:
int_indexer_shape = ()
idx = _pack_indices(non_slice_idx, indexed_dims)
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(ref_aval.shape, idx)
)
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, None, None))
return _masked_load_lowering_rule(
ctx,
*args_flat,
args_tree=args_tree,
eviction_policy=None,
cache_modifier=None,
is_volatile=False,
)
triton_lowering_rules[sp.get_p] = _get_lowering_rule
def _masked_load_lowering_rule(
ctx: TritonLoweringRuleContext,
*args_flat,
args_tree,
eviction_policy,
cache_modifier,
is_volatile,
):
ptr, idx, mask, other = args_tree.unflatten(args_flat)
if not isinstance(ptr.type, tl.pointer_type):
assert len(ctx.avals_in) == 1
return ptr
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape, ctx.builder
)
val = tl.load(
ptr,
mask=mask,
other=other,
cache_modifier=cache_modifier,
volatile=is_volatile,
eviction_policy=eviction_policy,
_builder=ctx.builder,
)
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
return val.to(ptr.dtype.element_ty, _builder=ctx.builder)
triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
def _swap_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
):
ref_aval, _, *idx_avals = ctx.avals_in
idx_avals = _pack_indices(idx_avals, indexed_dims)
if non_slice_idx:
(int_indexer_shape,) = {
i.shape for i in idx_avals if not isinstance(i, slice)
}
else:
int_indexer_shape = ()
idx = _pack_indices(non_slice_idx, indexed_dims)
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(ref_aval.shape, idx)
)
idx = NDIndexer(idx, ref_aval.shape, int_indexer_shape)
args_flat, args_tree = tree_util.tree_flatten((ptr, idx, value, None))
return _masked_swap_lowering_rule(
ctx, *args_flat, args_tree=args_tree, eviction_policy=None
)
triton_lowering_rules[sp.swap_p] = _swap_lowering_rule
def _masked_swap_lowering_rule(
ctx: TritonLoweringRuleContext, *args_flat, args_tree, eviction_policy
):
ptr, idx, value, mask = args_tree.unflatten(args_flat)
ptr_type = (
ptr.type.element_ty.element_ty
if ptr.type.is_block()
else ptr.type.element_ty
)
value_type = value.type.element_ty if value.type.is_block() else value.type
assert ptr_type == value_type, (ptr_type, value_type)
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape, ctx.builder
)
other = None if mask is None else value
old_value = tl.load(ptr, mask=mask, other=other, _builder=ctx.builder)
tl.store(
ptr,
value,
mask=mask,
eviction_policy=eviction_policy,
_builder=ctx.builder,
)
return old_value
triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
def _addupdate_lowering_rule(
ctx: TritonLoweringRuleContext, ptr, value, *non_slice_idx, indexed_dims
):
ref_block_info, *_ = ctx.block_infos
avals_in = ctx.avals_in
idx = _pack_indices(non_slice_idx, indexed_dims)
if non_slice_idx:
(int_indexer_shape,) = {
tuple(map(lambda x: x.value, i.shape)) for i in non_slice_idx
}
else:
int_indexer_shape = ()
idx = tuple(
primitives.Slice.from_slice(slc, s) if isinstance(slc, slice) else slc
for s, slc in zip(avals_in[0].shape, idx)
)
idx = primitives.NDIndexer(idx, avals_in[0].shape, int_indexer_shape)
ptr = _compute_pointers_from_indices(
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
)
tl.atomic_add(ptr, value, _builder=ctx.builder)
return []
triton_lowering_rules[sp.addupdate_p] = _addupdate_lowering_rule
def _transpose_lowering(ctx: TritonLoweringRuleContext, a, *, permutation):
if permutation != (1, 0):
raise NotImplementedError(permutation)
return tl.trans(a, _builder=ctx.builder)
triton_lowering_rules[lax.transpose_p] = _transpose_lowering
_TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT)
def _dot_general_lowering(
ctx: TritonLoweringRuleContext,
a,
b,
*,
dimension_numbers,
precision,
preferred_element_type
):
del preferred_element_type # Unused.
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
assert batch_dims == ((), ())
if a_contract_dim == 0:
a = tl.trans(a, _builder=ctx.builder)
if b_contract_dim == 1:
b = tl.trans(b, _builder=ctx.builder)
if precision is None:
allow_tf32 = True
else:
prec_a, prec_b = precision
allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS
out_dtype = acc_dtype = _convert_dtype(ctx.avals_out[0].dtype)
if acc_dtype not in (tl.int32, tl.float16):
acc_dtype = tl.float32
return tl.dot(
a,
b,
allow_tf32=allow_tf32,
out_dtype=acc_dtype,
_builder=ctx.builder,
).to(out_dtype, _builder=ctx.builder)
triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering
def _reduction_lowering(body, ctx: TritonLoweringRuleContext, args, axes):
flat_args = tree_util.tree_leaves(args)
(axis,) = axes
mapped_avals = [
jax_core.mapped_aval(aval.shape[axis], axis, aval)
for aval in ctx.avals_in
]
in_tree = tree_util.tree_structure((args, args))
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree
)
combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
flat_fun, [*mapped_avals, *mapped_avals]
)
out_tree = out_tree_thunk()
del out_tree # Not needed
if consts:
raise NotImplementedError("Reductions with constants not supported.")
element_types = [arg.type.scalar for arg in flat_args]
builder = ctx.builder
reduce_op = builder.create_reduce([t.handle for t in flat_args], axis)
region = reduce_op.get_region(0)
old_ip = builder.get_insertion_point()
param_types = element_types * 2
ir_param_types = [ty.to_ir(builder) for ty in param_types]
block = builder.create_block_with_parent(region, ir_param_types)
combine_args = [
tl.core.tensor(block.arg(i), ty) for i, ty in enumerate(param_types)
]
results = lower_jaxpr_to_triton_ir(
ctx.context, combine_jaxpr, None, *combine_args
)
handles = [r.handle for r in results]
builder.create_reduce_ret(*handles)
builder.restore_insertion_point(old_ip)
reduce_op.verify()
def wrap_tensor(x, scalar_ty):
if ctx.avals_out[0].shape:
res_ty = tl.block_type(scalar_ty, ctx.avals_out[0].shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
return tl.core.tensor(x, res_ty)
results = [
wrap_tensor(reduce_op.get_result(i), ty)
for i, ty in enumerate(element_types)
]
return results
def _reduce_lowering(body, ctx: TritonLoweringRuleContext, a, *, axes):
assert isinstance(axes, tuple)
if not axes:
return a
while len(axes) > 1:
axis = max(axes)
dst_avals = tuple(v.update(shape=v.shape[:axis] + v.shape[axis + 1:])
for v in ctx.avals_in)
a = _reduce_lowering(
body, ctx.replace(avals_out=dst_avals), a, axes=(axis,))
# Adding an intervening -(-reduce(.)) introduces a convert_layout between
# reduces, which seems necessary for correctness.
# TODO(bjp): Get rid of the double negation.
# https://github.com/openai/triton/issues/1776
a = a.__neg__(_builder=ctx.builder).__neg__(_builder=ctx.builder)
ctx = ctx.replace(avals_in=dst_avals)
axes = tuple(ax for ax in axes if ax != axis)
return _reduction_lowering(body, ctx, a, axes=axes)[0]
triton_lowering_rules[lax.reduce_max_p] = functools.partial(
_reduce_lowering, jnp.maximum
)
triton_lowering_rules[lax.reduce_min_p] = functools.partial(
_reduce_lowering, jnp.minimum
)
triton_lowering_rules[lax.reduce_sum_p] = functools.partial(
_reduce_lowering, jnp.add
)
def _argreduce_lowering(
body, ctx: TritonLoweringRuleContext, a, *, axes, index_dtype
):
if index_dtype != jnp.int32:
raise ValueError("`index_type` must be f32.")
if len(axes) != 1:
raise ValueError("`pallas` reduce operations only support one reduce axis.")
(axis,) = axes
n = ctx.avals_in[0].shape[axis]
index = tl.arange(0, n, _builder=ctx.builder)
if len(a.shape) > 1:
# Broadcast index across the non-reduced axes
expand_dims_index = [tl.constexpr(None)] * len(a.shape)
expand_dims_index[axis] = slice(None)
index = index.__getitem__(expand_dims_index, _builder=ctx.builder)
index = tl.core.broadcast_to(index, a.shape, _builder=ctx.builder)
ctx = ctx.replace(
avals_in=[
ctx.avals_in[0],
ctx.avals_in[0].update(dtype=jnp.dtype("int32")),
]
)
_, indices = _reduction_lowering(body, ctx, (a, index), axes=axes)
return indices
def _reduce_argmax_combine(left, right):
value1, index1 = left
value2, index2 = right
gt = value1 > value2
lt = value1 < value2
index_min = jnp.minimum(index1, index2)
index_ret = jnp.where(gt, index1, jnp.where(lt, index2, index_min))
value_ret = jnp.maximum(value1, value2)
return value_ret, index_ret
triton_lowering_rules[lax.argmax_p] = functools.partial(