13
13
// #include "mlir/IR/Dialect.h"
14
14
#include " mlir/IR/Region.h"
15
15
16
+ #include " mlir/Dialect/SCF/IR/SCF.h"
16
17
#include " mlir/IR/BuiltinTypes.h"
17
18
#include " mlir/IR/DialectImplementation.h"
18
19
#include " mlir/Interfaces/InferTypeOpInterface.h"
@@ -54,16 +55,34 @@ struct TestOpConversion : public OpConversionPattern<test_irdl_to_cpp::BeefOp> {
54
55
}
55
56
};
56
57
58
+ struct TestRegionConversion
59
+ : public OpConversionPattern<test_irdl_to_cpp::ConditionalOp> {
60
+ using OpConversionPattern::OpConversionPattern;
61
+
62
+ LogicalResult
63
+ matchAndRewrite (mlir::test_irdl_to_cpp::ConditionalOp op, OpAdaptor adaptor,
64
+ ConversionPatternRewriter &rewriter) const override {
65
+ // Just exercising the C++ API even though these are not enforced in the
66
+ // dialect definition
67
+ assert (op.getThen ().getBlocks ().size () == 1 );
68
+ assert (adaptor.getElse ().getBlocks ().size () == 1 );
69
+ auto ifOp = scf::IfOp::create (rewriter, op.getLoc (), op.getInput ());
70
+ rewriter.replaceOp (op, ifOp);
71
+ return success ();
72
+ }
73
+ };
74
+
57
75
struct ConvertTestDialectToSomethingPass
58
76
: PassWrapper<ConvertTestDialectToSomethingPass, OperationPass<ModuleOp>> {
59
77
void runOnOperation () override {
60
78
MLIRContext *ctx = &getContext ();
61
79
RewritePatternSet patterns (ctx);
62
- patterns.add <TestOpConversion>(ctx);
80
+ patterns.add <TestOpConversion, TestRegionConversion >(ctx);
63
81
ConversionTarget target (getContext ());
64
- target.addIllegalOp <test_irdl_to_cpp::BeefOp>();
65
- target.addLegalOp <test_irdl_to_cpp::BarOp>();
66
- target.addLegalOp <test_irdl_to_cpp::HashOp>();
82
+ target.addIllegalOp <test_irdl_to_cpp::BeefOp,
83
+ test_irdl_to_cpp::ConditionalOp>();
84
+ target.addLegalOp <test_irdl_to_cpp::BarOp, test_irdl_to_cpp::HashOp,
85
+ scf::IfOp, scf::YieldOp>();
67
86
if (failed (applyPartialConversion (getOperation (), target,
68
87
std::move (patterns))))
69
88
signalPassFailure ();
@@ -73,6 +92,10 @@ struct ConvertTestDialectToSomethingPass
73
92
StringRef getDescription () const final {
74
93
return " Checks the convertability of an irdl dialect" ;
75
94
}
95
+
96
+ void getDependentDialects (DialectRegistry ®istry) const override {
97
+ registry.insert <scf::SCFDialect>();
98
+ }
76
99
};
77
100
78
101
void registerIrdlTestDialect (mlir::DialectRegistry ®istry) {
0 commit comments