28 changes: 6 additions & 22 deletions mlir/include/mlir/IR/SymbolInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -174,28 +174,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
return success();
}];

let extraClassDeclaration = [{
/// Convenience version of `getNameAttr` that returns a StringRef.
StringRef getName() {
return getNameAttr().getValue();
}

/// Convenience version of `setName` that take a StringRef.
void setName(StringRef name) {
setName(StringAttr::get(this->getContext(), name));
}

/// Custom classof that handles the case where the symbol is optional.
static bool classof(Operation *op) {
auto *opConcept = getInterfaceFor(op);
if (!opConcept)
return false;
return !opConcept->isOptionalSymbol(opConcept, op) ||
op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
}
}];

let extraTraitClassDeclaration = [{
let extraSharedClassDeclaration = [{
using Visibility = mlir::SymbolTable::Visibility;

/// Convenience version of `getNameAttr` that returns a StringRef.
Expand All @@ -208,6 +187,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
setName(StringAttr::get($_op->getContext(), name));
}
}];

// Add additional classof checks to properly handle "optional" symbols.
let extraClassOf = [{
return $_op->hasAttr(::mlir::SymbolTable::getSymbolAttrName());
}];
}

//===----------------------------------------------------------------------===//
Expand Down
98 changes: 45 additions & 53 deletions mlir/include/mlir/Support/InterfaceSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ class Interface : public BaseType {
"expected value to provide interface instance");
}

/// Constructor for a known concept.
Interface(ValueT t, const Concept *conceptImpl)
: BaseType(t), conceptImpl(const_cast<Concept *>(conceptImpl)) {
assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
}

/// Constructor for DenseMapInfo's empty key and tombstone key.
Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {}

Expand Down Expand Up @@ -146,25 +152,6 @@ struct count_if_t_impl<Pred, N, T, Us...>
template <template <class> class Pred, typename... Ts>
using count_if_t = count_if_t_impl<Pred, 0, Ts...>;

namespace {
/// Type trait indicating whether all template arguments are
/// trivially-destructible.
template <typename... Args>
struct all_trivially_destructible;

template <typename Arg, typename... Args>
struct all_trivially_destructible<Arg, Args...> {
static constexpr const bool value =
std::is_trivially_destructible<Arg>::value &&
all_trivially_destructible<Args...>::value;
};

template <>
struct all_trivially_destructible<> {
static constexpr const bool value = true;
};
} // namespace

/// This class provides an efficient mapping between a given `Interface` type,
/// and a particular implementation of its concept.
class InterfaceMap {
Expand All @@ -176,7 +163,16 @@ class InterfaceMap {
template <typename... Types>
using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;

/// Trait to check if T provides a 'initializeInterfaceConcept' method.
template <typename T, typename... Args>
using has_initialize_method =
decltype(std::declval<T>().initializeInterfaceConcept(
std::declval<InterfaceMap &>()));
template <typename T>
using detect_initialize_method = llvm::is_detected<has_initialize_method, T>;

public:
InterfaceMap() = default;
InterfaceMap(InterfaceMap &&) = default;
InterfaceMap &operator=(InterfaceMap &&rhs) {
for (auto &it : interfaces)
Expand All @@ -199,11 +195,9 @@ class InterfaceMap {
if constexpr (numInterfaces == 0)
return InterfaceMap();

std::array<std::pair<TypeID, void *>, numInterfaces> elements;
std::pair<TypeID, void *> *elementIt = elements.data();
(void)elementIt;
(addModelAndUpdateIterator<Types>(elementIt), ...);
return InterfaceMap(elements);
InterfaceMap map;
(map.insertPotentialInterface<Types>(), ...);
return map;
}

/// Returns an instance of the concept object for the given interface if it
Expand All @@ -216,42 +210,40 @@ class InterfaceMap {
/// Returns true if the interface map contains an interface for the given id.
bool contains(TypeID interfaceID) const { return lookup(interfaceID); }

/// Create an InterfaceMap given with the implementation of the interfaces.
/// The use of this constructor is in general discouraged in favor of
/// 'InterfaceMap::get<InterfaceA, ...>()'.
InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements);

/// Insert the given models as implementations of the corresponding interfaces
/// for the concrete attribute class.
/// Insert the given interface models.
template <typename... IfaceModels>
void insert() {
static_assert(all_trivially_destructible<IfaceModels...>::value,
"interface models must be trivially destructible");
std::pair<TypeID, void *> elements[] = {
std::make_pair(IfaceModels::Interface::getInterfaceID(),
new (malloc(sizeof(IfaceModels))) IfaceModels())...};
insert(elements);
void insertModels() {
(insertModel<IfaceModels>(), ...);
}

private:
InterfaceMap() = default;

/// Assign the interface model of the type to the given opaque element
/// iterator and increment it.
/// Insert the given interface type into the map, ignoring it if it doesn't
/// actually represent an interface.
template <typename T>
static inline std::enable_if_t<detect_get_interface_id<T>::value>
addModelAndUpdateIterator(std::pair<TypeID, void *> *&elementIt) {
*elementIt = {T::getInterfaceID(), new (malloc(sizeof(typename T::ModelT)))
typename T::ModelT()};
++elementIt;
inline void insertPotentialInterface() {
if constexpr (detect_get_interface_id<T>::value)
insertModel<typename T::ModelT>();
}
/// Overload when `T` isn't an interface.
template <typename T>
static inline std::enable_if_t<!detect_get_interface_id<T>::value>
addModelAndUpdateIterator(std::pair<TypeID, void *> *&) {}

/// Insert the given set of interface models into the interface map.
void insert(ArrayRef<std::pair<TypeID, void *>> elements);
/// Insert the given interface model into the map.
template <typename InterfaceModel>
void insertModel() {
// FIXME(#59975): Uncomment this when SPIRV no longer awkwardly reimplements
// interfaces in a way that isn't clean/compatible.
// static_assert(std::is_trivially_destructible_v<InterfaceModel>,
// "interface models must be trivially destructible");

// Build the interface model, optionally initializing if necessary.
InterfaceModel *model =
new (malloc(sizeof(InterfaceModel))) InterfaceModel();
if constexpr (detect_initialize_method<InterfaceModel>::value)
model->initializeInterfaceConcept(*this);

insert(InterfaceModel::Interface::getInterfaceID(), model);
}
/// Insert the given set of interface id and concept implementation into the
/// interface map.
void insert(TypeID interfaceId, void *conceptImpl);

/// Compare two TypeID instances by comparing the underlying pointer.
static bool compare(TypeID lhs, TypeID rhs) {
Expand Down
2 changes: 0 additions & 2 deletions mlir/include/mlir/TableGen/Format.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class FmtContext {
None,
Custom, // For custom placeholders
Builder, // For the $_builder placeholder
Op, // For the $_op placeholder
Self, // For the $_self placeholder
};

Expand All @@ -58,7 +57,6 @@ class FmtContext {

// Setters for builtin placeholders
FmtContext &withBuilder(Twine subst);
FmtContext &withOp(Twine subst);
FmtContext &withSelf(Twine subst);

std::optional<StringRef> getSubstFor(PHKind placeholder) const;
Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/TableGen/Interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator.h"

namespace llvm {
class Init;
Expand Down Expand Up @@ -72,10 +73,17 @@ class InterfaceMethod {
class Interface {
public:
explicit Interface(const llvm::Record *def);
Interface(const Interface &rhs) : def(rhs.def), methods(rhs.methods) {
for (auto &base : rhs.baseInterfaces)
baseInterfaces.push_back(std::make_unique<Interface>(*base));
}

// Return the name of this interface.
StringRef getName() const;

// Returns this interface's name prefixed with namespaces.
std::string getFullyQualifiedName() const;

// Return the C++ namespace of this interface.
StringRef getCppNamespace() const;

Expand All @@ -95,9 +103,17 @@ class Interface {
// trait classes.
std::optional<StringRef> getExtraSharedClassDeclaration() const;

// Return the extra classof method code.
std::optional<StringRef> getExtraClassOf() const;

// Return the verify method body if it has one.
std::optional<StringRef> getVerify() const;

// Return the base interfaces of this interface.
auto getBaseInterfaces() const {
return llvm::make_pointee_range(baseInterfaces);
}

// If there's a verify method, return if it needs to access the ops in the
// regions.
bool verifyWithRegions() const;
Expand All @@ -111,6 +127,9 @@ class Interface {

// The methods of this interface.
SmallVector<InterfaceMethod, 8> methods;

// The base interfaces of this interface.
SmallVector<std::unique_ptr<Interface>> baseInterfaces;
};

// An interface that is registered to an Attribute.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/ExtensibleDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ DynamicOpDefinition::DynamicOpDefinition(
: Impl(StringAttr::get(dialect->getContext(),
(dialect->getNamespace() + "." + name).str()),
dialect, dialect->allocateTypeID(),
/*interfaceMap=*/detail::InterfaceMap(std::nullopt)),
/*interfaceMap=*/detail::InterfaceMap()),
verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
parseFn(std::move(parseFn)), printFn(std::move(printFn)),
foldHookFn(std::move(foldHookFn)),
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
auto nameAttr = StringAttr::get(context, name);
it.first->second = std::make_unique<UnregisteredOpModel>(
nameAttr, nameAttr.getReferencedDialect(), TypeID::get<void>(),
detail::InterfaceMap(std::nullopt));
detail::InterfaceMap());
}
impl = it.first->second.get();
}
Expand Down
31 changes: 10 additions & 21 deletions mlir/lib/Support/InterfaceSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,16 @@

using namespace mlir;

detail::InterfaceMap::InterfaceMap(
MutableArrayRef<std::pair<TypeID, void *>> elements)
: interfaces(elements.begin(), elements.end()) {
llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) {
return compare(lhs.first, rhs.first);
});
}

void detail::InterfaceMap::insert(
ArrayRef<std::pair<TypeID, void *>> elements) {
void detail::InterfaceMap::insert(TypeID interfaceId, void *conceptImpl) {
// Insert directly into the right position to keep the interfaces sorted.
for (auto &element : elements) {
TypeID id = element.first;
auto *it = llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) {
return compare(it.first, id);
});
if (it != interfaces.end() && it->first == id) {
LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
free(element.second);
continue;
}
interfaces.insert(it, element);
auto *it =
llvm::lower_bound(interfaces, interfaceId, [](const auto &it, TypeID id) {
return compare(it.first, id);
});
if (it != interfaces.end() && it->first == interfaceId) {
LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration");
free(conceptImpl);
return;
}
interfaces.insert(it, {interfaceId, conceptImpl});
}
21 changes: 18 additions & 3 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Dialect.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
Expand Down Expand Up @@ -56,9 +57,23 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
traits.reserve(traitSet.size());
for (auto *traitInit : *traitList)
if (traitSet.insert(traitInit).second)
traits.push_back(Trait::create(traitInit));
llvm::unique_function<void(llvm::ListInit *)> processTraitList =
[&](llvm::ListInit *traitList) {
for (auto *traitInit : *traitList) {
if (!traitSet.insert(traitInit).second)
continue;

// If this is an interface, add any bases to the trait list.
auto *traitDef = cast<llvm::DefInit>(traitInit)->getDef();
if (traitDef->isSubClassOf("Interface")) {
if (auto *bases = traitDef->getValueAsListInit("baseInterfaces"))
processTraitList(bases);
}

traits.push_back(Trait::create(traitInit));
}
};
processTraitList(traitList);
}

// Populate the parameters.
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/TableGen/CodeGenHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ void StaticVerifierFunctionEmitter::emitConstraints(
const ConstraintMap &constraints, StringRef selfName,
const char *const codeTemplate) {
FmtContext ctx;
ctx.withOp("*op").withSelf(selfName);
ctx.addSubst("_op", "*op").withSelf(selfName);
for (auto &it : constraints) {
os << formatv(codeTemplate, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
Expand All @@ -216,7 +216,7 @@ void StaticVerifierFunctionEmitter::emitRegionConstraints() {

void StaticVerifierFunctionEmitter::emitPatternConstraints() {
FmtContext ctx;
ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
for (auto &it : typeConstraints) {
os << formatv(patternAttrOrTypeConstraintCode, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
Expand All @@ -240,9 +240,9 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
/// because ops use cached identifiers.
static bool canUniqueAttrConstraint(Attribute attr) {
FmtContext ctx;
auto test =
tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
.str();
auto test = tgfmt(attr.getConditionTemplate(),
&ctx.withSelf("attr").addSubst("_op", "*op"))
.str();
return !StringRef(test).contains("<no-subst-found>");
}

Expand Down
6 changes: 0 additions & 6 deletions mlir/lib/TableGen/Format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ FmtContext &FmtContext::withBuilder(Twine subst) {
return *this;
}

FmtContext &FmtContext::withOp(Twine subst) {
builtinSubstMap[PHKind::Op] = subst.str();
return *this;
}

FmtContext &FmtContext::withSelf(Twine subst) {
builtinSubstMap[PHKind::Self] = subst.str();
return *this;
Expand All @@ -69,7 +64,6 @@ std::optional<StringRef> FmtContext::getSubstFor(StringRef placeholder) const {
FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
return StringSwitch<FmtContext::PHKind>(str)
.Case("_builder", FmtContext::PHKind::Builder)
.Case("_op", FmtContext::PHKind::Op)
.Case("_self", FmtContext::PHKind::Self)
.Case("", FmtContext::PHKind::None)
.Default(FmtContext::PHKind::Custom);
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/TableGen/Interfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
Expand Down Expand Up @@ -74,16 +75,41 @@ Interface::Interface(const llvm::Record *def) : def(def) {
assert(def->isSubClassOf("Interface") &&
"must be subclass of TableGen 'Interface' class");

// Initialize the interface methods.
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());

// Initialize the interface base classes.
auto *basesInit =
dyn_cast<llvm::ListInit>(def->getValueInit("baseInterfaces"));
llvm::unique_function<void(Interface)> addBaseInterfaceFn =
[&](const Interface &baseInterface) {
// Inherit any base interfaces.
for (const auto &baseBaseInterface : baseInterface.getBaseInterfaces())
addBaseInterfaceFn(baseBaseInterface);

// Add the base interface.
baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
};
for (llvm::Init *init : basesInit->getValues())
addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
}

// Return the name of this interface.
StringRef Interface::getName() const {
return def->getValueAsString("cppInterfaceName");
}

// Returns this interface's name prefixed with namespaces.
std::string Interface::getFullyQualifiedName() const {
StringRef cppNamespace = getCppNamespace();
StringRef name = getName();
if (cppNamespace.empty())
return name.str();
return (cppNamespace + "::" + name).str();
}

// Return the C++ namespace of this interface.
StringRef Interface::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
Expand Down Expand Up @@ -116,6 +142,11 @@ std::optional<StringRef> Interface::getExtraSharedClassDeclaration() const {
return value.empty() ? std::optional<StringRef>() : value;
}

std::optional<StringRef> Interface::getExtraClassOf() const {
auto value = def->getValueAsString("extraClassOf");
return value.empty() ? std::optional<StringRef>() : value;
}

// Return the body for this method if it has one.
std::optional<StringRef> Interface::getVerify() const {
// Only OpInterface supports the verify method.
Expand Down
14 changes: 10 additions & 4 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,13 +711,19 @@ void Operator::populateOpStructure() {
continue;
}

// Ignore duplicates.
if (!traitSet.insert(traitInit).second)
continue;

// If this is an interface with base classes, add the bases to the
// trait list.
if (def->isSubClassOf("Interface"))
insert(def->getValueAsListInit("baseInterfaces"));

// Verify if the trait has all the dependent traits declared before
// itself.
verifyTraitValidity(def);

// Keep traits in the same order while skipping over duplicates.
if (traitSet.insert(traitInit).second)
traits.push_back(Trait::create(traitInit));
traits.push_back(Trait::create(traitInit));
}
};
insert(traitList);
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Func/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,12 @@ func.func @$invalid_function_name()
// -----

// expected-error @+1 {{arguments may only have dialect attributes}}
func.func @invalid_func_arg_attr(i1 {non_dialect_attr = 10})
func.func private @invalid_func_arg_attr(i1 {non_dialect_attr = 10})

// -----

// expected-error @+1 {{results may only have dialect attributes}}
func.func @invalid_func_result_attr() -> (i1 {non_dialect_attr = 10})
func.func private @invalid_func_result_attr() -> (i1 {non_dialect_attr = 10})

// -----

Expand Down
22 changes: 18 additions & 4 deletions mlir/test/lib/Dialect/Test/TestInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,35 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaceBase.td"

// A type interface used to test the ODS generation of type interfaces.
def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
// A set of type interfaces used to test interface inheritance.
def TestBaseTypeInterfacePrintTypeA : TypeInterface<"TestBaseTypeInterfacePrintTypeA"> {
let cppNamespace = "::test";
let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeA", (ins "::mlir::Location":$loc), [{
emitRemark(loc) << $_type << " - TestA";
}]
>,
>
];
}
def TestBaseTypeInterfacePrintTypeB
: TypeInterface<"TestBaseTypeInterfacePrintTypeB", [TestBaseTypeInterfacePrintTypeA]> {
let cppNamespace = "::test";
let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeB", (ins "::mlir::Location":$loc),
[{}], /*defaultImplementation=*/[{
emitRemark(loc) << $_type << " - TestB";
}]
>,
>
];
}

// A type interface used to test the ODS generation of type interfaces.
def TestTypeInterface
: TypeInterface<"TestTypeInterface", [TestBaseTypeInterfacePrintTypeB]> {
let cppNamespace = "::test";
let methods = [
InterfaceMethod<"Prints the type name.",
"void", "printTypeC", (ins "::mlir::Location":$loc)
>,
Expand Down
77 changes: 77 additions & 0 deletions mlir/test/mlir-tblgen/op-interface.td
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
// RUN: mlir-tblgen -gen-op-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-op-interface-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s --check-prefix=OP_DECL
// RUN: mlir-tblgen -gen-op-interface-docs -I %S/../../include %s | FileCheck %s --check-prefix=DOCS

include "mlir/IR/OpBase.td"

def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> {
let extraClassOf = "return $_op->someOtherMethod();";
}

// DECL: class ExtraClassOfInterface
// DECL: static bool classof(::mlir::Operation * base) {
// DECL-NEXT: if (!getInterfaceFor(base))
// DECL-NEXT: return false;
// DECL-NEXT: return base->someOtherMethod();
// DECL-NEXT: }

def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
let extraSharedClassDeclaration = [{
bool sharedMethodDeclaration() {
Expand All @@ -22,6 +34,66 @@ def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
// DECL-NEXT: return (*static_cast<ConcreteOp *>(this)).someOtherMethod();
// DECL-NEXT: }

def TestInheritanceBaseInterface : OpInterface<"TestInheritanceBaseInterface"> {
let methods = [
InterfaceMethod<
/*desc=*/[{some function comment}],
/*retTy=*/"int",
/*methodName=*/"foo",
/*args=*/(ins "int":$input)
>
];
}
def TestInheritanceMiddleBaseInterface
: OpInterface<"TestInheritanceMiddleBaseInterface", [TestInheritanceBaseInterface]> {
let methods = [
InterfaceMethod<
/*desc=*/[{some function comment}],
/*retTy=*/"int",
/*methodName=*/"bar",
/*args=*/(ins "int":$input)
>
];
}
def TestInheritanceZDerivedInterface
: OpInterface<"TestInheritanceZDerivedInterface", [TestInheritanceMiddleBaseInterface]>;

// DECL: class TestInheritanceZDerivedInterface
// DECL: struct Concept {
// DECL: const TestInheritanceBaseInterface::Concept *implTestInheritanceBaseInterface = nullptr;
// DECL: const TestInheritanceMiddleBaseInterface::Concept *implTestInheritanceMiddleBaseInterface = nullptr;

// DECL: void initializeInterfaceConcept(::mlir::detail::InterfaceMap &interfaceMap) {
// DECL: implTestInheritanceBaseInterface = interfaceMap.lookup<TestInheritanceBaseInterface>();
// DECL: assert(implTestInheritanceBaseInterface && "`TestInheritanceZDerivedInterface` expected its base interface `TestInheritanceBaseInterface` to be registered");
// DECL: implTestInheritanceMiddleBaseInterface = interfaceMap.lookup<TestInheritanceMiddleBaseInterface>();
// DECL: assert(implTestInheritanceMiddleBaseInterface
// DECL: }

// DECL: //===----------------------------------------------------------------===//
// DECL: // Inherited from TestInheritanceBaseInterface
// DECL: //===----------------------------------------------------------------===//
// DECL: operator TestInheritanceBaseInterface () const {
// DECL: return TestInheritanceBaseInterface(*this, getImpl()->implTestInheritanceBaseInterface);
// DECL: }
// DECL: /// some function comment
// DECL: int foo(int input);

// DECL: //===----------------------------------------------------------------===//
// DECL: // Inherited from TestInheritanceMiddleBaseInterface
// DECL: //===----------------------------------------------------------------===//
// DECL: operator TestInheritanceMiddleBaseInterface () const {
// DECL: return TestInheritanceMiddleBaseInterface(*this, getImpl()->implTestInheritanceMiddleBaseInterface);
// DECL: }
// DECL: /// some function comment
// DECL: int bar(int input);

// DEF: int TestInheritanceZDerivedInterface::foo(int input) {
// DEF-NEXT: getImpl()->implTestInheritanceBaseInterface->foo(getImpl()->implTestInheritanceBaseInterface, getOperation(), input);

// DEF: int TestInheritanceZDerivedInterface::bar(int input) {
// DEF-NEXT: return getImpl()->implTestInheritanceMiddleBaseInterface->bar(getImpl()->implTestInheritanceMiddleBaseInterface, getOperation(), input);

def TestOpInterface : OpInterface<"TestOpInterface"> {
let description = [{some op interface description}];

Expand Down Expand Up @@ -72,6 +144,8 @@ def TestDialect : Dialect {

def OpInterfaceOp : Op<TestDialect, "op_interface_op", [TestOpInterface]>;

def OpInterfaceInterfacesOp : Op<TestDialect, "op_inherit_interface_op", [TestInheritanceZDerivedInterface]>;

def DeclareMethodsOp : Op<TestDialect, "declare_methods_op",
[DeclareOpInterfaceMethods<TestOpInterface>]>;

Expand Down Expand Up @@ -102,6 +176,9 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
// OP_DECL: int foo(int input);
// OP_DECL: int default_foo(int input);

// OP_DECL: class OpInterfaceInterfacesOp :
// OP_DECL-SAME: TestInheritanceBaseInterface::Trait, TestInheritanceMiddleBaseInterface::Trait, TestInheritanceZDerivedInterface::Trait

// DOCS-LABEL: {{^}}## TestOpInterface (`TestOpInterface`)
// DOCS: some op interface description

Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ OpEmitter::OpEmitter(const Operator &op,
formatExtraDefinitions(op)),
staticVerifierEmitter(staticVerifierEmitter),
emitHelper(op, /*emitForOp=*/true) {
verifyCtx.withOp("(*this->getOperation())");
verifyCtx.addSubst("_op", "(*this->getOperation())");
verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");

genTraits();
Expand Down
189 changes: 139 additions & 50 deletions mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,19 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
/// Get an array of all OpInterface definitions but exclude those subclassing
/// "DeclareOpInterfaceMethods".
static std::vector<llvm::Record *>
getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) {
getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
StringRef name) {
std::vector<llvm::Record *> defs =
recordKeeper.getAllDerivedDefinitions("OpInterface");

llvm::erase_if(defs, [](const llvm::Record *def) {
return def->isSubClassOf("DeclareOpInterfaceMethods");
recordKeeper.getAllDerivedDefinitions((name + "Interface").str());

std::string declareName = ("Declare" + name + "InterfaceMethods").str();
llvm::erase_if(defs, [&](const llvm::Record *def) {
// Ignore any "declare methods" interfaces.
if (def->isSubClassOf(declareName))
return true;
// Ignore interfaces defined outside of the top-level file.
return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
llvm::SrcMgr.getMainFileID();
});
return defs;
}
Expand Down Expand Up @@ -101,6 +108,8 @@ class InterfaceGenerator {
StringRef interfaceBaseType;
/// The name of the typename for the value template.
StringRef valueTemplate;
/// The name of the substituion variable for the value.
StringRef substVar;
/// The format context to use for methods.
tblgen::FmtContext nonStaticMethodFmt;
tblgen::FmtContext traitMethodFmt;
Expand All @@ -110,46 +119,47 @@ class InterfaceGenerator {
/// A specialized generator for attribute interfaces.
struct AttrInterfaceGenerator : public InterfaceGenerator {
AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"),
os) {
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
valueType = "::mlir::Attribute";
interfaceBaseType = "AttributeInterface";
valueTemplate = "ConcreteAttr";
substVar = "_attr";
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
traitMethodFmt.addSubst("_attr",
nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
traitMethodFmt.addSubst(substVar,
"(*static_cast<const ConcreteAttr *>(this))");
extraDeclsFmt.addSubst("_attr", "(*this)");
extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
/// A specialized generator for operation interfaces.
struct OpInterfaceGenerator : public InterfaceGenerator {
OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) {
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
valueType = "::mlir::Operation *";
interfaceBaseType = "OpInterface";
valueTemplate = "ConcreteOp";
substVar = "_op";
StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
nonStaticMethodFmt.addSubst("_this", "impl")
.withOp(castCode)
.addSubst(substVar, castCode)
.withSelf(castCode);
traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
extraDeclsFmt.withOp("(*this)");
traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
/// A specialized generator for type interfaces.
struct TypeInterfaceGenerator : public InterfaceGenerator {
TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"),
os) {
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
valueType = "::mlir::Type";
interfaceBaseType = "TypeInterface";
valueTemplate = "ConcreteType";
substVar = "_type";
StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
traitMethodFmt.addSubst("_type",
nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
traitMethodFmt.addSubst(substVar,
"(*static_cast<const ConcreteType *>(this))");
extraDeclsFmt.addSubst("_type", "(*this)");
extraDeclsFmt.addSubst(substVar, "(*this)");
}
};
} // namespace
Expand All @@ -163,28 +173,21 @@ static void emitInterfaceMethodDoc(const InterfaceMethod &method,
if (std::optional<StringRef> description = method.getDescription())
tblgen::emitDescriptionComment(*description, os, prefix);
}

static void emitInterfaceDef(const Interface &interface, StringRef valueType,
raw_ostream &os) {
StringRef interfaceName = interface.getName();
StringRef cppNamespace = interface.getCppNamespace();
cppNamespace.consume_front("::");

// Insert the method definitions.
bool isOpInterface = isa<OpInterface>(interface);
static void emitInterfaceDefMethods(StringRef interfaceQualName,
const Interface &interface,
StringRef valueType, const Twine &implValue,
raw_ostream &os, bool isOpInterface) {
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os);
emitCPPType(method.getReturnType(), os);
if (!cppNamespace.empty())
os << cppNamespace << "::";
os << interfaceName << "::";
os << interfaceQualName << "::";
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface);

// Forward to the method on the concrete operation type.
os << " {\n return getImpl()->" << method.getName() << '(';
os << " {\n return " << implValue << "->" << method.getName() << '(';
if (!method.isStatic()) {
os << "getImpl(), ";
os << implValue << ", ";
os << (isOpInterface ? "getOperation()" : "*this");
os << (method.arg_empty() ? "" : ", ");
}
Expand All @@ -195,6 +198,25 @@ static void emitInterfaceDef(const Interface &interface, StringRef valueType,
}
}

static void emitInterfaceDef(const Interface &interface, StringRef valueType,
raw_ostream &os) {
std::string interfaceQualNameStr = interface.getFullyQualifiedName();
StringRef interfaceQualName = interfaceQualNameStr;
interfaceQualName.consume_front("::");

// Insert the method definitions.
bool isOpInterface = isa<OpInterface>(interface);
emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()",
os, isOpInterface);

// Insert the method definitions for base classes.
for (auto &base : interface.getBaseInterfaces()) {
emitInterfaceDefMethods(interfaceQualName, base, valueType,
"getImpl()->impl" + base.getName(), os,
isOpInterface);
}
}

bool InterfaceGenerator::emitInterfaceDefs() {
llvm::emitSourceFileHeader("Interface Definitions", os);

Expand All @@ -211,6 +233,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
os << " struct Concept {\n";

// Insert each of the pure virtual concept methods.
os << " /// The methods defined by the interface.\n";
for (auto &method : interface.getMethods()) {
os << " ";
emitCPPType(method.getReturnType(), os);
Expand All @@ -224,6 +247,33 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
[&](const InterfaceMethod::Argument &arg) { os << arg.type; });
os << ");\n";
}

// Insert a field containing a concept for each of the base interfaces.
auto baseInterfaces = interface.getBaseInterfaces();
if (!baseInterfaces.empty()) {
os << " /// The base classes of this interface.\n";
for (const auto &base : interface.getBaseInterfaces()) {
os << " const " << base.getFullyQualifiedName() << "::Concept *impl"
<< base.getName() << " = nullptr;\n";
}

// Define an "initialize" method that allows for the initialization of the
// base class concepts.
os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
"&interfaceMap) {\n";
std::string interfaceQualName = interface.getFullyQualifiedName();
for (const auto &base : interface.getBaseInterfaces()) {
StringRef baseName = base.getName();
std::string baseQualName = base.getFullyQualifiedName();
os << " impl" << baseName << " = interfaceMap.lookup<"
<< baseQualName << ">();\n"
<< " assert(impl" << baseName << " && \"`" << interfaceQualName
<< "` expected its base interface `" << baseQualName
<< "` to be registered\");\n";
}
os << " }\n";
}

os << " };\n";
}

Expand All @@ -232,9 +282,8 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
for (const char *modelClass : {"Model", "FallbackModel"}) {
os << " template<typename " << valueTemplate << ">\n";
os << " class " << modelClass << " : public Concept {\n public:\n";
os << " using Interface = " << interface.getCppNamespace()
<< (interface.getCppNamespace().empty() ? "" : "::")
<< interface.getName() << ";\n";
os << " using Interface = " << interface.getFullyQualifiedName()
<< ";\n";
os << " " << modelClass << "() : Concept{";
llvm::interleaveComma(
interface.getMethods(), os,
Expand Down Expand Up @@ -429,7 +478,7 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");

tblgen::FmtContext verifyCtx;
verifyCtx.withOp("op");
verifyCtx.addSubst("_op", "op");
os << llvm::formatv(
" static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
(interface.verifyWithRegions() ? "verifyRegionTrait"
Expand All @@ -445,6 +494,27 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
os << " };\n";
}

static void emitInterfaceDeclMethods(const Interface &interface,
raw_ostream &os, StringRef valueType,
bool isOpInterface,
tblgen::FmtContext &extraDeclsFmt) {
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os, " ");
emitCPPType(method.getReturnType(), os << " ");
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
os << ";\n";
}

// Emit any extra declarations.
if (std::optional<StringRef> extraDecls =
interface.getExtraClassDeclaration())
os << extraDecls->rtrim() << "\n";
if (std::optional<StringRef> extraDecls =
interface.getExtraSharedClassDeclaration())
os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
}

void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
Expand Down Expand Up @@ -485,21 +555,40 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {

// Insert the method declarations.
bool isOpInterface = isa<OpInterface>(interface);
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os, " ");
emitCPPType(method.getReturnType(), os << " ");
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
os << ";\n";
emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
extraDeclsFmt);

// Insert the method declarations for base classes.
for (auto &base : interface.getBaseInterfaces()) {
std::string baseQualName = base.getFullyQualifiedName();
os << " //"
"===---------------------------------------------------------------"
"-===//\n"
<< " // Inherited from " << baseQualName << "\n"
<< " //"
"===---------------------------------------------------------------"
"-===//\n\n";

// Allow implicit conversion to the base interface.
os << " operator " << baseQualName << " () const {\n"
<< " return " << baseQualName << "(*this, getImpl()->impl"
<< base.getName() << ");\n"
<< " }\n\n";

// Inherit the base interface's methods.
emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt);
}

// Emit any extra declarations.
if (std::optional<StringRef> extraDecls =
interface.getExtraClassDeclaration())
os << *extraDecls << "\n";
if (std::optional<StringRef> extraDecls =
interface.getExtraSharedClassDeclaration())
os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt);
// Emit classof code if necessary.
if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
auto extraClassOfFmt = tblgen::FmtContext();
extraClassOfFmt.addSubst(substVar, "base");
os << " static bool classof(" << valueType << " base) {\n"
<< " if (!getInterfaceFor(base))\n"
" return false;\n"
<< " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
<< "\n }\n";
}

os << "};\n";

Expand Down
3 changes: 3 additions & 0 deletions mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ static void emitModelDecl(const Availability &availability, raw_ostream &os) {
os << " template<typename ConcreteOp>\n";
os << " class " << modelClass << " : public Concept {\n"
<< " public:\n"
<< " using Interface = " << availability.getInterfaceClassName()
<< ";\n"
<< " " << availability.getQueryFnRetType() << " "
<< availability.getQueryFnName()
<< "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
Expand All @@ -258,6 +260,7 @@ static void emitInterfaceDecl(const Availability &availability,

StringRef cppNamespace = availability.getInterfaceClassNamespace();
NamespaceEmitter nsEmitter(os, cppNamespace);
os << "class " << interfaceName << ";\n\n";

// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
Expand Down
6 changes: 0 additions & 6 deletions mlir/unittests/TableGen/FormatTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,6 @@ TEST(FormatTest, PlaceHolderFmtStrWithBuilder) {
EXPECT_THAT(result, StrEq("bbb"));
}

TEST(FormatTest, PlaceHolderFmtStrWithOp) {
FmtContext ctx;
std::string result = std::string(tgfmt("$_op", &ctx.withOp("ooo")));
EXPECT_THAT(result, StrEq("ooo"));
}

TEST(FormatTest, PlaceHolderMissingCtx) {
std::string result = std::string(tgfmt("$_op", nullptr));
EXPECT_THAT(result, StrEq("$_op<no-subst-found>"));
Expand Down