Skip to content

Commit 8c3273c

Browse files
committed
Add support for regions in irdl-to-cpp
1 parent ad9d551 commit 8c3273c

File tree

3 files changed

+157
-28
lines changed

3 files changed

+157
-28
lines changed

mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct OpStrings {
5454
std::string opCppName;
5555
SmallVector<std::string> opResultNames;
5656
SmallVector<std::string> opOperandNames;
57+
SmallVector<std::string> opRegionNames;
5758
};
5859

5960
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
8788
/// Generates OpStrings from an OperatioOp
8889
static OpStrings getStrings(irdl::OperationOp op) {
8990
auto operandOp = op.getOp<irdl::OperandsOp>();
90-
9191
auto resultOp = op.getOp<irdl::ResultsOp>();
92+
auto regionsOp = op.getOp<irdl::RegionsOp>();
9293

9394
OpStrings strings;
9495
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
108109
}));
109110
}
110111

112+
if (regionsOp) {
113+
strings.opRegionNames = SmallVector<std::string>(
114+
llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
115+
return llvm::formatv("{0}", cast<StringAttr>(attr));
116+
}));
117+
}
118+
111119
return strings;
112120
}
113121

@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
122130
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
123131
const auto operandCount = strings.opOperandNames.size();
124132
const auto resultCount = strings.opResultNames.size();
133+
const auto regionCount = strings.opRegionNames.size();
125134

126135
dict["OP_NAME"] = strings.opName;
127136
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,14 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
131140
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
132141
dict["OP_RESULT_INITIALIZER_LIST"] =
133142
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
143+
dict["OP_REGION_COUNT"] = std::to_string(regionCount);
144+
dict["OP_ADD_REGIONS"] =
145+
regionCount ? std::string(llvm::formatv(
146+
R"(for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i)
147+
(void)odsState.addRegion();
148+
)",
149+
regionCount))
150+
: "";
134151
}
135152

136153
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +196,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
179196
const OpStrings &opStrings) {
180197
auto opGetters = std::string{};
181198
auto resGetters = std::string{};
199+
auto regionGetters = std::string{};
200+
auto regionAdaptorGetters = std::string{};
182201

183202
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
184203
const auto op =
@@ -196,8 +215,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
196215
op, i);
197216
}
198217

218+
for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
219+
const auto op =
220+
llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
221+
regionAdaptorGetters += llvm::formatv(
222+
R"(::mlir::Region &get{0}() { return getRegions()[{1}]; }
223+
)",
224+
op, i);
225+
regionGetters += llvm::formatv(
226+
R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
227+
)",
228+
op, i);
229+
}
230+
199231
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
200232
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
233+
dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
234+
dict["OP_REGION_GETTER_DECLS"] = regionGetters;
201235
}
202236

203237
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -301,6 +335,79 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
301335
return success();
302336
}
303337

338+
// add traits to the dictionary, return true if any were added
339+
static std::vector<std::string> generateTraits(irdl::detail::dictionary &dict,
340+
irdl::OperationOp op,
341+
const OpStrings &strings) {
342+
std::vector<std::string> cppTraitNames;
343+
if (!strings.opRegionNames.empty()) {
344+
cppTraitNames.push_back(
345+
llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
346+
strings.opRegionNames.size())
347+
.str());
348+
cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
349+
}
350+
return cppTraitNames;
351+
}
352+
353+
static void generateVerifiers(irdl::detail::dictionary &dict,
354+
irdl::OperationOp op, const OpStrings &strings) {
355+
std::vector<std::string> verifierHelpers;
356+
std::vector<std::string> verifierCalls;
357+
if (strings.opRegionNames.empty()) {
358+
// Currently IRDL regions are the only reason to generate a verifier,
359+
// though this will likely change as
360+
// https://github.com/llvm/llvm-project/issues/158040 is implemented
361+
return;
362+
}
363+
364+
for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
365+
std::string regionName = strings.opRegionNames[i];
366+
std::string helperFnName =
367+
llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
368+
strings.opCppName, regionName)
369+
.str();
370+
std::string condition = "true"; // FIXME: get from irdl region op
371+
std::string textualConditionName =
372+
"any region"; // FIXME: can use textual irdl for this region?
373+
374+
verifierHelpers.push_back(llvm::formatv(
375+
R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
376+
if (!{1}) {{
377+
return op->emitOpError("region #") << regionIndex
378+
<< (regionName.empty() ? " " : " ('" + regionName + "')
379+
<< "failed to verify constraint: {2}";
380+
}}
381+
return ::mlir::success();
382+
}})",
383+
helperFnName, condition));
384+
385+
verifierCalls.push_back(llvm::formatv(R"(
386+
if (::mlir::failed({0}(*this, (*this)->getRegion({1}), {2}, {1})))
387+
return ::mlir::failure();)",
388+
helperFnName, i, regionName)
389+
.str());
390+
}
391+
392+
// Add an overall verifier that sequences the helper calls
393+
std::string verifierDef =
394+
llvm::formatv(R"(
395+
::llvm::LogicalResult {0}::verifyInvariantsImpl() {
396+
{1}
397+
return ::mlir::success();
398+
}
399+
400+
::llvm::LogicalResult {0}::verifyInvariants() {
401+
if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify()))
402+
return ::mlir::success();
403+
return ::mlir::failure();
404+
})",
405+
strings.opCppName, llvm::join(verifierCalls, "\n"));
406+
407+
dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
408+
dict["OP_VERIFIER"] = verifierDef;
409+
}
410+
304411
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
305412
irdl::OperationOp op) {
306413
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +477,15 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
370477

371478
dict["OP_BUILD_DEFS"] = buildDefinition;
372479

480+
std::vector<std::string> traitNames = generateTraits(dict, op, opStrings);
481+
if (traitNames.empty())
482+
dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
483+
else
484+
dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
485+
llvm::join(traitNames, ", "));
486+
487+
generateVerifiers(dict, op, opStrings);
488+
373489
std::string str;
374490
llvm::raw_string_ostream stream{str};
375491
perOpDefTemplate.render(stream, dict);
@@ -427,7 +543,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
427543
dict["TYPE_PARSER"] = llvm::formatv(
428544
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
429545
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
430-
{0}
546+
{0}
431547
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
432548
*mnemonic = keyword;
433549
return std::nullopt;
@@ -520,6 +636,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
520636
"IRDL C++ translation does not yet support variadic results");
521637
}))
522638
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
639+
.Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
640+
.Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
523641
.Default([](mlir::Operation *op) -> LogicalResult {
524642
return op->emitError("IRDL C++ translation does not yet support "
525643
"translation of ")

mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ public:
1212
struct Properties {
1313
};
1414
public:
15-
__OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
16-
: odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
17-
odsRegions(op->getRegions())
15+
__OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
16+
: odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
17+
odsRegions(op->getRegions())
1818
{}
1919

2020
/// Return the unstructured operand index of a structured operand along with
2121
// the amount of unstructured operands it contains.
2222
std::pair<unsigned, unsigned>
23-
getStructuredOperandIndexAndLength (unsigned index,
23+
getStructuredOperandIndexAndLength (unsigned index,
2424
unsigned odsOperandsSize) {
2525
return {index, 1};
2626
}
@@ -32,6 +32,12 @@ public:
3232
::mlir::DictionaryAttr getAttributes() {
3333
return odsAttrs;
3434
}
35+
36+
__OP_REGION_ADAPTER_GETTER_DECLS__
37+
38+
::mlir::RegionRange getRegions() {
39+
return odsRegions;
40+
}
3541
protected:
3642
::mlir::DictionaryAttr odsAttrs;
3743
::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
4248
} // namespace detail
4349

4450
template <typename RangeT>
45-
class __OP_CPP_NAME__GenericAdaptor
51+
class __OP_CPP_NAME__GenericAdaptor
4652
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
4753
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
4854
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
4955
public:
5056
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
51-
::mlir::OpaqueProperties properties,
52-
::mlir::RegionRange regions = {})
53-
: __OP_CPP_NAME__GenericAdaptor(values, attrs,
54-
(properties ? *properties.as<::mlir::EmptyProperties *>()
57+
::mlir::OpaqueProperties properties,
58+
::mlir::RegionRange regions = {})
59+
: __OP_CPP_NAME__GenericAdaptor(values, attrs,
60+
(properties ? *properties.as<::mlir::EmptyProperties *>()
5561
: ::mlir::EmptyProperties{}), regions) {}
5662

57-
__OP_CPP_NAME__GenericAdaptor(RangeT values,
63+
__OP_CPP_NAME__GenericAdaptor(RangeT values,
5864
const __OP_CPP_NAME__GenericAdaptorBase &base)
5965
: Base(base), odsOperands(values) {}
6066

61-
// This template parameter allows using __OP_CPP_NAME__ which is declared
67+
// This template parameter allows using __OP_CPP_NAME__ which is declared
6268
// later.
6369
template <typename LateInst = __OP_CPP_NAME__,
6470
typename = std::enable_if_t<
6571
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
66-
__OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
72+
__OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
6773
: Base(op), odsOperands(values) {}
6874

6975
/// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
7783
RangeT getStructuredOperands(unsigned index) {
7884
auto valueRange = getStructuredOperandIndexAndLength(index);
7985
return {std::next(odsOperands.begin(), valueRange.first),
80-
std::next(odsOperands.begin(),
86+
std::next(odsOperands.begin(),
8187
valueRange.first + valueRange.second)};
8288
}
8389

@@ -91,7 +97,7 @@ private:
9197
RangeT odsOperands;
9298
};
9399

94-
class __OP_CPP_NAME__Adaptor
100+
class __OP_CPP_NAME__Adaptor
95101
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
96102
public:
97103
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
100106
::llvm::LogicalResult verify(::mlir::Location loc);
101107
};
102108

103-
class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
109+
class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
104110
public:
105111
using Op::Op;
106112
using Op::print;
@@ -147,7 +153,7 @@ public:
147153
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
148154
auto valueRange = getStructuredOperandIndexAndLength(index);
149155
return {std::next(getOperation()->operand_begin(), valueRange.first),
150-
std::next(getOperation()->operand_begin(),
156+
std::next(getOperation()->operand_begin(),
151157
valueRange.first + valueRange.second)};
152158
}
153159

@@ -162,18 +168,19 @@ public:
162168
::mlir::Operation::result_range getStructuredResults(unsigned index) {
163169
auto valueRange = getStructuredResultIndexAndLength(index);
164170
return {std::next(getOperation()->result_begin(), valueRange.first),
165-
std::next(getOperation()->result_begin(),
171+
std::next(getOperation()->result_begin(),
166172
valueRange.first + valueRange.second)};
167173
}
168174

169175
__OP_OPERAND_GETTER_DECLS__
170176
__OP_RESULT_GETTER_DECLS__
171-
177+
__OP_REGION_GETTER_DECLS__
178+
172179
__OP_BUILD_DECLS__
173-
static void build(::mlir::OpBuilder &odsBuilder,
174-
::mlir::OperationState &odsState,
175-
::mlir::TypeRange resultTypes,
176-
::mlir::ValueRange operands,
180+
static void build(::mlir::OpBuilder &odsBuilder,
181+
::mlir::OperationState &odsState,
182+
::mlir::TypeRange resultTypes,
183+
::mlir::ValueRange operands,
177184
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
178185

179186
static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,

mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,22 @@ R"(
66

77
__NAMESPACE_OPEN__
88

9+
__OP_VERIFIER_HELPERS__
10+
911
__OP_BUILD_DEFS__
1012

11-
void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
12-
::mlir::OperationState &odsState,
13-
::mlir::TypeRange resultTypes,
14-
::mlir::ValueRange operands,
13+
void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
14+
::mlir::OperationState &odsState,
15+
::mlir::TypeRange resultTypes,
16+
::mlir::ValueRange operands,
1517
::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
1618
{
1719
assert(operands.size() == __OP_OPERAND_COUNT__);
1820
assert(resultTypes.size() == __OP_RESULT_COUNT__);
1921
odsState.addOperands(operands);
2022
odsState.addAttributes(attributes);
2123
odsState.addTypes(resultTypes);
24+
__OP_ADD_REGIONS__
2225
}
2326

2427
__OP_CPP_NAME__
@@ -44,6 +47,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder,
4447
return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes);
4548
}
4649

50+
__OP_VERIFIER__
4751

4852
__NAMESPACE_CLOSE__
4953

0 commit comments

Comments
 (0)