Skip to content

Commit

Permalink
Harden MLIR detection of misconfiguration when missing dialect regist…
Browse files Browse the repository at this point in the history
…ration

This changes will catch error where C++ op are used without being
registered, either through creation with the OpBuilder or when trying to
cast to the C++ op.

Differential Revision: https://reviews.llvm.org/D80651
  • Loading branch information
joker-eph committed May 28, 2020
1 parent 4b94cee commit 213c6cd
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 3 deletions.
8 changes: 8 additions & 0 deletions mlir/include/mlir/IR/Builders.h
Expand Up @@ -374,6 +374,10 @@ class OpBuilder : public Builder {
template <typename OpTy, typename... Args>
OpTy create(Location location, Args &&... args) {
OperationState state(location, OpTy::getOperationName());
if (!state.name.getAbstractOperation())
llvm::report_fatal_error("Building op `" +
state.name.getStringRef().str() +
"` but it isn't registered in this MLIRContext");
OpTy::build(*this, state, std::forward<Args>(args)...);
auto *op = createOperation(state);
auto result = dyn_cast<OpTy>(op);
Expand All @@ -390,6 +394,10 @@ class OpBuilder : public Builder {
// Create the operation without using 'createOperation' as we don't want to
// insert it yet.
OperationState state(location, OpTy::getOperationName());
if (!state.name.getAbstractOperation())
llvm::report_fatal_error("Building op `" +
state.name.getStringRef().str() +
"` but it isn't registered in this MLIRContext");
OpTy::build(*this, state, std::forward<Args>(args)...);
Operation *op = Operation::create(state);

Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/MLIRContext.h
Expand Up @@ -85,6 +85,9 @@ class MLIRContext {
/// directly.
std::vector<AbstractOperation *> getRegisteredOperations();

/// Return true if this operation name is registered in this context.
bool isOperationRegistered(StringRef name);

// This is effectively private given that only MLIRContext.cpp can see the
// MLIRContextImpl type.
MLIRContextImpl &getImpl() { return *impl; }
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/IR/OpDefinition.h
Expand Up @@ -1235,7 +1235,10 @@ class Op : public OpState,
static bool classof(Operation *op) {
if (auto *abstractOp = op->getAbstractOperation())
return TypeID::get<ConcreteType>() == abstractOp->typeID;
return op->getName().getStringRef() == ConcreteType::getOperationName();
assert(op->getContext()->isOperationRegistered(
ConcreteType::getOperationName()) &&
"Casting attempt to an unregistered operation");
return false;
}

/// This is the hook used by the AsmParser to parse the custom form of this
Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/IR/MLIRContext.cpp
Expand Up @@ -543,6 +543,13 @@ std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
return result;
}

bool MLIRContext::isOperationRegistered(StringRef name) {
// Lock access to the context registry.
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);

return impl->registeredOperations.count(name);
}

void Dialect::addOperation(AbstractOperation opInfo) {
assert((getNamespace().empty() ||
opInfo.name.split('.').first == getNamespace()) &&
Expand Down Expand Up @@ -621,8 +628,9 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
static Dialect &lookupDialectForSymbol(MLIRContext *ctx, TypeID typeID) {
auto &impl = ctx->getImpl();
auto it = impl.registeredDialectSymbols.find(typeID);
assert(it != impl.registeredDialectSymbols.end() &&
"symbol is not registered.");
if (it == impl.registeredDialectSymbols.end())
llvm::report_fatal_error(
"Trying to create a type that was not registered in this MLIRContext.");
return *it->second;
}

Expand Down

0 comments on commit 213c6cd

Please sign in to comment.