Skip to content

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Aug 29, 2025

Many internal functions take a ConversionPatternRewriter & or ConversionPatternRewriterImpl & as a parameter. There's only a single instance of these classes, so it's better to store the reference in a field. This commit is in preparation of another PR that will require access to ConversionPatternRewriter in additional helper functions.

Note: Public API does not change.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Aug 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Many internal functions take a ConversionPatternRewriter & or ConversionPatternRewriterImpl & as a parameter. There's only a single instance of these classes, so it's better to store the reference in a field. This commit is in preparation of another PR that will require access to ConversionPatternRewriter in additional helper functions.


Patch is 26.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155997.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+87-102)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b6a216adfdd25..c0685f54731d5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+  explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
                                          const ConversionConfig &config)
-      : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
+      : rewriter(rewriter), config(config),
+        notifyingRewriter(rewriter.getContext(), config.listener) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -887,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// is the tag used when describing a value within a diagnostic, e.g.
   /// "operand".
   LogicalResult remapValues(StringRef valueDiagTag,
-                            std::optional<Location> inputLoc,
-                            PatternRewriter &rewriter, ValueRange values,
+                            std::optional<Location> inputLoc, ValueRange values,
                             SmallVector<ValueVector> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
@@ -918,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
   /// Convert the types of block arguments within the given region.
   FailureOr<Block *>
-  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
-                     const TypeConverter &converter,
+  convertRegionTypes(Region *region, const TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
   /// Apply the given signature conversion on the given block. The new block
@@ -929,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// translate between the origin argument types and those specified in the
   /// signature conversion.
   Block *applySignatureConversion(
-      ConversionPatternRewriter &rewriter, Block *block,
-      const TypeConverter *converter,
+      Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
   /// Replace the results of the given operation with the given values and
@@ -1060,8 +1058,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  /// MLIR context.
-  MLIRContext *context;
+  /// The rewriter that is used to perform the conversion.
+  ConversionPatternRewriter &rewriter;
 
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
@@ -1258,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
-  // Commit all rewrites.
-  IRRewriter rewriter(context, config.listener);
+  // Commit all rewrites. Use a new rewriter, so the modifications are not
+  // tracked for rollback purposes etc.
+  IRRewriter irRewriter(rewriter.getContext(), config.listener);
   // Note: New rewrites may be added during the "commit" phase and the
   // `rewrites` vector may reallocate.
   for (size_t i = 0; i < rewrites.size(); ++i)
-    rewrites[i]->commit(rewriter);
+    rewrites[i]->commit(irRewriter);
 
   // Clean up all rewrites.
   SingleEraseRewriter eraseRewriter(
-      context, /*opErasedCallback=*/[&](Operation *op) {
+      rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
         if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
           unresolvedMaterializations.erase(castOp);
       });
@@ -1412,8 +1411,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
 }
 
 LogicalResult ConversionPatternRewriterImpl::remapValues(
-    StringRef valueDiagTag, std::optional<Location> inputLoc,
-    PatternRewriter &rewriter, ValueRange values,
+    StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
     SmallVector<ValueVector> &remapped) {
   remapped.reserve(llvm::size(values));
 
@@ -1484,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
 //===----------------------------------------------------------------------===//
 
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
-    ConversionPatternRewriter &rewriter, Region *region,
-    const TypeConverter &converter,
+    Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
   regionToConverter[region] = &converter;
   if (region->empty())
@@ -1500,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     if (!conversion)
       return failure();
     // Convert the block with the computed signature.
-    applySignatureConversion(rewriter, &block, &converter, *conversion);
+    applySignatureConversion(&block, &converter, *conversion);
   }
 
   // Convert the entry block. If an entry signature conversion was provided,
   // use that one. Otherwise, compute the signature with the type converter.
   if (entryConversion)
-    return applySignatureConversion(rewriter, &region->front(), &converter,
+    return applySignatureConversion(&region->front(), &converter,
                                     *entryConversion);
   std::optional<TypeConverter::SignatureConversion> conversion =
       converter.convertBlockSignature(&region->front());
   if (!conversion)
     return failure();
-  return applySignatureConversion(rewriter, &region->front(), &converter,
-                                  *conversion);
+  return applySignatureConversion(&region->front(), &converter, *conversion);
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    ConversionPatternRewriter &rewriter, Block *block,
-    const TypeConverter *converter,
+    Block *block, const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   // A block cannot be converted multiple times.
@@ -2023,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 ConversionPatternRewriter::ConversionPatternRewriter(
     MLIRContext *ctx, const ConversionConfig &config)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+      impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
   setListener(impl.get());
 }
 
@@ -2100,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
   assert(!impl->wasOpReplaced(block->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->applySignatureConversion(*this, block, converter, conversion);
+  return impl->applySignatureConversion(block, converter, conversion);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -2109,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->convertRegionTypes(*this, region, converter, entryConversion);
+  return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2128,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
   SmallVector<ValueVector> remappedValues;
-  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
                                remappedValues)))
     return nullptr;
   assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -2141,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
   if (keys.empty())
     return success();
   SmallVector<ValueVector> remapped;
-  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
                                remapped)))
     return failure();
   for (const auto &values : remapped) {
@@ -2288,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
 
   // Remap the operands of the operation.
   SmallVector<ValueVector> remapped;
-  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+  if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
                                       op->getOperands(), remapped))) {
     return failure();
   }
@@ -2310,7 +2305,8 @@ class OperationLegalizer {
 public:
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
-  OperationLegalizer(const ConversionTarget &targetInfo,
+  OperationLegalizer(ConversionPatternRewriter &rewriter,
+                     const ConversionTarget &targetInfo,
                      const FrozenRewritePatternSet &patterns);
 
   /// Returns true if the given operation is known to be illegal on the target.
@@ -2318,29 +2314,25 @@ class OperationLegalizer {
 
   /// Attempt to legalize the given operation. Returns success if the operation
   /// was legalized, failure otherwise.
-  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+  LogicalResult legalize(Operation *op);
 
   /// Returns the conversion target in use by the legalizer.
   const ConversionTarget &getTarget() { return target; }
 
 private:
   /// Attempt to legalize the given operation by folding it.
-  LogicalResult legalizeWithFold(Operation *op,
-                                 ConversionPatternRewriter &rewriter);
+  LogicalResult legalizeWithFold(Operation *op);
 
   /// Attempt to legalize the given operation by applying a pattern. Returns
   /// success if the operation was legalized, failure otherwise.
-  LogicalResult legalizeWithPattern(Operation *op,
-                                    ConversionPatternRewriter &rewriter);
+  LogicalResult legalizeWithPattern(Operation *op);
 
   /// Return true if the given pattern may be applied to the given operation,
   /// false otherwise.
-  bool canApplyPattern(Operation *op, const Pattern &pattern,
-                       ConversionPatternRewriter &rewriter);
+  bool canApplyPattern(Operation *op, const Pattern &pattern);
 
   /// Legalize the resultant IR after successfully applying the given pattern.
   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
-                                      ConversionPatternRewriter &rewriter,
                                       const RewriterState &curState,
                                       const SetVector<Operation *> &newOps,
                                       const SetVector<Operation *> &modifiedOps,
@@ -2349,18 +2341,12 @@ class OperationLegalizer {
   /// Legalizes the actions registered during the execution of a pattern.
   LogicalResult
   legalizePatternBlockRewrites(Operation *op,
-                               ConversionPatternRewriter &rewriter,
-                               ConversionPatternRewriterImpl &impl,
                                const SetVector<Block *> &insertedBlocks,
                                const SetVector<Operation *> &newOps);
   LogicalResult
-  legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
-                                   ConversionPatternRewriterImpl &impl,
-                                   const SetVector<Operation *> &newOps);
+  legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
   LogicalResult
-  legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
-                             ConversionPatternRewriterImpl &impl,
-                             const SetVector<Operation *> &modifiedOps);
+  legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
 
   //===--------------------------------------------------------------------===//
   // Cost Model
@@ -2403,6 +2389,9 @@ class OperationLegalizer {
   /// The current set of patterns that have been applied.
   SmallPtrSet<const Pattern *, 8> appliedPatterns;
 
+  /// The rewriter to use when converting operations.
+  ConversionPatternRewriter &rewriter;
+
   /// The legalization information provided by the target.
   const ConversionTarget &target;
 
@@ -2411,9 +2400,10 @@ class OperationLegalizer {
 };
 } // namespace
 
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+                                       const ConversionTarget &targetInfo,
                                        const FrozenRewritePatternSet &patterns)
-    : target(targetInfo), applicator(patterns) {
+    : rewriter(rewriter), target(targetInfo), applicator(patterns) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2427,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
   return target.isIllegal(op);
 }
 
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
-                             ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
 #ifndef NDEBUG
   const char *logLineComment =
       "//===-------------------------------------------===//\n";
@@ -2495,7 +2483,7 @@ OperationLegalizer::legalize(Operation *op,
   // is 'BeforePatterns'. 'Never' will skip this.
   const ConversionConfig &config = rewriter.getConfig();
   if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
-    if (succeeded(legalizeWithFold(op, rewriter))) {
+    if (succeeded(legalizeWithFold(op))) {
       LLVM_DEBUG({
         logSuccess(logger, "operation was folded");
         logger.startLine() << logLineComment;
@@ -2505,7 +2493,7 @@ OperationLegalizer::legalize(Operation *op,
   }
 
   // Otherwise, we need to apply a legalization pattern to this operation.
-  if (succeeded(legalizeWithPattern(op, rewriter))) {
+  if (succeeded(legalizeWithPattern(op))) {
     LLVM_DEBUG({
       logSuccess(logger, "");
       logger.startLine() << logLineComment;
@@ -2516,7 +2504,7 @@ OperationLegalizer::legalize(Operation *op,
   // If the operation can't be legalized via patterns, try to fold it in-place
   // if the folding mode is 'AfterPatterns'.
   if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
-    if (succeeded(legalizeWithFold(op, rewriter))) {
+    if (succeeded(legalizeWithFold(op))) {
       LLVM_DEBUG({
         logSuccess(logger, "operation was folded");
         logger.startLine() << logLineComment;
@@ -2541,9 +2529,7 @@ static T moveAndReset(T &obj) {
   return result;
 }
 
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
-                                     ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
   auto &rewriterImpl = rewriter.getImpl();
   LLVM_DEBUG({
     rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2577,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
   if (replacementValues.empty())
-    return legalize(op, rewriter);
+    return legalize(op);
 
   // Insert a replacement for 'op' with the folded replacement values.
   rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
   for (Operation *newOp : newOps) {
-    if (failed(legalize(newOp, rewriter))) {
+    if (failed(legalize(newOp))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
                             newOp->getName()));
@@ -2629,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
       llvm::join(insertedBlockNames, ", ") + "}");
 }
 
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
-                                        ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
   auto &rewriterImpl = rewriter.getImpl();
   const ConversionConfig &config = rewriter.getConfig();
 
@@ -2663,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that returns if the given pattern may be applied.
   auto canApply = [&](const Pattern &pattern) {
-    bool canApply = canApplyPattern(op, pattern, rewriter);
+    bool canApply = canApplyPattern(op, pattern);
     if (canApply && config.listener)
       config.listener->notifyPatternBegin(pattern, op);
     return canApply;
@@ -2728,7 +2712,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
         moveAndReset(rewriterImpl.patternModifiedOps);
     SetVector<Block *> insertedBlocks =
         moveAndReset(rewriterImpl.patternInsertedBlocks);
-    auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+    auto result = legalizePatternResult(op, pattern, curState, newOps,
                                         modifiedOps, insertedBlocks);
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
@@ -2747,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
                                     onSuccess);
 }
 
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
-                                         ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+                                         const Pattern &pattern) {
   LLVM_DEBUG({
     auto &os = rewriter.getImpl().logger;
     os.getOStream() << "\n";
@@ -2770,8 +2754,8 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
 }
 
 LogicalResult OperationLegalizer::legalizePatternResult(
-    Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
-    const RewriterState &curState, const SetVector<Operation *> &newOps,
+    Operation *op, const Pattern &pattern, const RewriterState &curState,
+    const SetVector<Operation *> &newOps,
     const SetVector<Operation *> &modifiedOps,
     const SetVector<Block *> &insertedBlocks) {
   auto &impl = rewriter.getImpl();
@@ -2792,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
-  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
-                                          newOps)) ||
-      failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
-      failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+      failed(legalizePatternRootUpdates(modifiedOps)) ||
+      failed(legalizePatternCreatedOperations(newOps))) {
     return failure();
   }
 
@@ -2804,10 +2787,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 }
 
 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
-    Operation *op, ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &impl,
-    const SetVector<Block *> &insertedBlocks,
+    Operation *op, const SetVector<Block *> &insertedBlocks,
     const SetVector<Operation *> &newOps) {
+  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
   SmallPtrSet<Operation *, 16> alreadyLegalized;
 
   // If the pattern moved or created any blocks, make sure the types of block
@@ -2831,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
                                            "block"));
         return failure()...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Many internal functions take a ConversionPatternRewriter &amp; or ConversionPatternRewriterImpl &amp; as a parameter. There's only a single instance of these classes, so it's better to store the reference in a field. This commit is in preparation of another PR that will require access to ConversionPatternRewriter in additional helper functions.


Patch is 26.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155997.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+87-102)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b6a216adfdd25..c0685f54731d5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+  explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
                                          const ConversionConfig &config)
-      : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
+      : rewriter(rewriter), config(config),
+        notifyingRewriter(rewriter.getContext(), config.listener) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -887,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// is the tag used when describing a value within a diagnostic, e.g.
   /// "operand".
   LogicalResult remapValues(StringRef valueDiagTag,
-                            std::optional<Location> inputLoc,
-                            PatternRewriter &rewriter, ValueRange values,
+                            std::optional<Location> inputLoc, ValueRange values,
                             SmallVector<ValueVector> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
@@ -918,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
   /// Convert the types of block arguments within the given region.
   FailureOr<Block *>
-  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
-                     const TypeConverter &converter,
+  convertRegionTypes(Region *region, const TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
   /// Apply the given signature conversion on the given block. The new block
@@ -929,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// translate between the origin argument types and those specified in the
   /// signature conversion.
   Block *applySignatureConversion(
-      ConversionPatternRewriter &rewriter, Block *block,
-      const TypeConverter *converter,
+      Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
   /// Replace the results of the given operation with the given values and
@@ -1060,8 +1058,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  /// MLIR context.
-  MLIRContext *context;
+  /// The rewriter that is used to perform the conversion.
+  ConversionPatternRewriter &rewriter;
 
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
@@ -1258,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
-  // Commit all rewrites.
-  IRRewriter rewriter(context, config.listener);
+  // Commit all rewrites. Use a new rewriter, so the modifications are not
+  // tracked for rollback purposes etc.
+  IRRewriter irRewriter(rewriter.getContext(), config.listener);
   // Note: New rewrites may be added during the "commit" phase and the
   // `rewrites` vector may reallocate.
   for (size_t i = 0; i < rewrites.size(); ++i)
-    rewrites[i]->commit(rewriter);
+    rewrites[i]->commit(irRewriter);
 
   // Clean up all rewrites.
   SingleEraseRewriter eraseRewriter(
-      context, /*opErasedCallback=*/[&](Operation *op) {
+      rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
         if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
           unresolvedMaterializations.erase(castOp);
       });
@@ -1412,8 +1411,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
 }
 
 LogicalResult ConversionPatternRewriterImpl::remapValues(
-    StringRef valueDiagTag, std::optional<Location> inputLoc,
-    PatternRewriter &rewriter, ValueRange values,
+    StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
     SmallVector<ValueVector> &remapped) {
   remapped.reserve(llvm::size(values));
 
@@ -1484,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
 //===----------------------------------------------------------------------===//
 
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
-    ConversionPatternRewriter &rewriter, Region *region,
-    const TypeConverter &converter,
+    Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
   regionToConverter[region] = &converter;
   if (region->empty())
@@ -1500,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     if (!conversion)
       return failure();
     // Convert the block with the computed signature.
-    applySignatureConversion(rewriter, &block, &converter, *conversion);
+    applySignatureConversion(&block, &converter, *conversion);
   }
 
   // Convert the entry block. If an entry signature conversion was provided,
   // use that one. Otherwise, compute the signature with the type converter.
   if (entryConversion)
-    return applySignatureConversion(rewriter, &region->front(), &converter,
+    return applySignatureConversion(&region->front(), &converter,
                                     *entryConversion);
   std::optional<TypeConverter::SignatureConversion> conversion =
       converter.convertBlockSignature(&region->front());
   if (!conversion)
     return failure();
-  return applySignatureConversion(rewriter, &region->front(), &converter,
-                                  *conversion);
+  return applySignatureConversion(&region->front(), &converter, *conversion);
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    ConversionPatternRewriter &rewriter, Block *block,
-    const TypeConverter *converter,
+    Block *block, const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   // A block cannot be converted multiple times.
@@ -2023,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 ConversionPatternRewriter::ConversionPatternRewriter(
     MLIRContext *ctx, const ConversionConfig &config)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+      impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
   setListener(impl.get());
 }
 
@@ -2100,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
   assert(!impl->wasOpReplaced(block->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->applySignatureConversion(*this, block, converter, conversion);
+  return impl->applySignatureConversion(block, converter, conversion);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -2109,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->convertRegionTypes(*this, region, converter, entryConversion);
+  return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2128,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
   SmallVector<ValueVector> remappedValues;
-  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
                                remappedValues)))
     return nullptr;
   assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -2141,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
   if (keys.empty())
     return success();
   SmallVector<ValueVector> remapped;
-  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
                                remapped)))
     return failure();
   for (const auto &values : remapped) {
@@ -2288,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
 
   // Remap the operands of the operation.
   SmallVector<ValueVector> remapped;
-  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+  if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
                                       op->getOperands(), remapped))) {
     return failure();
   }
@@ -2310,7 +2305,8 @@ class OperationLegalizer {
 public:
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
-  OperationLegalizer(const ConversionTarget &targetInfo,
+  OperationLegalizer(ConversionPatternRewriter &rewriter,
+                     const ConversionTarget &targetInfo,
                      const FrozenRewritePatternSet &patterns);
 
   /// Returns true if the given operation is known to be illegal on the target.
@@ -2318,29 +2314,25 @@ class OperationLegalizer {
 
   /// Attempt to legalize the given operation. Returns success if the operation
   /// was legalized, failure otherwise.
-  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+  LogicalResult legalize(Operation *op);
 
   /// Returns the conversion target in use by the legalizer.
   const ConversionTarget &getTarget() { return target; }
 
 private:
   /// Attempt to legalize the given operation by folding it.
-  LogicalResult legalizeWithFold(Operation *op,
-                                 ConversionPatternRewriter &rewriter);
+  LogicalResult legalizeWithFold(Operation *op);
 
   /// Attempt to legalize the given operation by applying a pattern. Returns
   /// success if the operation was legalized, failure otherwise.
-  LogicalResult legalizeWithPattern(Operation *op,
-                                    ConversionPatternRewriter &rewriter);
+  LogicalResult legalizeWithPattern(Operation *op);
 
   /// Return true if the given pattern may be applied to the given operation,
   /// false otherwise.
-  bool canApplyPattern(Operation *op, const Pattern &pattern,
-                       ConversionPatternRewriter &rewriter);
+  bool canApplyPattern(Operation *op, const Pattern &pattern);
 
   /// Legalize the resultant IR after successfully applying the given pattern.
   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
-                                      ConversionPatternRewriter &rewriter,
                                       const RewriterState &curState,
                                       const SetVector<Operation *> &newOps,
                                       const SetVector<Operation *> &modifiedOps,
@@ -2349,18 +2341,12 @@ class OperationLegalizer {
   /// Legalizes the actions registered during the execution of a pattern.
   LogicalResult
   legalizePatternBlockRewrites(Operation *op,
-                               ConversionPatternRewriter &rewriter,
-                               ConversionPatternRewriterImpl &impl,
                                const SetVector<Block *> &insertedBlocks,
                                const SetVector<Operation *> &newOps);
   LogicalResult
-  legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
-                                   ConversionPatternRewriterImpl &impl,
-                                   const SetVector<Operation *> &newOps);
+  legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
   LogicalResult
-  legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
-                             ConversionPatternRewriterImpl &impl,
-                             const SetVector<Operation *> &modifiedOps);
+  legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
 
   //===--------------------------------------------------------------------===//
   // Cost Model
@@ -2403,6 +2389,9 @@ class OperationLegalizer {
   /// The current set of patterns that have been applied.
   SmallPtrSet<const Pattern *, 8> appliedPatterns;
 
+  /// The rewriter to use when converting operations.
+  ConversionPatternRewriter &rewriter;
+
   /// The legalization information provided by the target.
   const ConversionTarget &target;
 
@@ -2411,9 +2400,10 @@ class OperationLegalizer {
 };
 } // namespace
 
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+                                       const ConversionTarget &targetInfo,
                                        const FrozenRewritePatternSet &patterns)
-    : target(targetInfo), applicator(patterns) {
+    : rewriter(rewriter), target(targetInfo), applicator(patterns) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2427,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
   return target.isIllegal(op);
 }
 
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
-                             ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
 #ifndef NDEBUG
   const char *logLineComment =
       "//===-------------------------------------------===//\n";
@@ -2495,7 +2483,7 @@ OperationLegalizer::legalize(Operation *op,
   // is 'BeforePatterns'. 'Never' will skip this.
   const ConversionConfig &config = rewriter.getConfig();
   if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
-    if (succeeded(legalizeWithFold(op, rewriter))) {
+    if (succeeded(legalizeWithFold(op))) {
       LLVM_DEBUG({
         logSuccess(logger, "operation was folded");
         logger.startLine() << logLineComment;
@@ -2505,7 +2493,7 @@ OperationLegalizer::legalize(Operation *op,
   }
 
   // Otherwise, we need to apply a legalization pattern to this operation.
-  if (succeeded(legalizeWithPattern(op, rewriter))) {
+  if (succeeded(legalizeWithPattern(op))) {
     LLVM_DEBUG({
       logSuccess(logger, "");
       logger.startLine() << logLineComment;
@@ -2516,7 +2504,7 @@ OperationLegalizer::legalize(Operation *op,
   // If the operation can't be legalized via patterns, try to fold it in-place
   // if the folding mode is 'AfterPatterns'.
   if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
-    if (succeeded(legalizeWithFold(op, rewriter))) {
+    if (succeeded(legalizeWithFold(op))) {
       LLVM_DEBUG({
         logSuccess(logger, "operation was folded");
         logger.startLine() << logLineComment;
@@ -2541,9 +2529,7 @@ static T moveAndReset(T &obj) {
   return result;
 }
 
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
-                                     ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
   auto &rewriterImpl = rewriter.getImpl();
   LLVM_DEBUG({
     rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2577,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
   if (replacementValues.empty())
-    return legalize(op, rewriter);
+    return legalize(op);
 
   // Insert a replacement for 'op' with the folded replacement values.
   rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
   for (Operation *newOp : newOps) {
-    if (failed(legalize(newOp, rewriter))) {
+    if (failed(legalize(newOp))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
                             newOp->getName()));
@@ -2629,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
       llvm::join(insertedBlockNames, ", ") + "}");
 }
 
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
-                                        ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
   auto &rewriterImpl = rewriter.getImpl();
   const ConversionConfig &config = rewriter.getConfig();
 
@@ -2663,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that returns if the given pattern may be applied.
   auto canApply = [&](const Pattern &pattern) {
-    bool canApply = canApplyPattern(op, pattern, rewriter);
+    bool canApply = canApplyPattern(op, pattern);
     if (canApply && config.listener)
       config.listener->notifyPatternBegin(pattern, op);
     return canApply;
@@ -2728,7 +2712,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
         moveAndReset(rewriterImpl.patternModifiedOps);
     SetVector<Block *> insertedBlocks =
         moveAndReset(rewriterImpl.patternInsertedBlocks);
-    auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+    auto result = legalizePatternResult(op, pattern, curState, newOps,
                                         modifiedOps, insertedBlocks);
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
@@ -2747,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
                                     onSuccess);
 }
 
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
-                                         ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+                                         const Pattern &pattern) {
   LLVM_DEBUG({
     auto &os = rewriter.getImpl().logger;
     os.getOStream() << "\n";
@@ -2770,8 +2754,8 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
 }
 
 LogicalResult OperationLegalizer::legalizePatternResult(
-    Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
-    const RewriterState &curState, const SetVector<Operation *> &newOps,
+    Operation *op, const Pattern &pattern, const RewriterState &curState,
+    const SetVector<Operation *> &newOps,
     const SetVector<Operation *> &modifiedOps,
     const SetVector<Block *> &insertedBlocks) {
   auto &impl = rewriter.getImpl();
@@ -2792,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
-  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
-                                          newOps)) ||
-      failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
-      failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+      failed(legalizePatternRootUpdates(modifiedOps)) ||
+      failed(legalizePatternCreatedOperations(newOps))) {
     return failure();
   }
 
@@ -2804,10 +2787,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 }
 
 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
-    Operation *op, ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &impl,
-    const SetVector<Block *> &insertedBlocks,
+    Operation *op, const SetVector<Block *> &insertedBlocks,
     const SetVector<Operation *> &newOps) {
+  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
   SmallPtrSet<Operation *, 16> alreadyLegalized;
 
   // If the pattern moved or created any blocks, make sure the types of block
@@ -2831,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
                                            "block"));
         return failure()...
[truncated]

@matthias-springer matthias-springer merged commit 49f39b3 into main Aug 29, 2025
12 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/rewriter_params branch August 29, 2025 13:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants