From 9b1890c7a2fe43388b7f7f4ec25aed0a45f20d3d Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Wed, 20 Nov 2024 17:36:19 +0000 Subject: [PATCH 1/4] [mlir][Affine] Split off delinearize parts that depend on last component If we have %0 = affine.linearize_index disjoint [%a, %b] by (A, B) %1:3 = affine.delinearize_index %0 into (A, B1, B2) where B = B1 * B2 (or some mor complex product), we can simplify this to %0 = affine.linearize_index disjoint [%a] by (A) %1a:1 = affine.delinearize_index %0 into (A) %1b:2 = affine.delinearize_index %b into (B1, B2) This, and more complex cases, prevent us from adding terms together only to divide them away from each other. --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 87 +++++++++++++++++++++- mlir/test/Dialect/Affine/canonicalize.mlir | 66 ++++++++++++++++ 2 files changed, 151 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 4cf07bc167eab..b13331abc32ad 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4694,13 +4694,96 @@ struct CancelDelinearizeOfLinearizeDisjointExact return success(); } }; + +/// If the input to a delinearization is a disjoint linearization, and the +/// last k > 1 components of the delinearization basis multiply to the +/// last component of the linearization basis, break the linearization and +/// delinearization into two parts, peeling off the last input to linearization. +/// +/// For example: +/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index +/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ... +/// becomes +/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index +/// %1:2 = affine.delinearize_index %0 by (2, 3) : index +/// %2:2 = affine.delinearize_index %x by (8, 4) : index +/// where the original %1:4 is replaced by %1:2 ++ %2:2 +struct SplitDelinearizeSpanningLastLinearizeArg final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp, + PatternRewriter &rewriter) const override { + auto linearizeOp = delinearizeOp.getLinearIndex() + .getDefiningOp(); + if (!linearizeOp) + return rewriter.notifyMatchFailure(delinearizeOp, + "index doesn't come from linearize"); + + if (!linearizeOp.getDisjoint()) + return rewriter.notifyMatchFailure(linearizeOp, + "linearize isn't disjoint"); + + int64_t target = linearizeOp.getStaticBasis().back(); + if (ShapedType::isDynamic(target)) + return rewriter.notifyMatchFailure( + linearizeOp, "linearize ends with dynamic basis value"); + + int64_t sizeToSplit = 1; + size_t elemsToSplit = 0; + ArrayRef basis = delinearizeOp.getStaticBasis(); + for (int64_t basisElem : llvm::reverse(basis)) { + if (ShapedType::isDynamic(basisElem)) + return rewriter.notifyMatchFailure( + delinearizeOp, "dynamic basis element while scanning for split"); + sizeToSplit *= basisElem; + elemsToSplit += 1; + + if (sizeToSplit > target) + return rewriter.notifyMatchFailure(delinearizeOp, + "overshot last argument size"); + if (sizeToSplit == target) + break; + } + + if (sizeToSplit < target) + return rewriter.notifyMatchFailure( + delinearizeOp, "product of known basis elements doesn't exceed last " + "linearize argument"); + + if (elemsToSplit < 2) + return rewriter.notifyMatchFailure( + delinearizeOp, "don't have a non-trivial basis product"); + + Value linearizeWithoutBack = + rewriter.create( + linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(), + linearizeOp.getDynamicBasis(), + linearizeOp.getStaticBasis().drop_back(), + linearizeOp.getDisjoint()); + auto delinearizeWithoutSplitPart = + rewriter.create( + delinearizeOp.getLoc(), linearizeWithoutBack, + delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit), + delinearizeOp.hasOuterBound()); + auto delinearizeBack = rewriter.create( + delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(), + basis.take_back(elemsToSplit), /*hasOuterBound=*/true); + SmallVector results = llvm::to_vector( + llvm::concat(delinearizeWithoutSplitPart.getResults(), + delinearizeBack.getResults())); + rewriter.replaceOp(delinearizeOp, results); + + return success(); + } +}; } // namespace void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns - .insert( - context); + .insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index b54a13cffe777..efeea7eb2af53 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1777,6 +1777,72 @@ func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1: // ----- +// CHECK-LABEL: func @split_delinearize_spanning_final_part +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) +// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4) +// CHECK: %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2) +// CHECK: %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8) +// CHECK: return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1 +func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) { + %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index + %1:4 = affine.delinearize_index %0 into (2, 8, 8) + : index, index, index, index + return %1#0, %1#1, %1#2, %1#3 : index, index, index, index +} + +// ----- + +// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) +// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8) +// CHECK: return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1 +func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) { + %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index + %1:4 = affine.delinearize_index %0 into (2, 4, 8, 8) + : index, index, index, index + return %1#0, %1#1, %1#2, %1#3 : index, index, index, index +} + +// ----- + +// The delinearize basis doesn't match the last basis element before +// overshooting it, don't simplify. +// CHECK-LABEL: func @dont_split_delinearize_overshooting_target +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index) +// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64) +// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8) +// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3 +func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) { + %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index + %1:4 = affine.delinearize_index %0 into (2, 16, 8) + : index, index, index, index + return %1#0, %1#1, %1#2, %1#3 : index, index, index, index +} + +// ----- + +// The delinearize basis doesn't fully multiply to the final basis element. +// CHECK-LABEL: func @dont_split_delinearize_undershooting_target +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64) +// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8) +// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1 +func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) { + %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index + %1:3 = affine.delinearize_index %0 into (4, 8) + : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} + +// ----- + // CHECK-LABEL: @linearize_unit_basis_disjoint // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index) // CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index From c90f3f3bd7381ff846046d6216231100708ef4c6 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Fri, 22 Nov 2024 18:23:54 +0000 Subject: [PATCH 2/4] clang-format --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index bb9f1d72e611c..28d27b0b2810f 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4816,8 +4816,10 @@ struct SplitDelinearizeSpanningLastLinearizeArg final void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - patterns.insert(context); + patterns + .insert( + context); } //===----------------------------------------------------------------------===// From 808e0a4bc26809da9ad400b747a5e67bf9ed844f Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 25 Nov 2024 13:41:05 -0600 Subject: [PATCH 3/4] Update debug message wording Co-authored-by: Abhishek Varma --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 28d27b0b2810f..ba259876ec18c 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4788,7 +4788,7 @@ struct SplitDelinearizeSpanningLastLinearizeArg final if (elemsToSplit < 2) return rewriter.notifyMatchFailure( - delinearizeOp, "don't have a non-trivial basis product"); + delinearizeOp, "need at least two elements to form the basis product"); Value linearizeWithoutBack = rewriter.create( From 62f24e096017465b3e771304887a7a65c79091b1 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Mon, 25 Nov 2024 19:42:52 +0000 Subject: [PATCH 4/4] Clang-format of suggestion --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index ba259876ec18c..1c5466730a558 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4788,7 +4788,8 @@ struct SplitDelinearizeSpanningLastLinearizeArg final if (elemsToSplit < 2) return rewriter.notifyMatchFailure( - delinearizeOp, "need at least two elements to form the basis product"); + delinearizeOp, + "need at least two elements to form the basis product"); Value linearizeWithoutBack = rewriter.create(