Skip to content

Conversation

joker-eph
Copy link
Collaborator

This allows to define multiple interface methods with the same name but different arguments.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

This allows to define multiple interface methods with the same name but different arguments.


Full diff: https://github.com/llvm/llvm-project/pull/161828.diff

5 Files Affected:

  • (modified) mlir/test/lib/Dialect/Test/TestTypes.cpp (+4)
  • (modified) mlir/test/lib/IR/TestInterfaces.cpp (+1)
  • (modified) mlir/test/mlir-tblgen/interfaces.mlir (+1)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+26-1)
  • (modified) mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (+18-14)
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..9076c7e54d7bf 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -245,6 +245,10 @@ void TestType::printTypeC(Location loc) const {
   emitRemark(loc) << *this << " - TestC";
 }
 
+void TestType::printTypeC(Location loc, int value) const {
+  emitRemark(loc) << *this << " - " << value << " - TestC";
+}
+
 //===----------------------------------------------------------------------===//
 // TestTypeWithLayout
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp
index 2dd3fe245e220..e021f78e1142d 100644
--- a/mlir/test/lib/IR/TestInterfaces.cpp
+++ b/mlir/test/lib/IR/TestInterfaces.cpp
@@ -31,6 +31,7 @@ struct TestTypeInterfaces
           testInterface.printTypeA(op->getLoc());
           testInterface.printTypeB(op->getLoc());
           testInterface.printTypeC(op->getLoc());
+          testInterface.printTypeC(op->getLoc(), 42);
           testInterface.printTypeD(op->getLoc());
           // Just check that we can assign the result to a variable of interface
           // type.
diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir
index 5c1ec613b387a..927cfd728bcd4 100644
--- a/mlir/test/mlir-tblgen/interfaces.mlir
+++ b/mlir/test/mlir-tblgen/interfaces.mlir
@@ -3,6 +3,7 @@
 // expected-remark@below {{'!test.test_type' - TestA}}
 // expected-remark@below {{'!test.test_type' - TestB}}
 // expected-remark@below {{'!test.test_type' - TestC}}
+// expected-remark@below {{'!test.test_type' - 42 - TestC}}
 // expected-remark@below {{'!test.test_type' - TestD}}
 // expected-remark@below {{'!test.test_type' - TestRet}}
 // expected-remark@below {{'!test.test_type' - TestE}}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 7e8e559baf878..4c6519cd2f7bf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -789,6 +789,14 @@ class OpEmitter {
   Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
                                bool declaration = true);
 
+  // Generate a `using` declaration for the op interface method to include
+  // the default implementation from the interface trait.
+  // This is needed when the interface defines multiple methods with the same
+  // name, but some have a default implementation and some don't.
+  UsingDeclaration *
+  genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+                                const tblgen::InterfaceMethod &method);
+
   // Generate the side effect interface methods.
   void genSideEffectInterfaceMethods();
 
@@ -815,6 +823,10 @@ class OpEmitter {
 
   // Helper for emitting op code.
   OpOrAdaptorHelper emitHelper;
+
+  // Keep track of the interface using declarations that have been generated to
+  // avoid duplicates.
+  llvm::StringSet<> interfaceUsingNames;
 };
 
 } // namespace
@@ -3672,8 +3684,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
     // Don't declare if the method has a default implementation and the op
     // didn't request that it always be declared.
     if (method.getDefaultImplementation() &&
-        !alwaysDeclaredMethods.count(method.getName()))
+        !alwaysDeclaredMethods.count(method.getName())) {
+      genOpInterfaceMethodUsingDecl(opTrait, method);
       continue;
+    }
     // Interface methods are allowed to overlap with existing methods, so don't
     // check if pruned.
     (void)genOpInterfaceMethod(method);
@@ -3692,6 +3706,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
                            std::move(paramList));
 }
 
+UsingDeclaration *
+OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+                                         const InterfaceMethod &method) {
+  std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" +
+                      op.getQualCppClassName() + ">::" + method.getName())
+                         .str();
+  if (interfaceUsingNames.insert(name).second)
+    return opClass.declare<UsingDeclaration>(std::move(name));
+  return nullptr;
+}
+
 void OpEmitter::genOpInterfaceMethods() {
   for (const auto &trait : op.getTraits()) {
     if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 3cc1636ac3317..9dedd55005f87 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
 /// Emit the method name and argument list for the given method. If 'addThisArg'
 /// is true, then an argument is added to the beginning of the argument list for
 /// the concrete value.
-static void emitMethodNameAndArgs(const InterfaceMethod &method,
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
                                   raw_ostream &os, StringRef valueType,
                                   bool addThisArg, bool addConst) {
-  os << method.getName() << '(';
+  os << name << '(';
   if (addThisArg) {
     if (addConst)
       os << "const ";
@@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName,
     emitInterfaceMethodDoc(method, os);
     emitCPPType(method.getReturnType(), os);
     os << interfaceQualName << "::";
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface);
 
     // Forward to the method on the concrete operation type.
-    os << " {\n      return " << implValue << "->" << method.getName() << '(';
+    os << " {\n      return " << implValue << "->" << method.getDedupName()
+       << '(';
     if (!method.isStatic()) {
       os << implValue << ", ";
       os << (isOpInterface ? "getOperation()" : "*this");
@@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
   for (auto &method : interface.getMethods()) {
     os << "    ";
     emitCPPType(method.getReturnType(), os);
-    os << "(*" << method.getName() << ")(";
+    os << "(*" << method.getDedupName() << ")(";
     if (!method.isStatic()) {
       os << "const Concept *impl, ";
       emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
@@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
     os << "    " << modelClass << "() : Concept{";
     llvm::interleaveComma(
         interface.getMethods(), os,
-        [&](const InterfaceMethod &method) { os << method.getName(); });
+        [&](const InterfaceMethod &method) { os << method.getDedupName(); });
     os << "} {}\n\n";
 
     // Insert each of the virtual method overrides.
     for (auto &method : interface.getMethods()) {
       emitCPPType(method.getReturnType(), os << "    static inline ");
-      emitMethodNameAndArgs(method, os, valueType,
+      emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                             /*addThisArg=*/!method.isStatic(),
                             /*addConst=*/false);
       os << ";\n";
@@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
     if (method.isStatic())
       os << "static ";
     emitCPPType(method.getReturnType(), os);
-    os << method.getName() << "(";
+    os << method.getDedupName() << "(";
     if (!method.isStatic()) {
       emitCPPType(valueType, os);
       os << "tablegen_opaque_val";
@@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
     emitCPPType(method.getReturnType(), os);
     os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
        << valueTemplate << ">::";
-    emitMethodNameAndArgs(method, os, valueType,
+    emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                           /*addThisArg=*/!method.isStatic(),
                           /*addConst=*/false);
     os << " {\n  ";
@@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
     emitCPPType(method.getReturnType(), os);
     os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
        << valueTemplate << ">::";
-    emitMethodNameAndArgs(method, os, valueType,
+    emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                           /*addThisArg=*/!method.isStatic(),
                           /*addConst=*/false);
     os << " {\n  ";
@@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
       os << "return static_cast<const " << valueTemplate << " *>(impl)->";
 
     // Add the arguments to the call.
-    os << method.getName() << '(';
+    os << method.getDedupName() << '(';
     if (!method.isStatic())
       os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
     llvm::interleaveComma(
@@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
        << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
        << ">::";
 
-    os << method.getName() << "(";
+    os << method.getDedupName() << "(";
     if (!method.isStatic()) {
       emitCPPType(valueType, os);
       os << "tablegen_opaque_val";
@@ -477,7 +479,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
     emitInterfaceMethodDoc(method, os, "    ");
     os << "    " << (method.isStatic() ? "static " : "");
     emitCPPType(method.getReturnType(), os);
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface && !method.isStatic());
     os << " {\n      " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
        << "\n    }\n";
@@ -514,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface,
   for (auto &method : interface.getMethods()) {
     emitInterfaceMethodDoc(method, os, "  ");
     emitCPPType(method.getReturnType(), os << "  ");
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface);
     os << ";\n";
   }

@joker-eph joker-eph force-pushed the interface_overloading branch 4 times, most recently from 9ddecbc to d603c93 Compare October 3, 2025 14:48
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 3, 2025
std::string name =
cast<DefInit>(init)->getDef()->getValueAsString("name").str();
while (!dedupNames.insert(name).second) {
name = name + "_" + std::to_string(dedupNames.size());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come a single counter is sufficient here? I expected name mangling scheme that makes the argument types part of the function name. When the interface method is called, how do you know which dedup method to call?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the interface method is called, how do you know which dedup method to call?

The actual deduce method name could be a randomly generated name, as long as the mapping is consistent.
So we can initialize InterfaceMethod with any dedup name as long as it is uniqued.
This is all just internal to the "vtables" we generate. The public interface/trait will expose the overloaded methods with their public name, and try to call the method on the Op with the public name as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments about this, PTAL.

This allows to define multiple interface methods with the same name
but different arguments.
@joker-eph joker-eph force-pushed the interface_overloading branch from d603c93 to c5f6b0f Compare October 6, 2025 15:03
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a useful feature. I had to work around this limitation in the past: there is BufferizableOpInterface::bufferizesToMemoryWrite and BufferizableOpInterface::resultBufferizesToMemoryWrite.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants