Skip to content

Commit

Permalink
[mlir] clean up transform dialect definitions, NFC
Browse files Browse the repository at this point in the history
Refactor the definition of the Transform dialect to move non-trivial
method implementations out of the .td file, and detemplatize functions
when possible while moving their implementations to a .cpp.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135165
  • Loading branch information
ftynse committed Oct 11, 2022
1 parent bba85eb commit b586d56
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 47 deletions.
43 changes: 27 additions & 16 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
Expand All @@ -23,21 +22,7 @@ namespace detail {
/// Asserts that the operations provided as template arguments implement the
/// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
/// assertion since interface implementations may be registered at runtime.
template <typename OpTy>
static inline void checkImplementsTransformInterface(MLIRContext *context) {
// Since the operation is being inserted into the Transform dialect and the
// dialect does not implement the interface fallback, only check for the op
// itself having the interface implementation.
RegisteredOperationName opName =
*RegisteredOperationName::lookup(OpTy::getOperationName(), context);
assert((opName.hasInterface<TransformOpInterface>() ||
opName.hasTrait<OpTrait::IsTerminator>()) &&
"non-terminator ops injected into the transform dialect must "
"implement TransformOpInterface");
assert(opName.hasInterface<MemoryEffectOpInterface>() &&
"ops injected into the transform dialect must implement "
"MemoryEffectsOpInterface");
}
void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context);

/// Asserts that the type provided as template argument implements the
/// TransformTypeInterface. This must be a dynamic assertion since interface
Expand Down Expand Up @@ -200,6 +185,25 @@ class TransformDialectExtension
bool buildOnly;
};

template <typename OpTy>
void TransformDialect::addOperationIfNotRegistered() {
StringRef name = OpTy::getOperationName();
Optional<RegisteredOperationName> opName =
RegisteredOperationName::lookup(name, getContext());
if (!opName) {
addOperations<OpTy>();
#ifndef NDEBUG
detail::checkImplementsTransformOpInterface(name, getContext());
#endif // NDEBUG
return;
}

if (opName->getTypeID() == TypeID::get<OpTy>())
return;

reportDuplicateOpRegistration(name);
}

template <typename Type>
void TransformDialect::addTypeIfNotRegistered() {
// Use the address of the parse method as a proxy for identifying whether we
Expand All @@ -210,13 +214,20 @@ void TransformDialect::addTypeIfNotRegistered() {
const ExtensionTypeParsingHook &parsingHook = it->getValue();
if (*parsingHook.target<mlir::Type (*)(AsmParser &)>() != &Type::parse)
reportDuplicateTypeRegistration(mnemonic);
else
return;
}
typePrintingHooks.try_emplace(
TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
printer << Type::getMnemonic();
cast<Type>(type).print(printer);
});
addTypes<Type>();

#ifndef NDEBUG
detail::checkImplementsTransformTypeInterface(TypeID::get<Type>(),
getContext());
#endif // NDEBUG
}

/// A wrapper for transform dialect extensions that forces them to be
Expand Down
36 changes: 6 additions & 30 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -333,49 +333,25 @@ def Transform_Dialect : Dialect {
std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;

private:
template <typename OpTy>
void addOperationIfNotRegistered() {
Optional<RegisteredOperationName> opName =
RegisteredOperationName::lookup(OpTy::getOperationName(),
getContext());
if (!opName)
return addOperations<OpTy>();

if (opName->getTypeID() == TypeID::get<OpTy>())
return;

llvm::errs() << "error: extensible dialect operation '"
<< OpTy::getOperationName()
<< "' is already registered with a mismatching TypeID";
abort();
}

/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
template <typename... OpTys>
void addOperationsChecked() {
(addOperationIfNotRegistered<OpTys>(),...);

#ifndef NDEBUG
(detail::checkImplementsTransformInterface<OpTys>(getContext()),...);
#endif // NDEBUG
(addOperationIfNotRegistered<OpTys>(), ...);
}
template <typename OpTy>
void addOperationIfNotRegistered();

/// Reports a repeated registration error of an op with the given name.
[[noreturn]] void reportDuplicateOpRegistration(StringRef opName);

/// Registers the types specified as template parameters with the
/// Transform dialect. Checks that they meet the requirements for
/// Transform IR types.
template <typename... TypeTys>
void addTypesChecked() {
(addTypeIfNotRegistered<TypeTys>(), ...);

#ifndef NDEBUG
(detail::checkImplementsTransformTypeInterface(
TypeID::get<TypeTys>(), getContext()), ...);
#endif // NDEBUG
}

/// Implementation of the type registration for a single type, should
/// not be called directly, use addTypesChecked instead.
template <typename Type>
void addTypeIfNotRegistered();

Expand Down
29 changes: 28 additions & 1 deletion mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/DialectImplementation.h"
Expand All @@ -18,6 +19,22 @@ using namespace mlir;
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"

#ifndef NDEBUG
void transform::detail::checkImplementsTransformOpInterface(
StringRef name, MLIRContext *context) {
// Since the operation is being inserted into the Transform dialect and the
// dialect does not implement the interface fallback, only check for the op
// itself having the interface implementation.
RegisteredOperationName opName =
*RegisteredOperationName::lookup(name, context);
assert((opName.hasInterface<TransformOpInterface>() ||
opName.hasTrait<OpTrait::IsTerminator>()) &&
"non-terminator ops injected into the transform dialect must "
"implement TransformOpInterface");
assert(opName.hasInterface<MemoryEffectOpInterface>() &&
"ops injected into the transform dialect must implement "
"MemoryEffectsOpInterface");
}

void transform::detail::checkImplementsTransformTypeInterface(
TypeID typeID, MLIRContext *context) {
const auto &abstractType = AbstractType::lookup(typeID, context);
Expand Down Expand Up @@ -76,10 +93,20 @@ void transform::TransformDialect::reportDuplicateTypeRegistration(
StringRef mnemonic) {
std::string buffer;
llvm::raw_string_ostream msg(buffer);
msg << "error: extensible dialect type '" << mnemonic
msg << "extensible dialect type '" << mnemonic
<< "' is already registered with a different implementation";
msg.flush();
llvm::report_fatal_error(StringRef(buffer));
}

void transform::TransformDialect::reportDuplicateOpRegistration(
StringRef opName) {
std::string buffer;
llvm::raw_string_ostream msg(buffer);
msg << "extensible dialect operation '" << opName
<< "' is already registered with a mismatching TypeID";
msg.flush();
llvm::report_fatal_error(StringRef(buffer));
}

#include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"

0 comments on commit b586d56

Please sign in to comment.