diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 33d962d01598a..42eccde488919 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1679,8 +1679,8 @@ class Op : public OpState, public Traits... { reinterpret_cast(const_cast(pointer))); } - /// Attach the given models as implementations of the corresponding interfaces - /// for the concrete operation. + /// Attach the given models as implementations of the corresponding + /// interfaces for the concrete operation. template static void attachInterface(MLIRContext &context) { Optional info = RegisteredOperationName::lookup( @@ -1689,6 +1689,7 @@ class Op : public OpState, public Traits... { llvm::report_fatal_error( "Attempting to attach an interface to an unregistered operation " + ConcreteType::getOperationName() + "."); + (void)std::initializer_list{(checkInterfaceTarget(), 0)...}; info->attachInterface(); } @@ -1714,6 +1715,32 @@ class Op : public OpState, public Traits... { template using detect_has_print = llvm::is_detected; + /// Trait to check if T provides a 'ConcreteEntity' type alias. + template + using has_concrete_entity_t = typename T::ConcreteEntity; + + /// A struct-wrapped type alias to T::ConcreteEntity if provided and to + /// ConcreteType otherwise. This is akin to std::conditional but doesn't fail + /// on the missing typedef. Useful for checking if the interface is targeting + /// the right class. + template ::value> + struct InterfaceTargetOrOpT { + using type = typename T::ConcreteEntity; + }; + template struct InterfaceTargetOrOpT { + using type = ConcreteType; + }; + + /// A hook for static assertion that the external interface model T is + /// targeting the concrete type of this op. The model can also be a fallback + /// model that works for every op. + template static void checkInterfaceTarget() { + static_assert(std::is_same::type, + ConcreteType>::value, + "attaching an interface to the wrong op kind"); + } + /// Returns an interface map containing the interfaces registered to this /// operation. static detail::InterfaceMap getInterfaceMap() { diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index 8cd159e6f0438..6d854f66f6ff7 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -127,6 +127,8 @@ class StorageUserBase : public BaseT, public Traits... { if (!abstract) llvm::report_fatal_error("Registering an interface for an attribute/type " "that is not itself registered."); + (void)std::initializer_list{ + (checkInterfaceTarget(), 0)...}; abstract->interfaceMap.template insert(); } @@ -182,6 +184,35 @@ class StorageUserBase : public BaseT, public Traits... { /// Utility for easy access to the storage instance. ImplType *getImpl() const { return static_cast(this->impl); } + +private: + /// Trait to check if T provides a 'ConcreteEntity' type alias. + template + using has_concrete_entity_t = typename T::ConcreteEntity; + + /// A struct-wrapped type alias to T::ConcreteEntity if provided and to + /// ConcreteT otherwise. This is akin to std::conditional but doesn't fail on + /// the missing typedef. Useful for checking if the interface is targeting the + /// right class. + template ::value> + struct IfaceTargetOrConcreteT { + using type = typename T::ConcreteEntity; + }; + template + struct IfaceTargetOrConcreteT { + using type = ConcreteT; + }; + + /// A hook for static assertion that the external interface model T is + /// targeting a base class of the concrete attribute/type. The model can also + /// be a fallback model that works for every attribute/type. + template + static void checkInterfaceTarget() { + static_assert(std::is_base_of::type, + ConcreteT>::value, + "attaching an interface to the wrong attribute/type kind"); + } }; } // namespace detail } // namespace mlir diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 76a5164eba3c8..a6d3cccdd8770 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -250,6 +250,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) { << ">\n"; os << " class ExternalModel : public FallbackModel {\n"; os << " public:\n"; + os << " using ConcreteEntity = " << valueTemplate << ";\n"; // Emit declarations for methods that have default implementations. Other // methods are expected to be implemented by the concrete derived model.