From c5f6b0f6d8d30ccc1cec53f7d580e1def8123dcb Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 3 Oct 2025 04:47:35 -0700 Subject: [PATCH] [MLIR][ODS] Add support for overloading interface methods This allows to define multiple interface methods with the same name but different arguments. --- .../include/flang/Optimizer/HLFIR/HLFIROps.td | 9 +++++- mlir/include/mlir/TableGen/Interfaces.h | 9 +++++- mlir/lib/TableGen/Interfaces.cpp | 21 ++++++++++-- mlir/test/lib/Dialect/Test/TestInterfaces.td | 10 ++++++ mlir/test/lib/Dialect/Test/TestTypes.cpp | 4 +++ mlir/test/lib/IR/TestInterfaces.cpp | 2 ++ mlir/test/mlir-tblgen/interfaces.mlir | 2 ++ mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 19 ++++++++++- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 27 +++++++++++++++- mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 32 +++++++++++-------- 10 files changed, 114 insertions(+), 21 deletions(-) diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index 90512586a6520..218435a44c24f 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -269,6 +269,9 @@ def hlfir_DesignateOp : hlfir_Op<"designate", [AttrSizedOperandSegments, using Triplet = std::tuple; using Subscript = std::variant; using Subscripts = llvm::SmallVector; + void setFortranAttrs(fir::FortranVariableFlagsEnum flags) { + this->setFortranAttrs(std::optional(flags)); + } }]; let builders = [ @@ -319,7 +322,7 @@ def hlfir_ParentComponentOp : hlfir_Op<"parent_comp", [AttrSizedOperandSegments, // Implement FortranVariableInterface interface. Parent components have // no attributes (pointer, allocatable or contiguous can only be added // to regular components). - std::optional getFortranAttrs() const { + std::optional getFortranAttrs() { return std::nullopt; } void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {} @@ -882,6 +885,10 @@ def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments, CArg<"llvm::ArrayRef", "{}">:$attributes)>]; let extraClassDeclaration = [{ + void setFortranAttrs(fir::FortranVariableFlagsEnum flags) { + this->setFortranAttrs(std::optional(flags)); + } + /// Override FortranVariableInterface default implementation mlir::Value getBase() { return getResult(0); diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h index 15f667e0ffce0..7c36cbc1192ac 100644 --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -32,7 +32,7 @@ class InterfaceMethod { StringRef name; }; - explicit InterfaceMethod(const llvm::Record *def); + explicit InterfaceMethod(const llvm::Record *def, std::string uniqueName); // Return the return type of this method. StringRef getReturnType() const; @@ -40,6 +40,9 @@ class InterfaceMethod { // Return the name of this method. StringRef getName() const; + // Return the dedup name of this method. + StringRef getUniqueName() const; + // Return if this method is static. bool isStatic() const; @@ -62,6 +65,10 @@ class InterfaceMethod { // The arguments of this method. SmallVector arguments; + + // The unique name of this method, to distinguish it from other methods with + // the same name (overloaded methods) + std::string uniqueName; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp index ec7adf3b02c21..b0ad3ee59a089 100644 --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -25,7 +25,8 @@ using llvm::StringInit; // InterfaceMethod //===----------------------------------------------------------------------===// -InterfaceMethod::InterfaceMethod(const Record *def) : def(def) { +InterfaceMethod::InterfaceMethod(const Record *def, std::string uniqueName) + : def(def), uniqueName(uniqueName) { const DagInit *args = def->getValueAsDag("arguments"); for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) { arguments.push_back({cast(args->getArg(i))->getValue(), @@ -42,6 +43,9 @@ StringRef InterfaceMethod::getName() const { return def->getValueAsString("name"); } +// Return the name of this method. +StringRef InterfaceMethod::getUniqueName() const { return uniqueName; } + // Return if this method is static. bool InterfaceMethod::isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); @@ -83,8 +87,19 @@ Interface::Interface(const Record *def) : def(def) { // Initialize the interface methods. auto *listInit = dyn_cast(def->getValueInit("methods")); - for (const Init *init : listInit->getElements()) - methods.emplace_back(cast(init)->getDef()); + // In case of overloaded methods, we need to find a unique name for each for + // the internal function pointer in the "vtable" we generate. This is an + // internal name, we could use a randomly generated name as long as there are + // no collisions. + StringSet<> uniqueNames; + for (const Init *init : listInit->getElements()) { + std::string name = + cast(init)->getDef()->getValueAsString("name").str(); + while (!uniqueNames.insert(name).second) { + name = name + "_" + std::to_string(uniqueNames.size()); + } + methods.emplace_back(cast(init)->getDef(), name); + } // Initialize the interface base classes. auto *basesInit = dyn_cast(def->getValueInit("baseInterfaces")); diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td index d3d96ea5a65a4..3697e38ac4c7d 100644 --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -44,6 +44,16 @@ def TestTypeInterface InterfaceMethod<"Prints the type name.", "void", "printTypeC", (ins "::mlir::Location":$loc) >, + // Check that we can have multiple method with the same name. + InterfaceMethod<"Prints the type name, with a value prefixed.", + "void", "printTypeC", (ins "::mlir::Location":$loc, "int":$value) + >, + InterfaceMethod<"Prints the type name, with a value prefixed.", + "void", "printTypeC", (ins "::mlir::Location":$loc, "float":$value), + [{}], /*defaultImplementation=*/[{ + emitRemark(loc) << $_type << " - " << value << " - Float TestC"; + }] + >, // It should be possible to use the interface type name as result type // as well as in the implementation. InterfaceMethod<"Prints the type name and returns the type as interface.", diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index bea043f56fe21..614121f1d43dd 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -245,6 +245,10 @@ void TestType::printTypeC(Location loc) const { emitRemark(loc) << *this << " - TestC"; } +void TestType::printTypeC(Location loc, int value) const { + emitRemark(loc) << *this << " - " << value << " - Int TestC"; +} + //===----------------------------------------------------------------------===// // TestTypeWithLayout //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp index 2dd3fe245e220..881019dbfd50d 100644 --- a/mlir/test/lib/IR/TestInterfaces.cpp +++ b/mlir/test/lib/IR/TestInterfaces.cpp @@ -31,6 +31,8 @@ struct TestTypeInterfaces testInterface.printTypeA(op->getLoc()); testInterface.printTypeB(op->getLoc()); testInterface.printTypeC(op->getLoc()); + testInterface.printTypeC(op->getLoc(), 42); + testInterface.printTypeC(op->getLoc(), 3.14f); testInterface.printTypeD(op->getLoc()); // Just check that we can assign the result to a variable of interface // type. diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir index 5c1ec613b387a..b5d694f75734c 100644 --- a/mlir/test/mlir-tblgen/interfaces.mlir +++ b/mlir/test/mlir-tblgen/interfaces.mlir @@ -3,6 +3,8 @@ // expected-remark@below {{'!test.test_type' - TestA}} // expected-remark@below {{'!test.test_type' - TestB}} // expected-remark@below {{'!test.test_type' - TestC}} +// expected-remark@below {{'!test.test_type' - 42 - Int TestC}} +// expected-remark@below {{'!test.test_type' - 3.140000e+00 - Float TestC}} // expected-remark@below {{'!test.test_type' - TestD}} // expected-remark@below {{'!test.test_type' - TestRet}} // expected-remark@below {{'!test.test_type' - TestE}} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index b9115657d6bf3..15b03b85727f6 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -130,6 +130,9 @@ class DefGen { void emitTraitMethods(const InterfaceTrait &trait); /// Emit a trait method. void emitTraitMethod(const InterfaceMethod &method); + /// Generate a using declaration for a trait method. + void genTraitMethodUsingDecl(const InterfaceTrait &trait, + const InterfaceMethod &method); //===--------------------------------------------------------------------===// // OpAsm{Type,Attr}Interface Default Method Emission @@ -176,6 +179,9 @@ class DefGen { StringRef valueType; /// The prefix/suffix of the TableGen def name, either "Attr" or "Type". StringRef defType; + + /// The set of using declarations for trait methods. + llvm::StringSet<> interfaceUsingNames; }; } // namespace @@ -632,8 +638,10 @@ void DefGen::emitTraitMethods(const InterfaceTrait &trait) { // Don't declare if the method has a body. Or if the method has a default // implementation and the def didn't request that it always be declared. if (method.getBody() || (method.getDefaultImplementation() && - !alwaysDeclared.count(method.getName()))) + !alwaysDeclared.count(method.getName()))) { + genTraitMethodUsingDecl(trait, method); continue; + } emitTraitMethod(method); } } @@ -649,6 +657,15 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) { std::move(params)); } +void DefGen::genTraitMethodUsingDecl(const InterfaceTrait &trait, + const InterfaceMethod &method) { + std::string name = (llvm::Twine(trait.getFullyQualifiedTraitName()) + "<" + + def.getCppClassName() + ">::" + method.getName()) + .str(); + if (interfaceUsingNames.insert(name).second) + defCls.declare(std::move(name)); +} + //===----------------------------------------------------------------------===// // OpAsm{Type,Attr}Interface Default Method Emission diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 7e8e559baf878..70c462bb667b2 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -789,6 +789,14 @@ class OpEmitter { Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, bool declaration = true); + // Generate a `using` declaration for the op interface method to include + // the default implementation from the interface trait. + // This is needed when the interface defines multiple methods with the same + // name, but some have a default implementation and some don't. + UsingDeclaration * + genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait, + const tblgen::InterfaceMethod &method); + // Generate the side effect interface methods. void genSideEffectInterfaceMethods(); @@ -815,6 +823,10 @@ class OpEmitter { // Helper for emitting op code. OpOrAdaptorHelper emitHelper; + + // Keep track of the interface using declarations that have been generated to + // avoid duplicates. + llvm::StringSet<> interfaceUsingNames; }; } // namespace @@ -3672,8 +3684,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) { // Don't declare if the method has a default implementation and the op // didn't request that it always be declared. if (method.getDefaultImplementation() && - !alwaysDeclaredMethods.count(method.getName())) + !alwaysDeclaredMethods.count(method.getName())) { + genOpInterfaceMethodUsingDecl(opTrait, method); continue; + } // Interface methods are allowed to overlap with existing methods, so don't // check if pruned. (void)genOpInterfaceMethod(method); @@ -3692,6 +3706,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, std::move(paramList)); } +UsingDeclaration * +OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait, + const InterfaceMethod &method) { + std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" + + op.getCppClassName() + ">::" + method.getName()) + .str(); + if (interfaceUsingNames.insert(name).second) + return opClass.declare(std::move(name)); + return nullptr; +} + void OpEmitter::genOpInterfaceMethods() { for (const auto &trait : op.getTraits()) { if (const auto *opTrait = dyn_cast(&trait)) diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 3cc1636ac3317..0510fcad35d61 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { /// Emit the method name and argument list for the given method. If 'addThisArg' /// is true, then an argument is added to the beginning of the argument list for /// the concrete value. -static void emitMethodNameAndArgs(const InterfaceMethod &method, +static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name, raw_ostream &os, StringRef valueType, bool addThisArg, bool addConst) { - os << method.getName() << '('; + os << name << '('; if (addThisArg) { if (addConst) os << "const "; @@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName, emitInterfaceMethodDoc(method, os); emitCPPType(method.getReturnType(), os); os << interfaceQualName << "::"; - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + emitMethodNameAndArgs(method, method.getName(), os, valueType, + /*addThisArg=*/false, /*addConst=*/!isOpInterface); // Forward to the method on the concrete operation type. - os << " {\n return " << implValue << "->" << method.getName() << '('; + os << " {\n return " << implValue << "->" << method.getUniqueName() + << '('; if (!method.isStatic()) { os << implValue << ", "; os << (isOpInterface ? "getOperation()" : "*this"); @@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) { for (auto &method : interface.getMethods()) { os << " "; emitCPPType(method.getReturnType(), os); - os << "(*" << method.getName() << ")("; + os << "(*" << method.getUniqueName() << ")("; if (!method.isStatic()) { os << "const Concept *impl, "; emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); @@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) { os << " " << modelClass << "() : Concept{"; llvm::interleaveComma( interface.getMethods(), os, - [&](const InterfaceMethod &method) { os << method.getName(); }); + [&](const InterfaceMethod &method) { os << method.getUniqueName(); }); os << "} {}\n\n"; // Insert each of the virtual method overrides. for (auto &method : interface.getMethods()) { emitCPPType(method.getReturnType(), os << " static inline "); - emitMethodNameAndArgs(method, os, valueType, + emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); os << ";\n"; @@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) { if (method.isStatic()) os << "static "; emitCPPType(method.getReturnType(), os); - os << method.getName() << "("; + os << method.getUniqueName() << "("; if (!method.isStatic()) { emitCPPType(valueType, os); os << "tablegen_opaque_val"; @@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { emitCPPType(method.getReturnType(), os); os << "detail::" << interface.getName() << "InterfaceTraits::Model<" << valueTemplate << ">::"; - emitMethodNameAndArgs(method, os, valueType, + emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); os << " {\n "; @@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { emitCPPType(method.getReturnType(), os); os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<" << valueTemplate << ">::"; - emitMethodNameAndArgs(method, os, valueType, + emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); os << " {\n "; @@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { os << "return static_cast(impl)->"; // Add the arguments to the call. - os << method.getName() << '('; + os << method.getUniqueName() << '('; if (!method.isStatic()) os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); llvm::interleaveComma( @@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { << "InterfaceTraits::ExternalModel::"; - os << method.getName() << "("; + os << method.getUniqueName() << "("; if (!method.isStatic()) { emitCPPType(valueType, os); os << "tablegen_opaque_val"; @@ -477,7 +479,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { emitInterfaceMethodDoc(method, os, " "); os << " " << (method.isStatic() ? "static " : ""); emitCPPType(method.getReturnType(), os); - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + emitMethodNameAndArgs(method, method.getName(), os, valueType, + /*addThisArg=*/false, /*addConst=*/!isOpInterface && !method.isStatic()); os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt) << "\n }\n"; @@ -514,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface, for (auto &method : interface.getMethods()) { emitInterfaceMethodDoc(method, os, " "); emitCPPType(method.getReturnType(), os << " "); - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + emitMethodNameAndArgs(method, method.getName(), os, valueType, + /*addThisArg=*/false, /*addConst=*/!isOpInterface); os << ";\n"; }