-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][UB] Add ub.unreachable canonicalization
#169873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][UB] Add ub.unreachable canonicalization
#169873
Conversation
|
@llvm/pr-subscribers-mlir-ub @llvm/pr-subscribers-mlir-cf Author: Matthias Springer (matthias-springer) ChangesBasic blocks with a Depends on #169872. Full diff: https://github.com/llvm/llvm-project/pull/169873.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index a441fd82546e3..c9b4da44ffa01 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ControlFlow_Dialect : Dialect {
let name = "cf";
let cppNamespace = "::mlir::cf";
- let dependentDialects = ["arith::ArithDialect"];
+ let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
let description = [{
This dialect contains low-level, i.e. non-region based, control flow
constructs. These constructs generally represent control flow directly
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..02081e2d6d15f 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -9,6 +9,10 @@
#ifndef MLIR_DIALECT_UB_IR_OPS_H
#define MLIR_DIALECT_UB_IR_OPS_H
+namespace mlir {
+class PatternRewriter;
+}
+
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index 8a354da2db10c..c1d74290ec174 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -84,6 +84,7 @@ def UnreachableOp : UB_Op<"unreachable", [Terminator]> {
}];
let assemblyFormat = "attr-dict";
+ let hasCanonicalizeMethod = 1;
}
#endif // MLIR_DIALECT_UB_IR_UBOPS_TD
diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
index 58551bb435c86..05a787fa53ec3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect
MLIRControlFlowInterfaces
MLIRIR
MLIRSideEffectInterfaces
+ MLIRUBDialect
)
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index f1da1a125e9ef..aabf8930cf78e 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -445,6 +446,35 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
return success(replaced);
}
};
+
+struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ // If the "true" destination has unreachable an unreachable terminator,
+ // always branch to the "false" destination.
+ Block *trueDest = condbr.getTrueDest();
+ Block *falseDest = condbr.getFalseDest();
+ if (llvm::hasSingleElement(*trueDest) &&
+ isa<ub::UnreachableOp>(trueDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
+ condbr.getFalseOperands());
+ return success();
+ }
+
+ // If the "false" destination has unreachable an unreachable terminator,
+ // always branch to the "true" destination.
+ if (llvm::hasSingleElement(*falseDest) &&
+ isa<ub::UnreachableOp>(falseDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
+ condbr.getTrueOperands());
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -452,7 +482,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
- CondBranchTruthPropagation>(context);
+ CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
}
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..419e3f9d76fb2 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc"
@@ -57,8 +58,33 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// PoisonOp
+//===----------------------------------------------------------------------===//
+
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
+//===----------------------------------------------------------------------===//
+// UnreachableOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreachableOp::canonicalize(UnreachableOp unreachableOp,
+ PatternRewriter &rewriter) {
+ Block *block = unreachableOp->getBlock();
+ if (llvm::hasSingleElement(*block))
+ return rewriter.notifyMatchFailure(
+ unreachableOp, "unreachable op is the only operation in the block");
+
+ // Erase all other operations in the block. They must be dead.
+ for (Operation &op : llvm::make_early_inc_range(*block)) {
+ if (&op == unreachableOp.getOperation())
+ continue;
+ op.dropAllUses();
+ rewriter.eraseOp(&op);
+ }
+ return success();
+}
+
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 17f7d28ba59fb..75dec6dacde91 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -634,3 +634,28 @@ func.func @unsimplified_cycle_2(%c : i1) {
^bb7:
cf.br ^bb6
}
+
+// CHECK-LABEL: @drop_unreachable_branch_1
+// CHECK-NEXT: "test.foo"() : () -> ()
+// CHECK-NEXT: return
+func.func @drop_unreachable_branch_1(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ "test.foo"() : () -> ()
+ return
+^bb2:
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
+
+// CHECK-LABEL: @drop_unreachable_branch_2
+// CHECK-NEXT: ub.unreachable
+func.func @drop_unreachable_branch_2(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ "test.foo"() : () -> ()
+ ub.unreachable
+^bb2:
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
diff --git a/mlir/test/Dialect/UB/canonicalize.mlir b/mlir/test/Dialect/UB/canonicalize.mlir
index c3f286e49b09b..74ba9f1932384 100644
--- a/mlir/test/Dialect/UB/canonicalize.mlir
+++ b/mlir/test/Dialect/UB/canonicalize.mlir
@@ -9,3 +9,13 @@ func.func @merge_poison() -> (i32, i32) {
%1 = ub.poison : i32
return %0, %1 : i32, i32
}
+
+// -----
+
+// CHECK-LABEL: func @drop_ops_before_unreachable()
+// CHECK-NEXT: ub.unreachable
+func.func @drop_ops_before_unreachable() {
+ "test.foo"() : () -> ()
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
|
bc96208 to
561b6ca
Compare
| let name = "cf"; | ||
| let cppNamespace = "::mlir::cf"; | ||
| let dependentDialects = ["arith::ArithDialect"]; | ||
| let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Are you creating a ub dialect entity in the canonicalization? Just checking with a isa does not need it.
| } | ||
| }; | ||
|
|
||
| struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a small snippet/explanation summary of the transformation?
| continue; | ||
| op.dropAllUses(); | ||
| rewriter.eraseOp(&op); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should TODO that:
- this assumes we don't have calls that are "no return".
- this assumes that loops terminates.
- this assumes that nothing interrupts the control-flow (which is fine until early-exit is added).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should have a trait "AlwaysForwardProgress" on operations to allow this kind of transformations?
Basic blocks with a
ub.unreachableterminator are unreachable. This commit adds a canonicalization pattern that drops all preceding operations. This commit also adds a canonicalization pattern that folds tocf.cond_brtocf.brif one of the destination branches is unreachable.Depends on #169872.