diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index 633e4924ebfd6..8e75146b2c17b 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -384,6 +384,9 @@ comprised of the following components: - Additional C++ code that is generated in the declaration of the interface class. This allows for defining methods and more on the user facing interface class, that do not need to hook into the IR entity. + These declarations are _not_ implicitly visible in default + implementations of interface methods, but static declarations may be + accessed with full name qualification. `OpInterface` classes may additionally contain the following: @@ -430,6 +433,8 @@ Interface methods are comprised of the following components: - `ConcreteAttr`/`ConcreteOp`/`ConcreteType` is an implicitly defined `typename` that can be used to refer to the type of the derived IR entity currently being operated on. + - This may refer to static fields of the interface class using the + qualified name, e.g., `TestOpInterface::staticMethod()`. ODS also allows for generating declarations for the `InterfaceMethod`s of an operation if the operation specifies the interface with diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 3c568a0e776f4..bb69ed8deef41 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -78,7 +78,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> { ]; let verify = [{ - auto concreteOp = cast($_op); + auto concreteOp = cast($_op); for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) { Optional operands = concreteOp.getSuccessorOperands(i); if (failed(detail::verifyBranchSuccessorOperands($_op, i, operands))) @@ -154,7 +154,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { ]; let verify = [{ - static_assert(!ConcreteOpType::template hasTrait(), + static_assert(!ConcreteOp::template hasTrait(), "expected operation to have non-zero regions"); return success(); }]; diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index d4b4cd123e9db..a365b42621cdc 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -241,17 +241,10 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) { os << " };\n"; } + // Emit the template for the external model. os << " template\n"; os << " class ExternalModel : public FallbackModel {\n"; - - // Emit the template for the external model if there are no extra class - // declarations. - if (interface.getExtraClassDeclaration()) { - os << " };\n"; - return; - } - os << " public:\n"; // Emit declarations for methods that have default implementations. Other @@ -345,9 +338,6 @@ void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { } // Emit default implementations for the external model. - if (interface.getExtraClassDeclaration()) - return; - for (auto &method : interface.getMethods()) { if (!method.getDefaultImplementation()) continue; @@ -427,11 +417,6 @@ void InterfaceGenerator::emitTraitDecl(Interface &interface, os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; - - // Emit a utility wrapper trait class. - os << llvm::formatv(" template \n" - " struct Trait : public {0}Trait<{1}> {{};\n", - interfaceName, valueTemplate); } void InterfaceGenerator::emitInterfaceDecl(Interface interface) { @@ -452,7 +437,13 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) { << "struct " << interfaceTraitsName << " {\n"; emitConceptDecl(interface); emitModelDecl(interface); - os << "};\n} // end namespace detail\n"; + os << "};"; + + // Emit the derived trait for the interface. + os << "template \n"; + os << "struct " << interface.getName() << "Trait;\n"; + + os << "\n} // end namespace detail\n"; // Emit the main interface class declaration. os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" @@ -461,8 +452,10 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) { interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); - // Emit the derived trait for the interface. - emitTraitDecl(interface, interfaceName, interfaceTraitsName); + // Emit a utility wrapper trait class. + os << llvm::formatv(" template \n" + " struct Trait : public detail::{0}Trait<{1}> {{};\n", + interfaceName, valueTemplate); // Insert the method declarations. bool isOpInterface = isa(interface); @@ -479,6 +472,10 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) { os << "};\n"; + os << "namespace detail {\n"; + emitTraitDecl(interface, interfaceName, interfaceTraitsName); + os << "}// namespace detail\n"; + emitModelMethodsDef(interface); for (StringRef ns : llvm::reverse(namespaces))