Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial port of HECO auto-SIMD passes to HEIR
- InsertRotate: insert rotations and apply target slot selection rules - TensorExtCanonicalization: canonicalization patterns to enable cse to remove unnecessary rotations - CollapseInsertionChains: identify extract/insert chains that can be converted to rotations Additional followup issues identified in #471 for improvements.
- Loading branch information
Showing
27 changed files
with
898 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_ | ||
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_ | ||
|
||
include "TensorExtOps.td" | ||
include "mlir/Dialect/Arith/IR/ArithOps.td" | ||
include "mlir/Dialect/Tensor/IR/TensorOps.td" | ||
include "mlir/IR/PatternBase.td" | ||
|
||
// TODO(#515): refactor these helpers to a common file with InsertRotate.td | ||
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">; | ||
|
||
def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">; | ||
|
||
def CreateIndexCastOp : NativeCodeCall< | ||
"$_builder.create<arith::IndexCastOp>($0.getLoc(), $1.getType(), $0)">; | ||
|
||
def IsZero : | ||
Constraint< | ||
CPred<"llvm::cast<mlir::IntegerAttr>($0).getValue().isZero()">>; | ||
|
||
def OutOfBoundsOfTensorDim : | ||
Constraint< | ||
CPred< | ||
"llvm::cast<mlir::IntegerAttr>($0).getValue().getSExtValue() < 0 " | ||
"|| llvm::cast<mlir::IntegerAttr>($0).getValue().getSExtValue() > " | ||
"llvm::cast<mlir::ShapedType>($1.getType()).getShape()[0]" | ||
> | ||
>; | ||
|
||
// rotate %t, 0 -> %t | ||
def DropZeroRotation : Pat< | ||
(TensorExt_RotateOp $tensor, (ConstantLikeMatcher APIntAttr:$c0)), | ||
(replaceWithValue $tensor), | ||
[(IsZero $c0)] | ||
>; | ||
|
||
// rotate %t, x -> rotate %t, x mod size | ||
def NormalizeRotationIndex : Pat< | ||
(TensorExt_RotateOp $tensor, (Arith_ConstantOp:$shiftOp APIntAttr:$shiftAmount)), | ||
(TensorExt_RotateOp $tensor, | ||
(Arith_RemUIOp | ||
$shiftOp, | ||
// Only works for 1D tensors: index is taken modulo the tensor length, | ||
// i.e., dim 0 | ||
(CreateIndexCastOp | ||
(Tensor_DimOp $tensor, (Arith_ConstantOp ConstantAttr<IndexAttr, "0">)), | ||
$shiftOp)) | ||
), | ||
[(OutOfBoundsOfTensorDim $shiftAmount, $tensor)] | ||
>; | ||
|
||
// %0 = rotate %t, x | ||
// %1 = rotate %0, y | ||
// ---> rotate %t (x+y) | ||
def CombineSequentialRotates : Pat< | ||
(TensorExt_RotateOp | ||
(TensorExt_RotateOp $tensor, (Arith_ConstantOp:$xOp APIntAttr:$x)), | ||
(Arith_ConstantOp:$yOp APIntAttr:$y)), | ||
(TensorExt_RotateOp $tensor, (Arith_AddIOp $xOp, $yOp, DefOverflow)), | ||
[] | ||
>; | ||
|
||
// A rotation followed by extraction can be extracted directly from the | ||
// original tensor. | ||
def RotatePlusExtractToIndexedExtract : Pat< | ||
(Tensor_ExtractOp | ||
(TensorExt_RotateOp $tensor, $shift), | ||
(variadic $index)), | ||
(Tensor_ExtractOp | ||
$tensor, | ||
(MakeSingleResultVariadic (Arith_AddIOp $shift, $index, DefOverflow))) | ||
>; | ||
|
||
// Rotating two tensors by the same amount can be converted to a single | ||
// post-rotation. This can result in eliminating either the rotation (because | ||
// it can be combined with a later rotation) or the arith op itself, if it is | ||
// is identical to an existing arith op applied before the rotation. | ||
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { | ||
def FactorParallelRotationsThroughOp_#ArithOp : Pat< | ||
(ArithOp | ||
(TensorExt_RotateOp $t1, $i), | ||
(TensorExt_RotateOp $t2, $i), | ||
$ovf), | ||
(TensorExt_RotateOp (ArithOp $t1, $t2, $ovf), $i) | ||
>; | ||
} | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTCANONICALIZATION_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# InsertRotate tablegen and headers. | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=TensorExt", | ||
], | ||
"Passes.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"TensorExtPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "Passes.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "insert_rotate_inc_gen", | ||
tbl_outs = [ | ||
( | ||
["-gen-rewriters"], | ||
"InsertRotate.cpp.inc", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "InsertRotate.td", | ||
deps = [ | ||
"@heir//include/Dialect/TensorExt/IR:ops_inc_gen", | ||
"@heir//include/Dialect/TensorExt/IR:td_files", | ||
"@llvm-project//mlir:ArithOpsTdFiles", | ||
"@llvm-project//mlir:TensorOpsTdFiles", | ||
], | ||
) | ||
|
||
exports_files([ | ||
"Passes.h", | ||
"CollapseInsertionChains.h", | ||
"InsertRotate.h", | ||
]) |
17 changes: 17 additions & 0 deletions
17
include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_ | ||
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace tensor_ext { | ||
|
||
#define GEN_PASS_DECL_COLLAPSEINSERTIONCHAINS | ||
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" | ||
|
||
} // namespace tensor_ext | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_COLLAPSEINSERTIONCHAINS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ | ||
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace tensor_ext { | ||
|
||
#define GEN_PASS_DECL_INSERTROTATE | ||
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" | ||
|
||
} // namespace tensor_ext | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_INSERTROTATE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ | ||
#define INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ | ||
|
||
include "include/Dialect/TensorExt/IR/TensorExtOps.td" | ||
include "mlir/Dialect/Arith/IR/ArithOps.td" | ||
include "mlir/Dialect/Tensor/IR/TensorOps.td" | ||
include "mlir/IR/PatternBase.td" | ||
|
||
// TODO(#512): Support target slot selection when the downstream op is an insert. | ||
|
||
// The patterns in this file are intended to align with the automatic-SIMD | ||
// batching heuristics from the HECO project. See section 4.4 of | ||
// https://arxiv.org/abs/2202.01649 and the hir2hir passes in | ||
// https://github.com/MarbleHE/HECO/blob/main/src/Passes/hir2hir/ | ||
|
||
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">; | ||
|
||
// To understand why this is needed, see | ||
// https://discourse.llvm.org/t/compilation-failure-with-drr-generated-pattern/77385 | ||
def MakeSingleResultVariadic: NativeCodeCall<"{ $0 }">; | ||
|
||
// Match an arith op that extracts scalar values from two tensors, and replace | ||
// it with rotations to align slots and apply the same op in SIMD. Other | ||
// patterns in this file will find better alignment of adjacent rotations, and | ||
// canonicalization patterns will remove duplicated rotations. | ||
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { | ||
def InsertRotations_#ArithOp : Pattern< | ||
(ArithOp | ||
(Tensor_ExtractOp $t1, (variadic $i1)), | ||
(Tensor_ExtractOp $t2, (variadic $i2)), | ||
$overflow), | ||
[ | ||
(TensorExt_RotateOp:$r1 $t1, $i1), | ||
(TensorExt_RotateOp:$r2 $t2, $i2), | ||
(ArithOp:$opResult $r1, $r2, $overflow), | ||
(Tensor_ExtractOp | ||
$opResult, | ||
(MakeSingleResultVariadic (Arith_ConstantOp ConstantAttr<IndexAttr, "0">))), | ||
] | ||
>; | ||
} | ||
|
||
// Pre-align the first op's operands to the index that the result is | ||
// used for in a subsequent op. | ||
// TODO(#514): handle OuterOp with two different InnerOps on the LHS and RHS | ||
foreach InnerOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { | ||
foreach OuterOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { | ||
// Left associated grouping handles (add (add (rotate t1 i1) (rotate t2 i2)) (rotate t3 i3)) | ||
def AlignRotations_LeftAssociated_Inner_#InnerOp#_Outer_#OuterOp : Pattern< | ||
(OuterOp | ||
(InnerOp (TensorExt_RotateOp $t1, $i1), (TensorExt_RotateOp $t2, $i2), $ovf1), | ||
(TensorExt_RotateOp $t3, $i3), | ||
$ovf2), | ||
[ | ||
(TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $i3, DefOverflow)), | ||
(TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $i3, DefOverflow)), | ||
(InnerOp:$addResult $r1, $r2, $ovf1), | ||
(OuterOp:$output $addResult, $t3, $ovf2), | ||
// Downstream ops are not updated by this pass, so we need to preserve the original | ||
// rotation and then clean it up in a separate canonicalization pattern. | ||
(TensorExt_RotateOp $output, $i3), | ||
] | ||
>; | ||
|
||
// Right associated grouping handles (add (rotate t1 i1) (add (rotate t2 i2) (rotate t3 i3))) | ||
def AlignRotations_RightAssociated_Inner_#InnerOp#_Outer_#OuterOp : Pattern< | ||
(OuterOp | ||
(TensorExt_RotateOp $t3, $i3), | ||
(InnerOp (TensorExt_RotateOp $t1, $i1), (TensorExt_RotateOp $t2, $i2), $ovf1), | ||
$ovf2), | ||
[ | ||
(TensorExt_RotateOp:$r1 $t1, (Arith_SubIOp $i1, $i3, DefOverflow)), | ||
(TensorExt_RotateOp:$r2 $t2, (Arith_SubIOp $i2, $i3, DefOverflow)), | ||
(InnerOp:$addResult $r1, $r2, $ovf1), | ||
(OuterOp:$output $addResult, $t3, $ovf2), | ||
// Downstream ops are not updated by this pass, so we need to preserve the original | ||
// rotation and then clean it up in a separate canonicalization pattern. | ||
(TensorExt_RotateOp $output, $i3), | ||
] | ||
>; | ||
} | ||
} | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_IR_TENSOREXTPATTERNS_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#ifndef INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ | ||
#define INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ | ||
|
||
#include "include/Dialect/TensorExt/IR/TensorExtDialect.h" | ||
#include "include/Dialect/TensorExt/Transforms/CollapseInsertionChains.h" | ||
#include "include/Dialect/TensorExt/Transforms/InsertRotate.h" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace tensor_ext { | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "include/Dialect/TensorExt/Transforms/Passes.h.inc" | ||
|
||
} // namespace tensor_ext | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_DIALECT_TENSOREXT_TRANSFORMS_PASSES_H_ |
Oops, something went wrong.