diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 424eb980cd33a..0dcf4daf656fd 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -374,6 +374,10 @@ class OpBuilder : public Builder { template OpTy create(Location location, Args &&... args) { OperationState state(location, OpTy::getOperationName()); + if (!state.name.getAbstractOperation()) + llvm::report_fatal_error("Building op `" + + state.name.getStringRef().str() + + "` but it isn't registered in this MLIRContext"); OpTy::build(*this, state, std::forward(args)...); auto *op = createOperation(state); auto result = dyn_cast(op); @@ -390,6 +394,10 @@ class OpBuilder : public Builder { // Create the operation without using 'createOperation' as we don't want to // insert it yet. OperationState state(location, OpTy::getOperationName()); + if (!state.name.getAbstractOperation()) + llvm::report_fatal_error("Building op `" + + state.name.getStringRef().str() + + "` but it isn't registered in this MLIRContext"); OpTy::build(*this, state, std::forward(args)...); Operation *op = Operation::create(state); diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index da0b0bd826ced..8e75bb6244493 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -85,6 +85,9 @@ class MLIRContext { /// directly. std::vector getRegisteredOperations(); + /// Return true if this operation name is registered in this context. + bool isOperationRegistered(StringRef name); + // This is effectively private given that only MLIRContext.cpp can see the // MLIRContextImpl type. MLIRContextImpl &getImpl() { return *impl; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index bf5bd70c2b7fe..e92d54ec84f9b 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1235,7 +1235,10 @@ class Op : public OpState, static bool classof(Operation *op) { if (auto *abstractOp = op->getAbstractOperation()) return TypeID::get() == abstractOp->typeID; - return op->getName().getStringRef() == ConcreteType::getOperationName(); + assert(op->getContext()->isOperationRegistered( + ConcreteType::getOperationName()) && + "Casting attempt to an unregistered operation"); + return false; } /// This is the hook used by the AsmParser to parse the custom form of this diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 0728f294be861..da607a2319bfc 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -543,6 +543,13 @@ std::vector MLIRContext::getRegisteredOperations() { return result; } +bool MLIRContext::isOperationRegistered(StringRef name) { + // Lock access to the context registry. + ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); + + return impl->registeredOperations.count(name); +} + void Dialect::addOperation(AbstractOperation opInfo) { assert((getNamespace().empty() || opInfo.name.split('.').first == getNamespace()) && @@ -621,8 +628,9 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) { static Dialect &lookupDialectForSymbol(MLIRContext *ctx, TypeID typeID) { auto &impl = ctx->getImpl(); auto it = impl.registeredDialectSymbols.find(typeID); - assert(it != impl.registeredDialectSymbols.end() && - "symbol is not registered."); + if (it == impl.registeredDialectSymbols.end()) + llvm::report_fatal_error( + "Trying to create a type that was not registered in this MLIRContext."); return *it->second; }