Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms][NFC] Decouple ConversionPatternRewriterImpl from ConversionPatternRewriter #82333

Merged

Conversation

matthias-springer
Copy link
Member

ConversionPatternRewriterImpl no longer maintains a reference to the respective ConversionPatternRewriter. An MLIRContext is sufficient. This commit simplifies the internal state of ConversionPatternRewriterImpl.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Feb 20, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 20, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

ConversionPatternRewriterImpl no longer maintains a reference to the respective ConversionPatternRewriter. An MLIRContext is sufficient. This commit simplifies the internal state of ConversionPatternRewriterImpl.


Full diff: https://github.com/llvm/llvm-project/pull/82333.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21-23)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 30fc2298b3deb3..ec97a4247658b8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -725,10 +725,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+  explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
-        config(config) {}
+      : eraseRewriter(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -823,8 +822,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        Type origOutputType,
                                        const TypeConverter *converter);
 
-  Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
-                                               Location loc, ValueRange inputs,
+  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
+                                               ValueRange inputs,
                                                Type origOutputType,
                                                Type outputType,
                                                const TypeConverter *converter);
@@ -903,8 +902,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  PatternRewriter &rewriter;
-
   /// This rewriter must be used for erasing ops/blocks.
   SingleEraseRewriter eraseRewriter;
 
@@ -1008,8 +1005,12 @@ void BlockTypeConversionRewrite::rollback() {
 
 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
     function_ref<Operation *(Value)> findLiveUser) {
+  auto builder = OpBuilder::atBlockBegin(block, /*listener=*/&rewriterImpl);
+
   // Process the remapping for each of the original arguments.
   for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+    OpBuilder::InsertionGuard g(builder);
+
     // If the type of this argument changed and the argument is still live, we
     // need to materialize a conversion.
     BlockArgument origArg = origBlock->getArgument(i);
@@ -1021,14 +1022,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
 
     Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
     bool isDroppedArg = replacementValue == origArg;
-    if (isDroppedArg)
-      rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
-    else
-      rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
+    if (!isDroppedArg)
+      builder.setInsertionPointAfterValue(replacementValue);
     Value newArg;
     if (converter) {
       newArg = converter->materializeSourceConversion(
-          rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
+          builder, origArg.getLoc(), origArg.getType(),
           isDroppedArg ? ValueRange() : ValueRange(replacementValue));
       assert((!newArg || newArg.getType() == origArg.getType()) &&
              "materialization hook did not provide a value of the expected "
@@ -1293,6 +1292,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
     Block *block, const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
+  MLIRContext *ctx = block->getParentOp()->getContext();
+
   // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
   auto convertedTypes = signatureConversion.getConvertedTypes();
@@ -1309,7 +1310,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
   // Map all new arguments to the location of the argument they originate from.
   SmallVector<Location> newLocs(convertedTypes.size(),
-                                rewriter.getUnknownLoc());
+                                Builder(ctx).getUnknownLoc());
   for (unsigned i = 0; i < origArgCount; ++i) {
     auto inputMap = signatureConversion.getInputMapping(i);
     if (!inputMap || inputMap->replacementValue)
@@ -1328,8 +1329,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
   argInfo.resize(origArgCount);
 
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(newBlock);
   for (unsigned i = 0; i != origArgCount; ++i) {
     auto inputMap = signatureConversion.getInputMapping(i);
     if (!inputMap)
@@ -1372,7 +1371,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
         outputType = legalOutputType;
 
       newArg = buildUnresolvedArgumentMaterialization(
-          rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
+          newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
           converter);
     }
 
@@ -1410,12 +1409,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   return convertOp.getResult(0);
 }
 Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    PatternRewriter &rewriter, Location loc, ValueRange inputs,
-    Type origOutputType, Type outputType, const TypeConverter *converter) {
-  return buildUnresolvedMaterialization(
-      MaterializationKind::Argument, rewriter.getInsertionBlock(),
-      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
-      converter);
+    Block *block, Location loc, ValueRange inputs, Type origOutputType,
+    Type outputType, const TypeConverter *converter) {
+  return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
+                                        block->begin(), loc, inputs, outputType,
+                                        origOutputType, converter);
 }
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
@@ -1527,7 +1525,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 ConversionPatternRewriter::ConversionPatternRewriter(
     MLIRContext *ctx, const ConversionConfig &config)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
+      impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
   setListener(impl.get());
 }
 

@matthias-springer matthias-springer force-pushed the users/matthias-springer/conversion_config branch 2 times, most recently from 577b5eb to 58e2c18 Compare February 23, 2024 09:37
Base automatically changed from users/matthias-springer/conversion_config to main February 23, 2024 10:28
… `ConversionPatternRewriter`

`ConversionPatternRewriterImpl` no longer maintains a reference to the respective `ConversionPatternRewriter`. An `MLIRContext` is sufficient. This commit simplifies the internal state of `ConversionPatternRewriterImpl`.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/decouple_conversion_rewriter branch from 362813d to 70920e8 Compare February 23, 2024 10:41
@matthias-springer matthias-springer merged commit 7bb08ee into main Feb 23, 2024
3 of 4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/decouple_conversion_rewriter branch February 23, 2024 10:55
matthias-springer added a commit to matthias-springer/llvm-project that referenced this pull request Feb 23, 2024
This is a follow-up to llvm#82333. It possible that the target block of a
`BlockTypeConversionRewrite` is detached, so the `MLIRContext` cannot be
taken from the block.
matthias-springer added a commit to matthias-springer/llvm-project that referenced this pull request Feb 23, 2024
This is a follow-up to llvm#82333. It possible that the target block of a
`BlockTypeConversionRewrite` is detached, so the `MLIRContext` cannot be
taken from the block.
matthias-springer added a commit that referenced this pull request Feb 23, 2024
This is a follow-up to #82333. It is possible that the target block of a
`BlockTypeConversionRewrite` is detached, so the `MLIRContext` cannot be
taken from the block.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants