Skip to content

Commit

Permalink
Refactor OperationName to use virtual tables for dispatch (NFC)
Browse files Browse the repository at this point in the history
This streamlines the implementation and makes it so that the virtual tables are in the binary instead of dynamically assembled during initialization.
The dynamic allocation size of op registration is also smaller with this
change.

Differential Revision: https://reviews.llvm.org/D141492
  • Loading branch information
joker-eph committed Jan 14, 2023
1 parent f72601a commit e055aad
Show file tree
Hide file tree
Showing 12 changed files with 328 additions and 279 deletions.
58 changes: 37 additions & 21 deletions mlir/include/mlir/IR/ExtensibleDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,15 @@ class DynamicType

/// The definition of a dynamic op. A dynamic op is an op that is defined at
/// runtime, and that can be registered at runtime by an extensible dialect (a
/// dialect inheriting ExtensibleDialect). This class stores the functions that
/// are in the OperationName class, and in addition defines the TypeID of the op
/// that will be defined.
/// Each dynamic operation definition refers to one instance of this class.
class DynamicOpDefinition {
/// dialect inheriting ExtensibleDialect). This class implements the method
/// exposed by the OperationName class, and in addition defines the TypeID of
/// the op that will be defined. Each dynamic operation definition refers to one
/// instance of this class.
class DynamicOpDefinition : public OperationName::Impl {
public:
using GetCanonicalizationPatternsFn =
llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;

/// Create a new op at runtime. The op is registered only after passing it to
/// the dialect using registerDynamicOp.
static std::unique_ptr<DynamicOpDefinition>
Expand All @@ -361,8 +364,7 @@ class DynamicOpDefinition {
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
OperationName::GetCanonicalizationPatternsFn
&&getCanonicalizationPatternsFn,
GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);

/// Returns the op typeID.
Expand Down Expand Up @@ -400,9 +402,8 @@ class DynamicOpDefinition {

/// Set the hook returning any canonicalization pattern rewrites that the op
/// supports, for use by the canonicalization pass.
void
setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn
&&getCanonicalizationPatterns) {
void setGetCanonicalizationPatternsFn(
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns) {
getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
}

Expand All @@ -412,6 +413,29 @@ class DynamicOpDefinition {
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}

LogicalResult foldHook(Operation *op, ArrayRef<Attribute> attrs,
SmallVectorImpl<OpFoldResult> &results) final {
return foldHookFn(op, attrs, results);
}
void getCanonicalizationPatterns(RewritePatternSet &set,
MLIRContext *context) final {
getCanonicalizationPatternsFn(set, context);
}
bool hasTrait(TypeID id) final { return false; }
OperationName::ParseAssemblyFn getParseAssemblyFn() final { return parseFn; }
void populateDefaultAttrs(const OperationName &name,
NamedAttrList &attrs) final {
populateDefaultAttrsFn(name, attrs);
}
void printAssembly(Operation *op, OpAsmPrinter &printer,
StringRef name) final {
printFn(op, printer, name);
}
LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); }
LogicalResult verifyRegionInvariants(Operation *op) final {
return verifyRegionFn(op);
}

private:
DynamicOpDefinition(
StringRef name, ExtensibleDialect *dialect,
Expand All @@ -420,26 +444,18 @@ class DynamicOpDefinition {
OperationName::ParseAssemblyFn &&parseFn,
OperationName::PrintAssemblyFn &&printFn,
OperationName::FoldHookFn &&foldHookFn,
OperationName::GetCanonicalizationPatternsFn
&&getCanonicalizationPatternsFn,
GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);

/// Unique identifier for this operation.
TypeID typeID;

/// Name of the operation.
/// The name is prefixed with the dialect name.
std::string name;

/// Dialect defining this operation.
ExtensibleDialect *dialect;
ExtensibleDialect *getdialect();

OperationName::VerifyInvariantsFn verifyFn;
OperationName::VerifyRegionInvariantsFn verifyRegionFn;
OperationName::ParseAssemblyFn parseFn;
OperationName::PrintAssemblyFn printFn;
OperationName::FoldHookFn foldHookFn;
OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;

friend ExtensibleDialect;
Expand Down
12 changes: 1 addition & 11 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ class OpState {
MLIRContext *context) {}

/// This hook populates any unset default attrs.
static void populateDefaultAttrs(const RegisteredOperationName &,
NamedAttrList &) {}
static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {}

protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
Expand Down Expand Up @@ -1831,20 +1830,11 @@ class Op : public OpState, public Traits<ConcreteType>... {
return result;
}

/// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
static OperationName::GetCanonicalizationPatternsFn
getGetCanonicalizationPatternsFn() {
return &ConcreteType::getCanonicalizationPatterns;
}
/// Implementation of `GetHasTraitFn`
static OperationName::HasTraitFn getHasTraitFn() {
return
[](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
}
/// Implementation of `ParseAssemblyFn` OperationName hook.
static OperationName::ParseAssemblyFn getParseAssemblyFn() {
return &ConcreteType::parse;
}
/// Implementation of `PrintAssemblyFn` OperationName hook.
static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
if constexpr (detect_has_print<ConcreteType>::value)
Expand Down
4 changes: 1 addition & 3 deletions mlir/include/mlir/IR/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,9 @@ class alignas(8) Operation final

/// Sets default attributes on unset attributes.
void populateDefaultAttrs() {
if (auto registered = getRegisteredInfo()) {
NamedAttrList attrs(getAttrDictionary());
registered->populateDefaultAttrs(attrs);
name.populateDefaultAttrs(attrs);
setAttrs(attrs.getDictionary(getContext()));
}
}

//===--------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit e055aad

Please sign in to comment.