Skip to content

Commit edc6c0e

Browse files
committed
[mlir] Refactor AbstractOperation and OperationName
The current implementation is quite clunky; OperationName stores either an Identifier or an AbstractOperation that corresponds to an operation. This has several problems: * OperationNames created before and after an operation are registered are different * Accessing the identifier name/dialect/etc. from an OperationName are overly branchy - they need to dyn_cast a PointerUnion to check the state This commit refactors this such that we create a single information struct for every operation name, even operations that aren't registered yet. When an OperationName is created for an unregistered operation, we only populate the name field. When the operation is registered, we populate the remaining fields. With this we now have two new classes: OperationName and RegisteredOperationName. These both point to the same underlying operation information struct, but only RegisteredOperationName can assume that the operation is actually registered. This leads to a much cleaner API, and we can also move some AbstractOperation functionality directly to OperationName. Differential Revision: https://reviews.llvm.org/D114049
1 parent 286094a commit edc6c0e

File tree

20 files changed

+455
-432
lines changed

20 files changed

+455
-432
lines changed

mlir/include/mlir/IR/Builders.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ class OpBuilder : public Builder {
408408

409409
private:
410410
/// Helper for sanity checking preconditions for create* methods below.
411-
void checkHasAbstractOperation(const OperationName &name) {
412-
if (LLVM_UNLIKELY(!name.getAbstractOperation()))
411+
void checkHasRegisteredInfo(const OperationName &name) {
412+
if (LLVM_UNLIKELY(!name.isRegistered()))
413413
llvm::report_fatal_error(
414414
"Building op `" + name.getStringRef() +
415415
"` but it isn't registered in this MLIRContext: the dialect may not "
@@ -423,7 +423,7 @@ class OpBuilder : public Builder {
423423
template <typename OpTy, typename... Args>
424424
OpTy create(Location location, Args &&...args) {
425425
OperationState state(location, OpTy::getOperationName());
426-
checkHasAbstractOperation(state.name);
426+
checkHasRegisteredInfo(state.name);
427427
OpTy::build(*this, state, std::forward<Args>(args)...);
428428
auto *op = createOperation(state);
429429
auto result = dyn_cast<OpTy>(op);
@@ -440,7 +440,7 @@ class OpBuilder : public Builder {
440440
// Create the operation without using 'createOperation' as we don't want to
441441
// insert it yet.
442442
OperationState state(location, OpTy::getOperationName());
443-
checkHasAbstractOperation(state.name);
443+
checkHasRegisteredInfo(state.name);
444444
OpTy::build(*this, state, std::forward<Args>(args)...);
445445
Operation *op = Operation::create(state);
446446

mlir/include/mlir/IR/Dialect.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class Dialect {
114114

115115
/// Return the hook to parse an operation registered to this dialect, if any.
116116
/// By default this will lookup for registered operations and return the
117-
/// `parse()` method registered on the AbstractOperation. Dialects can
117+
/// `parse()` method registered on the RegisteredOperationName. Dialects can
118118
/// override this behavior and handle unregistered operations as well.
119119
virtual Optional<ParseOpHook> getParseOperationHook(StringRef opName) const;
120120

@@ -194,7 +194,7 @@ class Dialect {
194194
///
195195
template <typename... Args> void addOperations() {
196196
(void)std::initializer_list<int>{
197-
0, (AbstractOperation::insert<Args>(*this), 0)...};
197+
0, (RegisteredOperationName::insert<Args>(*this), 0)...};
198198
}
199199

200200
/// Register a set of type classes with this dialect.

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ class ThreadPool;
2020
} // end namespace llvm
2121

2222
namespace mlir {
23-
class AbstractOperation;
2423
class DebugActionManager;
2524
class DiagnosticEngine;
2625
class Dialect;
2726
class DialectRegistry;
2827
class InFlightDiagnostic;
2928
class Location;
3029
class MLIRContextImpl;
30+
class RegisteredOperationName;
3131
class StorageUniquer;
3232

3333
/// MLIRContext is the top-level object for a collection of MLIR operations. It
@@ -172,7 +172,7 @@ class MLIRContext {
172172
/// Return information about all registered operations. This isn't very
173173
/// efficient: typically you should ask the operations about their properties
174174
/// directly.
175-
std::vector<AbstractOperation *> getRegisteredOperations();
175+
std::vector<RegisteredOperationName> getRegisteredOperations();
176176

177177
/// Return true if this operation name is registered in this context.
178178
bool isOperationRegistered(StringRef name);

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class OpState {
191191
Operation *state;
192192

193193
/// Allow access to internal hook implementation methods.
194-
friend AbstractOperation;
194+
friend RegisteredOperationName;
195195
};
196196

197197
// Allow comparing operators.
@@ -1585,8 +1585,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
15851585

15861586
/// Return true if this "op class" can match against the specified operation.
15871587
static bool classof(Operation *op) {
1588-
if (auto *abstractOp = op->getAbstractOperation())
1589-
return TypeID::get<ConcreteType>() == abstractOp->typeID;
1588+
if (auto info = op->getRegisteredInfo())
1589+
return TypeID::get<ConcreteType>() == info->getTypeID();
15901590
#ifndef NDEBUG
15911591
if (op->getName().getStringRef() == ConcreteType::getOperationName())
15921592
llvm::report_fatal_error(
@@ -1628,13 +1628,13 @@ class Op : public OpState, public Traits<ConcreteType>... {
16281628
/// for the concrete operation.
16291629
template <typename... Models>
16301630
static void attachInterface(MLIRContext &context) {
1631-
AbstractOperation *abstract = AbstractOperation::lookupMutable(
1631+
Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
16321632
ConcreteType::getOperationName(), &context);
1633-
if (!abstract)
1633+
if (!info)
16341634
llvm::report_fatal_error(
16351635
"Attempting to attach an interface to an unregistered operation " +
16361636
ConcreteType::getOperationName() + ".");
1637-
abstract->interfaceMap.insert<Models...>();
1637+
info->attachInterface<Models...>();
16381638
}
16391639

16401640
private:
@@ -1673,10 +1673,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
16731673
return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
16741674
}
16751675

1676-
/// Return the internal implementations of each of the AbstractOperation
1676+
/// Return the internal implementations of each of the OperationName
16771677
/// hooks.
1678-
/// Implementation of `FoldHookFn` AbstractOperation hook.
1679-
static AbstractOperation::FoldHookFn getFoldHookFn() {
1678+
/// Implementation of `FoldHookFn` OperationName hook.
1679+
static OperationName::FoldHookFn getFoldHookFn() {
16801680
return getFoldHookFnImpl<ConcreteType>();
16811681
}
16821682
/// The internal implementation of `getFoldHookFn` above that is invoked if
@@ -1685,7 +1685,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
16851685
static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
16861686
Traits<ConcreteOpT>...>::value &&
16871687
detect_has_single_result_fold<ConcreteOpT>::value,
1688-
AbstractOperation::FoldHookFn>
1688+
OperationName::FoldHookFn>
16891689
getFoldHookFnImpl() {
16901690
return [](Operation *op, ArrayRef<Attribute> operands,
16911691
SmallVectorImpl<OpFoldResult> &results) {
@@ -1698,7 +1698,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
16981698
static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
16991699
Traits<ConcreteOpT>...>::value &&
17001700
detect_has_fold<ConcreteOpT>::value,
1701-
AbstractOperation::FoldHookFn>
1701+
OperationName::FoldHookFn>
17021702
getFoldHookFnImpl() {
17031703
return [](Operation *op, ArrayRef<Attribute> operands,
17041704
SmallVectorImpl<OpFoldResult> &results) {
@@ -1710,7 +1710,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
17101710
template <typename ConcreteOpT>
17111711
static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
17121712
!detect_has_fold<ConcreteOpT>::value,
1713-
AbstractOperation::FoldHookFn>
1713+
OperationName::FoldHookFn>
17141714
getFoldHookFnImpl() {
17151715
return [](Operation *op, ArrayRef<Attribute> operands,
17161716
SmallVectorImpl<OpFoldResult> &results) {
@@ -1754,29 +1754,29 @@ class Op : public OpState, public Traits<ConcreteType>... {
17541754
return result;
17551755
}
17561756

1757-
/// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook.
1758-
static AbstractOperation::GetCanonicalizationPatternsFn
1757+
/// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
1758+
static OperationName::GetCanonicalizationPatternsFn
17591759
getGetCanonicalizationPatternsFn() {
17601760
return &ConcreteType::getCanonicalizationPatterns;
17611761
}
17621762
/// Implementation of `GetHasTraitFn`
1763-
static AbstractOperation::HasTraitFn getHasTraitFn() {
1763+
static OperationName::HasTraitFn getHasTraitFn() {
17641764
return
17651765
[](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
17661766
}
1767-
/// Implementation of `ParseAssemblyFn` AbstractOperation hook.
1768-
static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() {
1767+
/// Implementation of `ParseAssemblyFn` OperationName hook.
1768+
static OperationName::ParseAssemblyFn getParseAssemblyFn() {
17691769
return &ConcreteType::parse;
17701770
}
1771-
/// Implementation of `PrintAssemblyFn` AbstractOperation hook.
1772-
static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() {
1771+
/// Implementation of `PrintAssemblyFn` OperationName hook.
1772+
static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
17731773
return getPrintAssemblyFnImpl<ConcreteType>();
17741774
}
17751775
/// The internal implementation of `getPrintAssemblyFn` that is invoked when
17761776
/// the concrete operation does not define a `print` method.
17771777
template <typename ConcreteOpT>
17781778
static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1779-
AbstractOperation::PrintAssemblyFn>
1779+
OperationName::PrintAssemblyFn>
17801780
getPrintAssemblyFnImpl() {
17811781
return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
17821782
return OpState::print(op, printer);
@@ -1786,7 +1786,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
17861786
/// the concrete operation defines a `print` method.
17871787
template <typename ConcreteOpT>
17881788
static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1789-
AbstractOperation::PrintAssemblyFn>
1789+
OperationName::PrintAssemblyFn>
17901790
getPrintAssemblyFnImpl() {
17911791
return &printAssembly;
17921792
}
@@ -1795,8 +1795,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
17951795
OpState::printOpName(op, p, defaultDialect);
17961796
return cast<ConcreteType>(op).print(p);
17971797
}
1798-
/// Implementation of `VerifyInvariantsFn` AbstractOperation hook.
1799-
static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() {
1798+
/// Implementation of `VerifyInvariantsFn` OperationName hook.
1799+
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
18001800
return &verifyInvariants;
18011801
}
18021802

@@ -1816,7 +1816,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
18161816
}
18171817

18181818
/// Allow access to internal implementation methods.
1819-
friend AbstractOperation;
1819+
friend RegisteredOperationName;
18201820
};
18211821

18221822
/// This class represents the base of an operation interface. See the definition
@@ -1836,22 +1836,22 @@ class OpInterface
18361836
protected:
18371837
/// Returns the impl interface instance for the given operation.
18381838
static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1839-
// Access the raw interface from the abstract operation.
1840-
auto *abstractOp = op->getAbstractOperation();
1841-
if (abstractOp) {
1842-
if (auto *opIface = abstractOp->getInterface<ConcreteType>())
1839+
OperationName name = op->getName();
1840+
1841+
// Access the raw interface from the operation info.
1842+
if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
1843+
if (auto *opIface = rInfo->getInterface<ConcreteType>())
18431844
return opIface;
18441845
// Fallback to the dialect to provide it with a chance to implement this
18451846
// interface for this operation.
1846-
return abstractOp->dialect.getRegisteredInterfaceForOp<ConcreteType>(
1847+
return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
18471848
op->getName());
18481849
}
18491850
// Fallback to the dialect to provide it with a chance to implement this
18501851
// interface for this operation.
1851-
Dialect *dialect = op->getName().getDialect();
1852-
return dialect ? dialect->getRegisteredInterfaceForOp<ConcreteType>(
1853-
op->getName())
1854-
: nullptr;
1852+
if (Dialect *dialect = name.getDialect())
1853+
return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
1854+
return nullptr;
18551855
}
18561856

18571857
/// Allow access to `getInterfaceFor`.

mlir/include/mlir/IR/Operation.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ class alignas(8) Operation final
5757
OperationName getName() { return name; }
5858

5959
/// If this operation has a registered operation description, return it.
60-
/// Otherwise return null.
61-
const AbstractOperation *getAbstractOperation() {
62-
return getName().getAbstractOperation();
60+
/// Otherwise return None.
61+
Optional<RegisteredOperationName> getRegisteredInfo() {
62+
return getName().getRegisteredInfo();
6363
}
6464

6565
/// Returns true if this operation has a registered operation description,
6666
/// otherwise false.
67-
bool isRegistered() { return getAbstractOperation(); }
67+
bool isRegistered() { return getName().isRegistered(); }
6868

6969
/// Remove this operation from its parent block and delete it.
7070
void erase();
@@ -468,16 +468,14 @@ class alignas(8) Operation final
468468
/// Returns true if the operation was registered with a particular trait, e.g.
469469
/// hasTrait<OperandsAreSignlessIntegerLike>().
470470
template <template <typename T> class Trait> bool hasTrait() {
471-
const AbstractOperation *abstractOp = getAbstractOperation();
472-
return abstractOp ? abstractOp->hasTrait<Trait>() : false;
471+
return name.hasTrait<Trait>();
473472
}
474473

475-
/// Returns true if the operation is *might* have the provided trait. This
474+
/// Returns true if the operation *might* have the provided trait. This
476475
/// means that either the operation is unregistered, or it was registered with
477476
/// the provide trait.
478477
template <template <typename T> class Trait> bool mightHaveTrait() {
479-
const AbstractOperation *abstractOp = getAbstractOperation();
480-
return abstractOp ? abstractOp->hasTrait<Trait>() : true;
478+
return name.mightHaveTrait<Trait>();
481479
}
482480

483481
//===--------------------------------------------------------------------===//

0 commit comments

Comments
 (0)