Skip to content

Commit bc96208

Browse files
[mlir][UB] Add ub.unreachable canonicalization
1 parent a47b28c commit bc96208

File tree

8 files changed

+99
-2
lines changed

8 files changed

+99
-2
lines changed

mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
2222
def ControlFlow_Dialect : Dialect {
2323
let name = "cf";
2424
let cppNamespace = "::mlir::cf";
25-
let dependentDialects = ["arith::ArithDialect"];
25+
let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
2626
let description = [{
2727
This dialect contains low-level, i.e. non-region based, control flow
2828
constructs. These constructs generally represent control flow directly

mlir/include/mlir/Dialect/UB/IR/UBOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#ifndef MLIR_DIALECT_UB_IR_OPS_H
1010
#define MLIR_DIALECT_UB_IR_OPS_H
1111

12+
namespace mlir {
13+
class PatternRewriter;
14+
}
15+
1216
#include "mlir/Bytecode/BytecodeOpInterface.h"
1317
#include "mlir/IR/Dialect.h"
1418
#include "mlir/IR/OpImplementation.h"

mlir/include/mlir/Dialect/UB/IR/UBOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def UnreachableOp : UB_Op<"unreachable", [Terminator]> {
8484
}];
8585

8686
let assemblyFormat = "attr-dict";
87+
let hasCanonicalizeMethod = 1;
8788
}
8889

8990
#endif // MLIR_DIALECT_UB_IR_UBOPS_TD

mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect
1212
MLIRControlFlowInterfaces
1313
MLIRIR
1414
MLIRSideEffectInterfaces
15+
MLIRUBDialect
1516
)

mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
1313
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
1414
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15+
#include "mlir/Dialect/UB/IR/UBOps.h"
1516
#include "mlir/IR/AffineExpr.h"
1617
#include "mlir/IR/AffineMap.h"
1718
#include "mlir/IR/Builders.h"
@@ -445,14 +446,43 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
445446
return success(replaced);
446447
}
447448
};
449+
450+
struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
451+
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
452+
453+
LogicalResult matchAndRewrite(CondBranchOp condbr,
454+
PatternRewriter &rewriter) const override {
455+
// If the "true" destination has unreachable an unreachable terminator,
456+
// always branch to the "false" destination.
457+
Block *trueDest = condbr.getTrueDest();
458+
Block *falseDest = condbr.getFalseDest();
459+
if (llvm::hasSingleElement(*trueDest) &&
460+
isa<ub::UnreachableOp>(trueDest->getTerminator())) {
461+
rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
462+
condbr.getFalseOperands());
463+
return success();
464+
}
465+
466+
// If the "false" destination has unreachable an unreachable terminator,
467+
// always branch to the "true" destination.
468+
if (llvm::hasSingleElement(*falseDest) &&
469+
isa<ub::UnreachableOp>(falseDest->getTerminator())) {
470+
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
471+
condbr.getTrueOperands());
472+
return success();
473+
}
474+
475+
return failure();
476+
}
477+
};
448478
} // namespace
449479

450480
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
451481
MLIRContext *context) {
452482
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
453483
SimplifyCondBranchIdenticalSuccessors,
454484
SimplifyCondBranchFromCondBranchOnSameCondition,
455-
CondBranchTruthPropagation>(context);
485+
CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
456486
}
457487

458488
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {

mlir/lib/Dialect/UB/IR/UBOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/IR/Builders.h"
1414
#include "mlir/IR/DialectImplementation.h"
15+
#include "mlir/IR/PatternMatch.h"
1516
#include "llvm/ADT/TypeSwitch.h"
1617

1718
#include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc"
@@ -57,8 +58,33 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
5758
return nullptr;
5859
}
5960

61+
//===----------------------------------------------------------------------===//
62+
// PoisonOp
63+
//===----------------------------------------------------------------------===//
64+
6065
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
6166

67+
//===----------------------------------------------------------------------===//
68+
// UnreachableOp
69+
//===----------------------------------------------------------------------===//
70+
71+
LogicalResult UnreachableOp::canonicalize(UnreachableOp unreachableOp,
72+
PatternRewriter &rewriter) {
73+
Block *block = unreachableOp->getBlock();
74+
if (llvm::hasSingleElement(*block))
75+
return rewriter.notifyMatchFailure(
76+
unreachableOp, "unreachable op is the only operation in the block");
77+
78+
// Erase all other operations in the block. They must be dead.
79+
for (Operation &op : llvm::make_early_inc_range(*block)) {
80+
if (&op == unreachableOp.getOperation())
81+
continue;
82+
op.dropAllUses();
83+
rewriter.eraseOp(&op);
84+
}
85+
return success();
86+
}
87+
6288
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
6389

6490
#define GET_ATTRDEF_CLASSES

mlir/test/Dialect/ControlFlow/canonicalize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,28 @@ func.func @unsimplified_cycle_2(%c : i1) {
634634
^bb7:
635635
cf.br ^bb6
636636
}
637+
638+
// CHECK-LABEL: @drop_unreachable_branch_1
639+
// CHECK-NEXT: "test.foo"() : () -> ()
640+
// CHECK-NEXT: return
641+
func.func @drop_unreachable_branch_1(%c: i1) {
642+
cf.cond_br %c, ^bb1, ^bb2
643+
^bb1:
644+
"test.foo"() : () -> ()
645+
return
646+
^bb2:
647+
"test.bar"() : () -> ()
648+
ub.unreachable
649+
}
650+
651+
// CHECK-LABEL: @drop_unreachable_branch_2
652+
// CHECK-NEXT: ub.unreachable
653+
func.func @drop_unreachable_branch_2(%c: i1) {
654+
cf.cond_br %c, ^bb1, ^bb2
655+
^bb1:
656+
"test.foo"() : () -> ()
657+
ub.unreachable
658+
^bb2:
659+
"test.bar"() : () -> ()
660+
ub.unreachable
661+
}

mlir/test/Dialect/UB/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,13 @@ func.func @merge_poison() -> (i32, i32) {
99
%1 = ub.poison : i32
1010
return %0, %1 : i32, i32
1111
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: func @drop_ops_before_unreachable()
16+
// CHECK-NEXT: ub.unreachable
17+
func.func @drop_ops_before_unreachable() {
18+
"test.foo"() : () -> ()
19+
"test.bar"() : () -> ()
20+
ub.unreachable
21+
}

0 commit comments

Comments
 (0)