diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 15fa39bde104b..0d7722aa07ee3 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter { /// PatternRewriter hook for updating the given operation in-place. /// Note: These methods only track updates to the given operation itself, - /// and not nested regions. Updates to regions will still require - /// notification through other more specific hooks above. + /// and not nested regions. Updates to regions will still require notification + /// through other more specific hooks above. void startOpModification(Operation *op) override; /// PatternRewriter hook for updating the given operation in-place. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c58b856faefb6..84e7232d326a8 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -154,14 +154,12 @@ namespace { struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, unsigned numReplacements, unsigned numArgReplacements, - unsigned numRewrites, unsigned numIgnoredOperations, - unsigned numRootUpdates) + unsigned numRewrites, unsigned numIgnoredOperations) : numCreatedOps(numCreatedOps), numUnresolvedMaterializations(numUnresolvedMaterializations), numReplacements(numReplacements), numArgReplacements(numArgReplacements), numRewrites(numRewrites), - numIgnoredOperations(numIgnoredOperations), - numRootUpdates(numRootUpdates) {} + numIgnoredOperations(numIgnoredOperations) {} /// The current number of created operations. unsigned numCreatedOps; @@ -180,44 +178,6 @@ struct RewriterState { /// The current number of ignored operations. unsigned numIgnoredOperations; - - /// The current number of operations that were updated in place. - unsigned numRootUpdates; -}; - -//===----------------------------------------------------------------------===// -// OperationTransactionState - -/// The state of an operation that was updated by a pattern in-place. This -/// contains all of the necessary information to reconstruct an operation that -/// was updated in place. -class OperationTransactionState { -public: - OperationTransactionState() = default; - OperationTransactionState(Operation *op) - : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()), - operands(op->operand_begin(), op->operand_end()), - successors(op->successor_begin(), op->successor_end()) {} - - /// Discard the transaction state and reset the state of the original - /// operation. - void resetOperation() const { - op->setLoc(loc); - op->setAttrs(attrs); - op->setOperands(operands); - for (const auto &it : llvm::enumerate(successors)) - op->setSuccessor(it.value(), it.index()); - } - - /// Return the original operation of this state. - Operation *getOperation() const { return op; } - -private: - Operation *op; - LocationAttr loc; - DictionaryAttr attrs; - SmallVector operands; - SmallVector successors; }; //===----------------------------------------------------------------------===// @@ -754,14 +714,19 @@ namespace { class IRRewrite { public: /// The kind of the rewrite. Rewrites can be undone if the conversion fails. + /// Enum values are ordered, so that they can be used in `classof`: first all + /// block rewrites, then all operation rewrites. enum class Kind { + // Block rewrites CreateBlock, EraseBlock, InlineBlock, MoveBlock, SplitBlock, BlockTypeConversion, - MoveOperation + // Operation rewrites + MoveOperation, + ModifyOperation }; virtual ~IRRewrite() = default; @@ -992,7 +957,7 @@ class OperationRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::MoveOperation && - rewrite->getKind() <= Kind::MoveOperation; + rewrite->getKind() <= Kind::ModifyOperation; } protected: @@ -1031,8 +996,48 @@ class MoveOperationRewrite : public OperationRewrite { // this operation was the only operation in the region. Operation *insertBeforeOp; }; + +/// In-place modification of an op. This rewrite is immediately reflected in +/// the IR. The previous state of the operation is stored in this object. +class ModifyOperationRewrite : public OperationRewrite { +public: + ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), + loc(op->getLoc()), attrs(op->getAttrDictionary()), + operands(op->operand_begin(), op->operand_end()), + successors(op->successor_begin(), op->successor_end()) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ModifyOperation; + } + + void rollback() override { + op->setLoc(loc); + op->setAttrs(attrs); + op->setOperands(operands); + for (const auto &it : llvm::enumerate(successors)) + op->setSuccessor(it.value(), it.index()); + } + +private: + LocationAttr loc; + DictionaryAttr attrs; + SmallVector operands; + SmallVector successors; +}; } // namespace +/// Return "true" if there is an operation rewrite that matches the specified +/// rewrite type and operation among the given rewrites. +template +static bool hasRewrite(R &&rewrites, Operation *op) { + return any_of(std::move(rewrites), [&](auto &rewrite) { + auto *rewriteTy = dyn_cast(rewrite.get()); + return rewriteTy && rewriteTy->getOperation() == op; + }); +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -1184,9 +1189,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// operation was ignored. SetVector ignoredOps; - /// A transaction state for each of operations that were updated in-place. - SmallVector rootUpdates; - /// A vector of indices into `replacements` of operations that were replaced /// with values with different result types than the original operation, e.g. /// 1->N conversion of some kind. @@ -1238,10 +1240,6 @@ static void detachNestedAndErase(Operation *op) { } void ConversionPatternRewriterImpl::discardRewrites() { - // Reset any operations that were updated in place. - for (auto &state : rootUpdates) - state.resetOperation(); - undoRewrites(); // Remove any newly created ops. @@ -1316,15 +1314,10 @@ void ConversionPatternRewriterImpl::applyRewrites() { RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), unresolvedMaterializations.size(), replacements.size(), argReplacements.size(), - rewrites.size(), ignoredOps.size(), rootUpdates.size()); + rewrites.size(), ignoredOps.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { - // Reset any operations that were updated in place. - for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) - rootUpdates[i].resetOperation(); - rootUpdates.resize(state.numRootUpdates); - // Reset any replaced arguments. for (BlockArgument replacedArg : llvm::drop_begin(argReplacements, state.numArgReplacements)) @@ -1750,7 +1743,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { #ifndef NDEBUG impl->pendingRootUpdates.insert(op); #endif - impl->rootUpdates.emplace_back(op); + impl->appendRewrite(op); } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { @@ -1769,13 +1762,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) { "operation did not have a pending in-place update"); #endif // Erase the last update for this operation. - auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; - auto &rootUpdates = impl->rootUpdates; - auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); - assert(it != rootUpdates.rend() && "no root update started on op"); - (*it).resetOperation(); - int updateIdx = std::prev(rootUpdates.rend()) - it; - rootUpdates.erase(rootUpdates.begin() + updateIdx); + auto it = llvm::find_if( + llvm::reverse(impl->rewrites), [&](std::unique_ptr &rewrite) { + auto *modifyRewrite = dyn_cast(rewrite.get()); + return modifyRewrite && modifyRewrite->getOperation() == op; + }); + assert(it != impl->rewrites.rend() && "no root update started on op"); + (*it)->rollback(); + int updateIdx = std::prev(impl->rewrites.rend()) - it; + impl->rewrites.erase(impl->rewrites.begin() + updateIdx); } detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { @@ -2059,6 +2054,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that cleans up the rewriter state after a pattern failed to match. RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { + assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); LLVM_DEBUG({ logFailure(rewriterImpl.logger, "pattern failed to match"); if (rewriterImpl.notifyCallback) { @@ -2076,6 +2072,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that performs additional legalization when a pattern is // successfully applied. auto onSuccess = [&](const Pattern &pattern) { + assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); auto result = legalizePatternResult(op, pattern, rewriter, curState); appliedPatterns.erase(&pattern); if (failed(result)) @@ -2118,7 +2115,6 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, #ifndef NDEBUG assert(impl.pendingRootUpdates.empty() && "dangling root updates"); -#endif // Check that the root was either replaced or updated in place. auto replacedRoot = [&] { @@ -2127,14 +2123,12 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, [op](auto &it) { return it.first == op; }); }; auto updatedRootInPlace = [&] { - return llvm::any_of( - llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), - [op](auto &state) { return state.getOperation() == op; }); + return hasRewrite( + llvm::drop_begin(impl.rewrites, curState.numRewrites), op); }; - (void)replacedRoot; - (void)updatedRootInPlace; assert((replacedRoot() || updatedRootInPlace()) && "expected pattern to replace the root operation"); +#endif // NDEBUG // Legalize each of the actions registered during application. RewriterState newState = impl.getCurrentState(); @@ -2221,8 +2215,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations( LogicalResult OperationLegalizer::legalizePatternRootUpdates( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState) { - for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) { - Operation *op = impl.rootUpdates[i].getOperation(); + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + auto *rewrite = dyn_cast(impl.rewrites[i].get()); + if (!rewrite) + continue; + Operation *op = rewrite->getOperation(); if (failed(legalize(op, rewriter))) { LLVM_DEBUG(logFailure( impl.logger, "failed to legalize operation updated in-place '{0}'", @@ -3562,7 +3559,8 @@ mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, // Full Conversion LogicalResult -mlir::applyFullConversion(ArrayRef ops, const ConversionTarget &target, +mlir::applyFullConversion(ArrayRef ops, + const ConversionTarget &target, const FrozenRewritePatternSet &patterns) { OperationConverter opConverter(target, patterns, OpConversionMode::Full); return opConverter.convertOperations(ops);