-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][ODS] Add support for overloading interface methods #161828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesThis allows to define multiple interface methods with the same name but different arguments. Full diff: https://github.com/llvm/llvm-project/pull/161828.diff 5 Files Affected:
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..9076c7e54d7bf 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -245,6 +245,10 @@ void TestType::printTypeC(Location loc) const {
emitRemark(loc) << *this << " - TestC";
}
+void TestType::printTypeC(Location loc, int value) const {
+ emitRemark(loc) << *this << " - " << value << " - TestC";
+}
+
//===----------------------------------------------------------------------===//
// TestTypeWithLayout
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp
index 2dd3fe245e220..e021f78e1142d 100644
--- a/mlir/test/lib/IR/TestInterfaces.cpp
+++ b/mlir/test/lib/IR/TestInterfaces.cpp
@@ -31,6 +31,7 @@ struct TestTypeInterfaces
testInterface.printTypeA(op->getLoc());
testInterface.printTypeB(op->getLoc());
testInterface.printTypeC(op->getLoc());
+ testInterface.printTypeC(op->getLoc(), 42);
testInterface.printTypeD(op->getLoc());
// Just check that we can assign the result to a variable of interface
// type.
diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir
index 5c1ec613b387a..927cfd728bcd4 100644
--- a/mlir/test/mlir-tblgen/interfaces.mlir
+++ b/mlir/test/mlir-tblgen/interfaces.mlir
@@ -3,6 +3,7 @@
// expected-remark@below {{'!test.test_type' - TestA}}
// expected-remark@below {{'!test.test_type' - TestB}}
// expected-remark@below {{'!test.test_type' - TestC}}
+// expected-remark@below {{'!test.test_type' - 42 - TestC}}
// expected-remark@below {{'!test.test_type' - TestD}}
// expected-remark@below {{'!test.test_type' - TestRet}}
// expected-remark@below {{'!test.test_type' - TestE}}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 7e8e559baf878..4c6519cd2f7bf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -789,6 +789,14 @@ class OpEmitter {
Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
bool declaration = true);
+ // Generate a `using` declaration for the op interface method to include
+ // the default implementation from the interface trait.
+ // This is needed when the interface defines multiple methods with the same
+ // name, but some have a default implementation and some don't.
+ UsingDeclaration *
+ genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+ const tblgen::InterfaceMethod &method);
+
// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();
@@ -815,6 +823,10 @@ class OpEmitter {
// Helper for emitting op code.
OpOrAdaptorHelper emitHelper;
+
+ // Keep track of the interface using declarations that have been generated to
+ // avoid duplicates.
+ llvm::StringSet<> interfaceUsingNames;
};
} // namespace
@@ -3672,8 +3684,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
// Don't declare if the method has a default implementation and the op
// didn't request that it always be declared.
if (method.getDefaultImplementation() &&
- !alwaysDeclaredMethods.count(method.getName()))
+ !alwaysDeclaredMethods.count(method.getName())) {
+ genOpInterfaceMethodUsingDecl(opTrait, method);
continue;
+ }
// Interface methods are allowed to overlap with existing methods, so don't
// check if pruned.
(void)genOpInterfaceMethod(method);
@@ -3692,6 +3706,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
std::move(paramList));
}
+UsingDeclaration *
+OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+ const InterfaceMethod &method) {
+ std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" +
+ op.getQualCppClassName() + ">::" + method.getName())
+ .str();
+ if (interfaceUsingNames.insert(name).second)
+ return opClass.declare<UsingDeclaration>(std::move(name));
+ return nullptr;
+}
+
void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 3cc1636ac3317..9dedd55005f87 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
/// Emit the method name and argument list for the given method. If 'addThisArg'
/// is true, then an argument is added to the beginning of the argument list for
/// the concrete value.
-static void emitMethodNameAndArgs(const InterfaceMethod &method,
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
raw_ostream &os, StringRef valueType,
bool addThisArg, bool addConst) {
- os << method.getName() << '(';
+ os << name << '(';
if (addThisArg) {
if (addConst)
os << "const ";
@@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName,
emitInterfaceMethodDoc(method, os);
emitCPPType(method.getReturnType(), os);
os << interfaceQualName << "::";
- emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+ emitMethodNameAndArgs(method, method.getName(), os, valueType,
+ /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
// Forward to the method on the concrete operation type.
- os << " {\n return " << implValue << "->" << method.getName() << '(';
+ os << " {\n return " << implValue << "->" << method.getDedupName()
+ << '(';
if (!method.isStatic()) {
os << implValue << ", ";
os << (isOpInterface ? "getOperation()" : "*this");
@@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
for (auto &method : interface.getMethods()) {
os << " ";
emitCPPType(method.getReturnType(), os);
- os << "(*" << method.getName() << ")(";
+ os << "(*" << method.getDedupName() << ")(";
if (!method.isStatic()) {
os << "const Concept *impl, ";
emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
@@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
os << " " << modelClass << "() : Concept{";
llvm::interleaveComma(
interface.getMethods(), os,
- [&](const InterfaceMethod &method) { os << method.getName(); });
+ [&](const InterfaceMethod &method) { os << method.getDedupName(); });
os << "} {}\n\n";
// Insert each of the virtual method overrides.
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os << " static inline ");
- emitMethodNameAndArgs(method, os, valueType,
+ emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << ";\n";
@@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
if (method.isStatic())
os << "static ";
emitCPPType(method.getReturnType(), os);
- os << method.getName() << "(";
+ os << method.getDedupName() << "(";
if (!method.isStatic()) {
emitCPPType(valueType, os);
os << "tablegen_opaque_val";
@@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
<< valueTemplate << ">::";
- emitMethodNameAndArgs(method, os, valueType,
+ emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << " {\n ";
@@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
<< valueTemplate << ">::";
- emitMethodNameAndArgs(method, os, valueType,
+ emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << " {\n ";
@@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
os << "return static_cast<const " << valueTemplate << " *>(impl)->";
// Add the arguments to the call.
- os << method.getName() << '(';
+ os << method.getDedupName() << '(';
if (!method.isStatic())
os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
llvm::interleaveComma(
@@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
<< "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
<< ">::";
- os << method.getName() << "(";
+ os << method.getDedupName() << "(";
if (!method.isStatic()) {
emitCPPType(valueType, os);
os << "tablegen_opaque_val";
@@ -477,7 +479,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
emitInterfaceMethodDoc(method, os, " ");
os << " " << (method.isStatic() ? "static " : "");
emitCPPType(method.getReturnType(), os);
- emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+ emitMethodNameAndArgs(method, method.getName(), os, valueType,
+ /*addThisArg=*/false,
/*addConst=*/!isOpInterface && !method.isStatic());
os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
<< "\n }\n";
@@ -514,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface,
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os, " ");
emitCPPType(method.getReturnType(), os << " ");
- emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+ emitMethodNameAndArgs(method, method.getName(), os, valueType,
+ /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
os << ";\n";
}
|
9ddecbc
to
d603c93
Compare
mlir/lib/TableGen/Interfaces.cpp
Outdated
std::string name = | ||
cast<DefInit>(init)->getDef()->getValueAsString("name").str(); | ||
while (!dedupNames.insert(name).second) { | ||
name = name + "_" + std::to_string(dedupNames.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come a single counter is sufficient here? I expected name mangling scheme that makes the argument types part of the function name. When the interface method is called, how do you know which dedup method to call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the interface method is called, how do you know which dedup method to call?
The actual deduce method name could be a randomly generated name, as long as the mapping is consistent.
So we can initialize InterfaceMethod
with any dedup name as long as it is uniqued.
This is all just internal to the "vtables" we generate. The public interface/trait will expose the overloaded methods with their public name, and try to call the method on the Op with the public name as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some comments about this, PTAL.
This allows to define multiple interface methods with the same name but different arguments.
d603c93
to
c5f6b0f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a useful feature. I had to work around this limitation in the past: there is BufferizableOpInterface::bufferizesToMemoryWrite
and BufferizableOpInterface::resultBufferizesToMemoryWrite
.
This allows to define multiple interface methods with the same name but different arguments.