@@ -54,6 +54,7 @@ struct OpStrings {
54
54
std::string opCppName;
55
55
SmallVector<std::string> opResultNames;
56
56
SmallVector<std::string> opOperandNames;
57
+ SmallVector<std::string> opRegionNames;
57
58
};
58
59
59
60
static std::string joinNameList (llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
87
88
// / Generates OpStrings from an OperatioOp
88
89
static OpStrings getStrings (irdl::OperationOp op) {
89
90
auto operandOp = op.getOp <irdl::OperandsOp>();
90
-
91
91
auto resultOp = op.getOp <irdl::ResultsOp>();
92
+ auto regionsOp = op.getOp <irdl::RegionsOp>();
92
93
93
94
OpStrings strings;
94
95
strings.opName = op.getSymName ();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
108
109
}));
109
110
}
110
111
112
+ if (regionsOp) {
113
+ strings.opRegionNames = SmallVector<std::string>(
114
+ llvm::map_range (regionsOp->getNames (), [](Attribute attr) {
115
+ return llvm::formatv (" {0}" , cast<StringAttr>(attr));
116
+ }));
117
+ }
118
+
111
119
return strings;
112
120
}
113
121
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
122
130
static void fillDict (irdl::detail::dictionary &dict, const OpStrings &strings) {
123
131
const auto operandCount = strings.opOperandNames .size ();
124
132
const auto resultCount = strings.opResultNames .size ();
133
+ const auto regionCount = strings.opRegionNames .size ();
125
134
126
135
dict[" OP_NAME" ] = strings.opName ;
127
136
dict[" OP_CPP_NAME" ] = strings.opCppName ;
@@ -131,6 +140,14 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
131
140
operandCount ? joinNameList (strings.opOperandNames ) : " {\"\" }" ;
132
141
dict[" OP_RESULT_INITIALIZER_LIST" ] =
133
142
resultCount ? joinNameList (strings.opResultNames ) : " {\"\" }" ;
143
+ dict[" OP_REGION_COUNT" ] = std::to_string (regionCount);
144
+ dict[" OP_ADD_REGIONS" ] =
145
+ regionCount ? std::string (llvm::formatv (
146
+ R"( for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i)
147
+ (void)odsState.addRegion();
148
+ )" ,
149
+ regionCount))
150
+ : " " ;
134
151
}
135
152
136
153
// / Fills a dictionary with values from DialectStrings
@@ -179,6 +196,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
179
196
const OpStrings &opStrings) {
180
197
auto opGetters = std::string{};
181
198
auto resGetters = std::string{};
199
+ auto regionGetters = std::string{};
200
+ auto regionAdaptorGetters = std::string{};
182
201
183
202
for (size_t i = 0 , end = opStrings.opOperandNames .size (); i < end; ++i) {
184
203
const auto op =
@@ -196,8 +215,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
196
215
op, i);
197
216
}
198
217
218
+ for (size_t i = 0 , end = opStrings.opRegionNames .size (); i < end; ++i) {
219
+ const auto op =
220
+ llvm::convertToCamelFromSnakeCase (opStrings.opRegionNames [i], true );
221
+ regionAdaptorGetters += llvm::formatv (
222
+ R"( ::mlir::Region &get{0}() { return getRegions()[{1}]; }
223
+ )" ,
224
+ op, i);
225
+ regionGetters += llvm::formatv (
226
+ R"( ::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
227
+ )" ,
228
+ op, i);
229
+ }
230
+
199
231
dict[" OP_OPERAND_GETTER_DECLS" ] = opGetters;
200
232
dict[" OP_RESULT_GETTER_DECLS" ] = resGetters;
233
+ dict[" OP_REGION_ADAPTER_GETTER_DECLS" ] = regionAdaptorGetters;
234
+ dict[" OP_REGION_GETTER_DECLS" ] = regionGetters;
201
235
}
202
236
203
237
static void generateOpBuilderDeclarations (irdl::detail::dictionary &dict,
@@ -301,6 +335,79 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
301
335
return success ();
302
336
}
303
337
338
+ // add traits to the dictionary, return true if any were added
339
+ static std::vector<std::string> generateTraits (irdl::detail::dictionary &dict,
340
+ irdl::OperationOp op,
341
+ const OpStrings &strings) {
342
+ std::vector<std::string> cppTraitNames;
343
+ if (!strings.opRegionNames .empty ()) {
344
+ cppTraitNames.push_back (
345
+ llvm::formatv (" ::mlir::OpTrait::NRegions<{0}>::Impl" ,
346
+ strings.opRegionNames .size ())
347
+ .str ());
348
+ cppTraitNames.emplace_back (" ::mlir::OpTrait::OpInvariants" );
349
+ }
350
+ return cppTraitNames;
351
+ }
352
+
353
+ static void generateVerifiers (irdl::detail::dictionary &dict,
354
+ irdl::OperationOp op, const OpStrings &strings) {
355
+ std::vector<std::string> verifierHelpers;
356
+ std::vector<std::string> verifierCalls;
357
+ if (strings.opRegionNames .empty ()) {
358
+ // Currently IRDL regions are the only reason to generate a verifier,
359
+ // though this will likely change as
360
+ // https://github.com/llvm/llvm-project/issues/158040 is implemented
361
+ return ;
362
+ }
363
+
364
+ for (size_t i = 0 ; i < strings.opRegionNames .size (); ++i) {
365
+ std::string regionName = strings.opRegionNames [i];
366
+ std::string helperFnName =
367
+ llvm::formatv (" __mlir_irdl_local_region_constraint_{0}_{1}" ,
368
+ strings.opCppName , regionName)
369
+ .str ();
370
+ std::string condition = " true" ; // FIXME: get from irdl region op
371
+ std::string textualConditionName =
372
+ " any region" ; // FIXME: can use textual irdl for this region?
373
+
374
+ verifierHelpers.push_back (llvm::formatv (
375
+ R"( static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{
376
+ if (!{1}) {{
377
+ return op->emitOpError("region #") << regionIndex
378
+ << (regionName.empty() ? " " : " ('" + regionName + "')
379
+ << "failed to verify constraint: {2}";
380
+ }}
381
+ return ::mlir::success();
382
+ }})" ,
383
+ helperFnName, condition));
384
+
385
+ verifierCalls.push_back (llvm::formatv (R"(
386
+ if (::mlir::failed({0}(*this, (*this)->getRegion({1}), {2}, {1})))
387
+ return ::mlir::failure();)" ,
388
+ helperFnName, i, regionName)
389
+ .str ());
390
+ }
391
+
392
+ // Add an overall verifier that sequences the helper calls
393
+ std::string verifierDef =
394
+ llvm::formatv (R"(
395
+ ::llvm::LogicalResult {0}::verifyInvariantsImpl() {
396
+ {1}
397
+ return ::mlir::success();
398
+ }
399
+
400
+ ::llvm::LogicalResult {0}::verifyInvariants() {
401
+ if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify()))
402
+ return ::mlir::success();
403
+ return ::mlir::failure();
404
+ })" ,
405
+ strings.opCppName , llvm::join (verifierCalls, " \n " ));
406
+
407
+ dict[" OP_VERIFIER_HELPERS" ] = llvm::join (verifierHelpers, " \n " );
408
+ dict[" OP_VERIFIER" ] = verifierDef;
409
+ }
410
+
304
411
static std::string generateOpDefinition (irdl::detail::dictionary &dict,
305
412
irdl::OperationOp op) {
306
413
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +477,15 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
370
477
371
478
dict[" OP_BUILD_DEFS" ] = buildDefinition;
372
479
480
+ std::vector<std::string> traitNames = generateTraits (dict, op, opStrings);
481
+ if (traitNames.empty ())
482
+ dict[" OP_TEMPLATE_ARGS" ] = opStrings.opCppName ;
483
+ else
484
+ dict[" OP_TEMPLATE_ARGS" ] = llvm::formatv (" {0}, {1}" , opStrings.opCppName ,
485
+ llvm::join (traitNames, " , " ));
486
+
487
+ generateVerifiers (dict, op, opStrings);
488
+
373
489
std::string str;
374
490
llvm::raw_string_ostream stream{str};
375
491
perOpDefTemplate.render (stream, dict);
@@ -427,7 +543,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
427
543
dict[" TYPE_PARSER" ] = llvm::formatv (
428
544
R"( static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
429
545
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
430
- {0}
546
+ {0}
431
547
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
432
548
*mnemonic = keyword;
433
549
return std::nullopt;
@@ -520,6 +636,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
520
636
" IRDL C++ translation does not yet support variadic results" );
521
637
}))
522
638
.Case <irdl::AnyOp>(([](irdl::AnyOp) { return success (); }))
639
+ .Case <irdl::RegionOp>(([](irdl::RegionOp) { return success (); }))
640
+ .Case <irdl::RegionsOp>(([](irdl::RegionsOp) { return success (); }))
523
641
.Default ([](mlir::Operation *op) -> LogicalResult {
524
642
return op->emitError (" IRDL C++ translation does not yet support "
525
643
" translation of " )
0 commit comments