diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td index a441fd82546e3..c9b4da44ffa01 100644 --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" def ControlFlow_Dialect : Dialect { let name = "cf"; let cppNamespace = "::mlir::cf"; - let dependentDialects = ["arith::ArithDialect"]; + let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"]; let description = [{ This dialect contains low-level, i.e. non-region based, control flow constructs. These constructs generally represent control flow directly diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h index 21de5cb0c182a..02081e2d6d15f 100644 --- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h +++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h @@ -9,6 +9,10 @@ #ifndef MLIR_DIALECT_UB_IR_OPS_H #define MLIR_DIALECT_UB_IR_OPS_H +namespace mlir { +class PatternRewriter; +} + #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td index c400a2ef2cc7a..c1d74290ec174 100644 --- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td +++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td @@ -66,4 +66,25 @@ def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// UnreachableOp +//===----------------------------------------------------------------------===// + +def UnreachableOp : UB_Op<"unreachable", [Terminator]> { + let summary = "Unreachable operation."; + let description = [{ + The `unreachable` operation has no defined semantics. This operation + indicates that its enclosing basic block is not reachable. + + Example: + + ``` + ub.unreachable + ``` + }]; + + let assemblyFormat = "attr-dict"; + let hasCanonicalizeMethod = 1; +} + #endif // MLIR_DIALECT_UB_IR_UBOPS_TD diff --git a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp index 9921a06778dd7..feb04899cb33d 100644 --- a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp +++ b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp @@ -23,8 +23,11 @@ namespace mlir { using namespace mlir; -namespace { +//===----------------------------------------------------------------------===// +// PoisonOpLowering +//===----------------------------------------------------------------------===// +namespace { struct PoisonOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -32,13 +35,8 @@ struct PoisonOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; - } // namespace -//===----------------------------------------------------------------------===// -// PoisonOpLowering -//===----------------------------------------------------------------------===// - LogicalResult PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -60,6 +58,29 @@ PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, return success(); } +//===----------------------------------------------------------------------===// +// UnreachableOpLowering +//===----------------------------------------------------------------------===// + +namespace { +struct UnreachableOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ub::UnreachableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace +LogicalResult + +UnreachableOpLowering::matchAndRewrite( + ub::UnreachableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op); + return success(); +} + //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -93,7 +114,7 @@ struct UBToLLVMConversionPass void mlir::ub::populateUBToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter); + patterns.add(converter); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp index 244d214cba196..3831387816eaf 100644 --- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp +++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp @@ -40,6 +40,17 @@ struct PoisonOpLowering final : OpConversionPattern { } }; +struct UnreachableOpLowering final : OpConversionPattern { + using Base::Base; + + LogicalResult + matchAndRewrite(ub::UnreachableOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -75,5 +86,6 @@ struct UBToSPIRVConversionPass final void mlir::ub::populateUBToSPIRVConversionPatterns( const SPIRVTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter, patterns.getContext()); + patterns.add(converter, + patterns.getContext()); } diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt index 58551bb435c86..05a787fa53ec3 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt @@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect MLIRControlFlowInterfaces MLIRIR MLIRSideEffectInterfaces + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index f1da1a125e9ef..218758bc0aac5 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -445,6 +446,35 @@ struct CondBranchTruthPropagation : public OpRewritePattern { return success(replaced); } }; + +struct DropUnreachableCondBranch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // If the "true" destination is unreachable, branch to the "false" + // destination. + Block *trueDest = condbr.getTrueDest(); + Block *falseDest = condbr.getFalseDest(); + if (llvm::hasSingleElement(*trueDest) && + isa(trueDest->getTerminator())) { + rewriter.replaceOpWithNewOp(condbr, falseDest, + condbr.getFalseOperands()); + return success(); + } + + // If the "false" destination is unreachable, branch to the "true" + // destination. + if (llvm::hasSingleElement(*falseDest) && + isa(falseDest->getTerminator())) { + rewriter.replaceOpWithNewOp(condbr, trueDest, + condbr.getTrueOperands()); + return success(); + } + + return failure(); + } +}; } // namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -452,7 +482,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); + CondBranchTruthPropagation, DropUnreachableCondBranch>(context); } SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp index ee523f9522953..419e3f9d76fb2 100644 --- a/mlir/lib/Dialect/UB/IR/UBOps.cpp +++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc" @@ -57,8 +58,33 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value, return nullptr; } +//===----------------------------------------------------------------------===// +// PoisonOp +//===----------------------------------------------------------------------===// + OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); } +//===----------------------------------------------------------------------===// +// UnreachableOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreachableOp::canonicalize(UnreachableOp unreachableOp, + PatternRewriter &rewriter) { + Block *block = unreachableOp->getBlock(); + if (llvm::hasSingleElement(*block)) + return rewriter.notifyMatchFailure( + unreachableOp, "unreachable op is the only operation in the block"); + + // Erase all other operations in the block. They must be dead. + for (Operation &op : llvm::make_early_inc_range(*block)) { + if (&op == unreachableOp.getOperation()) + continue; + op.dropAllUses(); + rewriter.eraseOp(&op); + } + return success(); +} + #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc" #define GET_ATTRDEF_CLASSES diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir index 6c0b111d4c2c5..0fe63f5a3a89f 100644 --- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir +++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir @@ -17,3 +17,9 @@ func.func @check_poison() { %3 = ub.poison : !llvm.ptr return } + +// CHECK-LABEL: @check_unrechable +func.func @check_unrechable() { +// CHECK: llvm.unreachable + ub.unreachable +} diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir index f497eb3bc552c..edbe8b8001bba 100644 --- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir +++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir @@ -19,3 +19,18 @@ func.func @check_poison() { } } + +// ----- + +// No successful test because the dialect conversion framework does not convert +// unreachable blocks. + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { +func.func @check_unrechable() { +// expected-error@+1{{cannot be used in reachable block}} + spirv.Unreachable +} +} diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir index 17f7d28ba59fb..75dec6dacde91 100644 --- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir +++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir @@ -634,3 +634,28 @@ func.func @unsimplified_cycle_2(%c : i1) { ^bb7: cf.br ^bb6 } + +// CHECK-LABEL: @drop_unreachable_branch_1 +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: return +func.func @drop_unreachable_branch_1(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + "test.foo"() : () -> () + return +^bb2: + "test.bar"() : () -> () + ub.unreachable +} + +// CHECK-LABEL: @drop_unreachable_branch_2 +// CHECK-NEXT: ub.unreachable +func.func @drop_unreachable_branch_2(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + "test.foo"() : () -> () + ub.unreachable +^bb2: + "test.bar"() : () -> () + ub.unreachable +} diff --git a/mlir/test/Dialect/UB/canonicalize.mlir b/mlir/test/Dialect/UB/canonicalize.mlir index c3f286e49b09b..74ba9f1932384 100644 --- a/mlir/test/Dialect/UB/canonicalize.mlir +++ b/mlir/test/Dialect/UB/canonicalize.mlir @@ -9,3 +9,13 @@ func.func @merge_poison() -> (i32, i32) { %1 = ub.poison : i32 return %0, %1 : i32, i32 } + +// ----- + +// CHECK-LABEL: func @drop_ops_before_unreachable() +// CHECK-NEXT: ub.unreachable +func.func @drop_ops_before_unreachable() { + "test.foo"() : () -> () + "test.bar"() : () -> () + ub.unreachable +} diff --git a/mlir/test/Dialect/UB/ops.mlir b/mlir/test/Dialect/UB/ops.mlir index 724b6b4caac5d..730c1bd1380b8 100644 --- a/mlir/test/Dialect/UB/ops.mlir +++ b/mlir/test/Dialect/UB/ops.mlir @@ -38,3 +38,9 @@ func.func @poison_tensor() -> tensor<8x?xf64> { %0 = ub.poison : tensor<8x?xf64> return %0 : tensor<8x?xf64> } + +// CHECK-LABEL: func @unreachable() +// CHECK: ub.unreachable +func.func @unreachable() { + ub.unreachable +}