-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] [irdl] Add support for regions in irdl-to-cpp #158540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-irdl @llvm/pr-subscribers-mlir Author: Jeremy Kun (j2kun) ChangesFixes #158034 Very early draft, still working on it Full diff: https://github.com/llvm/llvm-project/pull/158540.diff 3 Files Affected:
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index d6b8a8a1df426..7d4a8d861a410 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -54,6 +54,7 @@ struct OpStrings {
std::string opCppName;
SmallVector<std::string> opResultNames;
SmallVector<std::string> opOperandNames;
+ SmallVector<std::string> opRegionNames;
};
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
/// Generates OpStrings from an OperatioOp
static OpStrings getStrings(irdl::OperationOp op) {
auto operandOp = op.getOp<irdl::OperandsOp>();
-
auto resultOp = op.getOp<irdl::ResultsOp>();
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
OpStrings strings;
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
}));
}
+ if (regionsOp) {
+ strings.opRegionNames = SmallVector<std::string>(
+ llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
+ return llvm::formatv("{0}", cast<StringAttr>(attr));
+ }));
+ }
+
return strings;
}
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
const auto operandCount = strings.opOperandNames.size();
const auto resultCount = strings.opResultNames.size();
+ const auto regionCount = strings.opRegionNames.size();
dict["OP_NAME"] = strings.opName;
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,14 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
dict["OP_RESULT_INITIALIZER_LIST"] =
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
+ dict["OP_REGION_COUNT"] = std::to_string(regionCount);
+ dict["OP_ADD_REGIONS"] =
+ regionCount ? std::string(llvm::formatv(
+ R"(for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i)
+ (void)odsState.addRegion();
+)",
+ regionCount))
+ : "";
}
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +196,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
const OpStrings &opStrings) {
auto opGetters = std::string{};
auto resGetters = std::string{};
+ auto regionGetters = std::string{};
+ auto regionAdaptorGetters = std::string{};
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
const auto op =
@@ -196,8 +215,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
op, i);
}
+ for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
+ const auto op =
+ llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
+ regionAdaptorGetters += llvm::formatv(
+ R"(::mlir::Region get{0}() { return getRegions()[{1}]; }
+ )",
+ op, i);
+ regionGetters += llvm::formatv(
+ R"(::mlir::Region get{0}() { return (*this)->getRegion({1}); }
+ )",
+ op, i);
+ }
+
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
+ dict["OP_REGION_ADAPTORBASE_GETTER_DECLS"] = regionAdaptorGetters;
+ dict["OP_REGION_GETTER_DECLS"] = regionGetters;
}
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -301,6 +335,79 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
return success();
}
+// add traits to the dictionary, return true if any were added
+static std::vector<std::string> generateTraits(irdl::detail::dictionary &dict,
+ irdl::OperationOp op,
+ const OpStrings &strings) {
+ std::vector<std::string> cppTraitNames;
+ if (!strings.opRegionNames.empty()) {
+ cppTraitNames.push_back(
+ llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
+ strings.opRegionNames.size())
+ .str());
+ cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
+ }
+ return cppTraitNames;
+}
+
+static void generateVerifiers(irdl::detail::dictionary &dict,
+ irdl::OperationOp op, const OpStrings &strings) {
+ std::vector<std::string> verifierHelpers;
+ std::vector<std::string> verifierCalls;
+ if (strings.opRegionNames.empty()) {
+ // Currently IRDL regions are the only reason to generate a verifier,
+ // though this will likely change as
+ // https://github.com/llvm/llvm-project/issues/158040 is implemented
+ return;
+ }
+
+ for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
+ std::string regionName = strings.opRegionNames[i];
+ std::string helperFnName =
+ llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
+ strings.opCppName, regionName)
+ .str();
+ std::string condition = "true"; // FIXME: get from irdl region op
+ std::string textualConditionName =
+ "any region"; // FIXME: can use textual irdl for this region?
+
+ verifierHelpers.push_back(llvm::formatv(
+ R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{
+ if (!{1}) {{
+ return op->emitOpError("region #") << regionIndex
+ << (regionName.empty() ? " " : " ('" + regionName + "')
+ << "failed to verify constraint: {2}";
+ }}
+ return ::mlir::success();
+ }})",
+ helperFnName, condition));
+
+ verifierCalls.push_back(llvm::formatv(R"(
+ if (::mlir::failed({0}(*this, (*this)->getRegion({1}), {2}, {1})))
+ return ::mlir::failure();)",
+ helperFnName, i, regionName)
+ .str());
+ }
+
+ // Add an overall verifier that sequences the helper calls
+ std::string verifierDef =
+ llvm::formatv(R"(
+::llvm::LogicalResult {0}::verifyInvariantsImpl() {
+ {1}
+ return ::mlir::success();
+ }
+
+::llvm::LogicalResult {0}::verifyInvariants() {
+ if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify()))
+ return ::mlir::success();
+ return ::mlir::failure();
+})",
+ strings.opCppName, llvm::join(verifierCalls, "\n"));
+
+ dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
+ dict["OP_VERIFIER"] = verifierDef;
+}
+
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
irdl::OperationOp op) {
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +477,15 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
dict["OP_BUILD_DEFS"] = buildDefinition;
+ std::vector<std::string> traitNames = generateTraits(dict, op, opStrings);
+ if (traitNames.empty())
+ dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
+ else
+ dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
+ llvm::join(traitNames, ", "));
+
+ generateVerifiers(dict, op, opStrings);
+
std::string str;
llvm::raw_string_ostream stream{str};
perOpDefTemplate.render(stream, dict);
@@ -427,7 +543,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
dict["TYPE_PARSER"] = llvm::formatv(
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
- {0}
+ {0}
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
*mnemonic = keyword;
return std::nullopt;
@@ -520,6 +636,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
"IRDL C++ translation does not yet support variadic results");
}))
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
+ .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
+ .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
.Default([](mlir::Operation *op) -> LogicalResult {
return op->emitError("IRDL C++ translation does not yet support "
"translation of ")
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
index e9068e9488f99..ae900bb818dd3 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
@@ -12,15 +12,15 @@ public:
struct Properties {
};
public:
- __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
- : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
- odsRegions(op->getRegions())
+ __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
+ : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
+ odsRegions(op->getRegions())
{}
/// Return the unstructured operand index of a structured operand along with
// the amount of unstructured operands it contains.
std::pair<unsigned, unsigned>
- getStructuredOperandIndexAndLength (unsigned index,
+ getStructuredOperandIndexAndLength (unsigned index,
unsigned odsOperandsSize) {
return {index, 1};
}
@@ -32,6 +32,12 @@ public:
::mlir::DictionaryAttr getAttributes() {
return odsAttrs;
}
+
+ __OP_REGION_ADAPTER_GETTER_DECLS__
+
+ ::mlir::RegionRange getRegions() {
+ return odsRegions;
+ }
protected:
::mlir::DictionaryAttr odsAttrs;
::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
} // namespace detail
template <typename RangeT>
-class __OP_CPP_NAME__GenericAdaptor
+class __OP_CPP_NAME__GenericAdaptor
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
public:
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
- ::mlir::OpaqueProperties properties,
- ::mlir::RegionRange regions = {})
- : __OP_CPP_NAME__GenericAdaptor(values, attrs,
- (properties ? *properties.as<::mlir::EmptyProperties *>()
+ ::mlir::OpaqueProperties properties,
+ ::mlir::RegionRange regions = {})
+ : __OP_CPP_NAME__GenericAdaptor(values, attrs,
+ (properties ? *properties.as<::mlir::EmptyProperties *>()
: ::mlir::EmptyProperties{}), regions) {}
- __OP_CPP_NAME__GenericAdaptor(RangeT values,
+ __OP_CPP_NAME__GenericAdaptor(RangeT values,
const __OP_CPP_NAME__GenericAdaptorBase &base)
: Base(base), odsOperands(values) {}
- // This template parameter allows using __OP_CPP_NAME__ which is declared
+ // This template parameter allows using __OP_CPP_NAME__ which is declared
// later.
template <typename LateInst = __OP_CPP_NAME__,
typename = std::enable_if_t<
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
- __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
+ __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
: Base(op), odsOperands(values) {}
/// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
RangeT getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(odsOperands.begin(), valueRange.first),
- std::next(odsOperands.begin(),
+ std::next(odsOperands.begin(),
valueRange.first + valueRange.second)};
}
@@ -91,7 +97,7 @@ private:
RangeT odsOperands;
};
-class __OP_CPP_NAME__Adaptor
+class __OP_CPP_NAME__Adaptor
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
public:
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
::llvm::LogicalResult verify(::mlir::Location loc);
};
-class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
+class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
public:
using Op::Op;
using Op::print;
@@ -147,7 +153,7 @@ public:
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(getOperation()->operand_begin(), valueRange.first),
- std::next(getOperation()->operand_begin(),
+ std::next(getOperation()->operand_begin(),
valueRange.first + valueRange.second)};
}
@@ -162,18 +168,19 @@ public:
::mlir::Operation::result_range getStructuredResults(unsigned index) {
auto valueRange = getStructuredResultIndexAndLength(index);
return {std::next(getOperation()->result_begin(), valueRange.first),
- std::next(getOperation()->result_begin(),
+ std::next(getOperation()->result_begin(),
valueRange.first + valueRange.second)};
}
__OP_OPERAND_GETTER_DECLS__
__OP_RESULT_GETTER_DECLS__
-
+ __OP_REGION_GETTER_DECLS__
+
__OP_BUILD_DECLS__
- static void build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+ static void build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
index 30ca420d77448..840e497560e6b 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
@@ -6,12 +6,14 @@ R"(
__NAMESPACE_OPEN__
+__OP_VERIFIER_HELPERS__
+
__OP_BUILD_DEFS__
-void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
{
assert(operands.size() == __OP_OPERAND_COUNT__);
@@ -19,6 +21,7 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
odsState.addOperands(operands);
odsState.addAttributes(attributes);
odsState.addTypes(resultTypes);
+ __OP_ADD_REGIONS__
}
__OP_CPP_NAME__
@@ -44,6 +47,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder,
return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes);
}
+__OP_VERIFIER__
__NAMESPACE_CLOSE__
|
|
4984b28
to
8c3273c
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
10d6418
to
3dcceb4
Compare
Thank you very much for the PR. Will review as soon as possible! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for this!
For the interface for verification: my intuition is that the non-impl version could be a public interface? But to be quite frank I am not familiar enough with ODS internals to know for sure. This could be a good question to ask on MLIR Discord!
For testing: could you also add something that uses the conditional op and its regions in TestIRDLToCppDialect.cpp
and test that separately in test_conversion.testd.mlir
? Maybe a lowering to scf.if next to the other conversion rewrites, why not?
I added a proper C++ test in fd49ca73f840, though to make this work I needed a terminator op; as far as I know irdl can't do this yet? So I had to allow unregistered dialects and used an unknown op, which MLIR treats as a terminator by default. |
fd49ca7
to
55693e8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me this looks good! I still don't know what the "right" way to register verifiers is, but this appears to work as expected and makes sense to me, so I'm happy to approve it. Feel free to merge it now or later if you want to get some time to hear back about the verifier registration. Thanks again!
Can always follow up, esp since there will be more verifier work to do. Thanks! |
Fixes llvm#158034 For the input ```mlir irdl.dialect @conditional_dialect { // A conditional operation with regions irdl.operation @conditional { // Create region constraints %r0 = irdl.region // Unconstrained region %r1 = irdl.region() // Region with no entry block arguments %v0 = irdl.any %r2 = irdl.region(%v0) // Region with one i1 entry block argument irdl.regions(cond: %r2, then: %r0, else: %r1) } } ``` This produces the following cpp: https://gist.github.com/j2kun/d2095f108efbd8d403576d5c460e0c00 Summary of changes: - The op class and adaptor get named accessors to the regions `Region &get<RegionName>()` and `getRegions()` - The op now gets `OpTrait::NRegions<3>` and `OpInvariants` to trigger the region verification - Support for region block argument constraints is added, but not working for all constraints until codegen for `irdl.is` is added (filed llvm#161018 and left a TODO). - Helper functions for the individual verification steps are added, following mlir-tblgen's format (in the above gist, `__mlir_irdl_local_region_constraint_ConditionalOp_cond` and similar), and `verifyInvariantsImpl` that calls them. - Regions are added in the builder ## Questions for the reviewer ### What is the "correct" interface for verification? I used `mlir-tblgen` on an analogous version of the example `ConditionalOp` in this PR, and I see an `::mlir::OpTrait::OpInvariants` trait as well as ```cpp ::llvm::LogicalResult ConditionalOp::verifyInvariantsImpl() { { unsigned index = 0; (void)index; for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(0))) if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "cond", index++))) return ::mlir::failure(); for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(1))) if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "then", index++))) return ::mlir::failure(); for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(2))) if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "else", index++))) return ::mlir::failure(); } return ::mlir::success(); } ::llvm::LogicalResult ConditionalOp::verifyInvariants() { if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify())) return ::mlir::success(); return ::mlir::failure(); } ``` However, `OpInvariants` only seems to need `verifyInvariantsImpl`, so it's not clear to me what is the purpose of the `verifyInvariants` function, or, if I leave out `verifyInvariants`, whether I need to call `verify()` in my implementation of `verifyInvariantsImpl`. In this PR, I omitted `verifyInvariants` and generated `verifyInvariantsImpl`. ### Is testing sufficient? I am not certain I implemented the builders properly, and it's unclear to me to what extent the existing tests check this (which look like they compile the generated cpp, but don't actually use it). Did I omit some standard function or overload? --------- Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com>
Fixes #158034
For the input
This produces the following cpp: https://gist.github.com/j2kun/d2095f108efbd8d403576d5c460e0c00
Summary of changes:
Region &get<RegionName>()
andgetRegions()
OpTrait::NRegions<3>
andOpInvariants
to trigger the region verificationirdl.is
is added (filed [MLIR][IRDL][irdl-to-cpp] add support forirdl.is
#161018 and left a TODO).__mlir_irdl_local_region_constraint_ConditionalOp_cond
and similar), andverifyInvariantsImpl
that calls them.Questions for the reviewer
What is the "correct" interface for verification?
I used
mlir-tblgen
on an analogous version of the exampleConditionalOp
in this PR, and I see an::mlir::OpTrait::OpInvariants
trait as well asHowever,
OpInvariants
only seems to needverifyInvariantsImpl
, so it's not clear to me what is the purpose of theverifyInvariants
function, or, if I leave outverifyInvariants
, whether I need to callverify()
in my implementation ofverifyInvariantsImpl
. In this PR, I omittedverifyInvariants
and generatedverifyInvariantsImpl
.Is testing sufficient?
I am not certain I implemented the builders properly, and it's unclear to me to what extent the existing tests check this (which look like they compile the generated cpp, but don't actually use it). Did I omit some standard function or overload?