Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] Support replaceAllUsesWith in dialect conversion #84725

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

This commit adds support for RewriterBase::replaceAllUsesWith to the dialect conversion. Uses are not immediately replaced, but in a delayed fashion during the "commit" phase. No type conversions are performed; this is consistent with ConversionPatternRewriter::replaceUsesOfBlockArgument.

  • RewriterBase::replaceAllUsesWith is now virtual, so that it can be overridden in the dialect conversion. Note: RewriterBase::replaceOp can now be turned into a non-virtual function in a follow-up commit.
  • ConversionPatternRewriter::replaceUsesOfBlockArgument is generalized to ConversionPatternRewriter::replaceAllUsesWith, following the same implementation strategy.
  • A new kind of "IR rewrite" is added: ValueRewrite with ReplaceAllUsesRewrite (replacing ReplaceBlockArgRewrite) as the only value rewrite for now.
  • replacedOps is renamed to erasedOps to better capture its meaning.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 11, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds support for RewriterBase::replaceAllUsesWith to the dialect conversion. Uses are not immediately replaced, but in a delayed fashion during the "commit" phase. No type conversions are performed; this is consistent with ConversionPatternRewriter::replaceUsesOfBlockArgument.

  • RewriterBase::replaceAllUsesWith is now virtual, so that it can be overridden in the dialect conversion. Note: RewriterBase::replaceOp can now be turned into a non-virtual function in a follow-up commit.
  • ConversionPatternRewriter::replaceUsesOfBlockArgument is generalized to ConversionPatternRewriter::replaceAllUsesWith, following the same implementation strategy.
  • A new kind of "IR rewrite" is added: ValueRewrite with ReplaceAllUsesRewrite (replacing ReplaceBlockArgRewrite) as the only value rewrite for now.
  • replacedOps is renamed to erasedOps to better capture its meaning.

Patch is 25.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84725.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+5-3)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+121-88)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+18)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+24-3)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2be1e2e2b40276..3e11e00b9d4b40 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,7 +614,7 @@ class RewriterBase : public OpBuilder {
 
   /// Find uses of `from` and replace them with `to`. Also notify the listener
   /// about every in-place op modification (for every use that was replaced).
-  void replaceAllUsesWith(Value from, Value to) {
+  virtual void replaceAllUsesWith(Value from, Value to) {
     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db545..1797ee0876e437 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -697,9 +697,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       ArrayRef<TypeConverter::SignatureConversion> blockConversions);
 
-  /// Replace all the uses of the block argument `from` with value `to`.
-  void replaceUsesOfBlockArgument(BlockArgument from, Value to);
-
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
   /// of failure, the remapped value otherwise.
@@ -720,6 +717,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
+  /// Find uses of `from` and replace them with `to`.
+  ///
+  /// Note: This function does not convert types.
+  void replaceAllUsesWith(Value from, Value to) override;
+
   /// PatternRewriter hook for replacing an operation.
   void replaceOp(Operation *op, ValueRange newValues) override;
 
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 53b44aa3241bb1..d7ed9a196e8938 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -310,7 +310,7 @@ static void modifyFuncOpToUseBarePtrCallingConv(
     Location loc = funcOp.getLoc();
     auto placeholder = rewriter.create<LLVM::UndefOp>(
         loc, typeConverter.convertType(memrefTy));
-    rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+    rewriter.replaceAllUsesWith(arg, placeholder);
 
     Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
                                                    memrefTy, arg);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 73d418cb841327..c6d2ddac9dbb19 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -201,7 +201,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
           llvmFuncOp.getBody().getArgument(remapping->inputNo);
       auto placeholder = rewriter.create<LLVM::UndefOp>(
           loc, getTypeConverter()->convertType(memrefTy));
-      rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
+      rewriter.replaceAllUsesWith(newArg, placeholder);
       Value desc = MemRefDescriptor::fromStaticShape(
           rewriter, loc, *getTypeConverter(), memrefTy, newArg);
       rewriter.replaceOp(placeholder, {desc});
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c1a261eab8487d..e4a022b7a0288b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numReplacedOps)
+                unsigned numErasedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numReplacedOps(numReplacedOps) {}
+        numErasedOps(numErasedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -163,8 +163,8 @@ struct RewriterState {
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
 
-  /// The current number of replaced ops that are scheduled for erasure.
-  unsigned numReplacedOps;
+  /// The current number of ops that are scheduled for erasure.
+  unsigned numErasedOps;
 };
 
 //===----------------------------------------------------------------------===//
@@ -190,13 +190,14 @@ class IRRewrite {
     InlineBlock,
     MoveBlock,
     BlockTypeConversion,
-    ReplaceBlockArg,
     // Operation rewrites
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
     CreateOperation,
-    UnresolvedMaterialization
+    UnresolvedMaterialization,
+    // Value rewrites
+    ReplaceAllUses
   };
 
   virtual ~IRRewrite() = default;
@@ -243,7 +244,7 @@ class BlockRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::CreateBlock &&
-           rewrite->getKind() <= Kind::ReplaceBlockArg;
+           rewrite->getKind() <= Kind::BlockTypeConversion;
   }
 
 protected:
@@ -487,27 +488,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   const TypeConverter *converter;
 };
 
-/// Replacing a block argument. This rewrite is not immediately reflected in the
-/// IR. An internal IR mapping is updated, but the actual replacement is delayed
-/// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
-public:
-  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                         Block *block, BlockArgument arg)
-      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
-
-  static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::ReplaceBlockArg;
-  }
-
-  void commit(RewriterBase &rewriter) override;
-
-  void rollback() override;
-
-private:
-  BlockArgument arg;
-};
-
 /// An operation rewrite.
 class OperationRewrite : public IRRewrite {
 public:
@@ -751,6 +731,44 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// The original output type. This is only used for argument conversions.
   Type origOutputType;
 };
+
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+  /// Return the operation that this rewrite operates on.
+  Value getValue() const { return value; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() >= Kind::ReplaceAllUses &&
+           rewrite->getKind() <= Kind::ReplaceAllUses;
+  }
+
+protected:
+  ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+               Value value)
+      : IRRewrite(kind, rewriterImpl), value(value) {}
+
+  // The value that this rewrite operates on.
+  Value value;
+};
+
+/// Replacing a value. This rewrite is not immediately reflected in the IR. An
+/// internal IR mapping is updated, but the actual replacement is delayed until
+/// the rewrite is committed.
+class ReplaceAllUsesRewrite : public ValueRewrite {
+public:
+  ReplaceAllUsesRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                        Value value)
+      : ValueRewrite(Kind::ReplaceAllUses, rewriterImpl, value) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ReplaceAllUses;
+  }
+
+  void commit(RewriterBase &rewriter) override;
+
+  void rollback() override;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -832,8 +850,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// converted.
   bool isOpIgnored(Operation *op) const;
 
-  /// Return "true" if the given operation was replaced or erased.
-  bool wasOpReplaced(Operation *op) const;
+  /// Return "true" if the given operation is scheduled for erasure. (It may
+  /// still be visible in the IR, but should not be accessed.)
+  bool wasOpErased(Operation *op) const;
 
   //===--------------------------------------------------------------------===//
   // Type Conversion
@@ -982,11 +1001,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// tracked separately.
   SetVector<Operation *> ignoredOps;
 
-  /// A set of operations that were replaced/erased. Such ops are not erased
-  /// immediately but only when the dialect conversion succeeds. In the mean
-  /// time, they should no longer be considered for legalization and any attempt
-  /// to modify/access them is invalid rewriter API usage.
-  SetVector<Operation *> replacedOps;
+  /// A set of operations that were erased. Such ops are not erased immediately
+  /// but only when the dialect conversion succeeds. In the mean time, they
+  /// should no longer be considered for legalization and any attempt to
+  /// modify/access them is invalid rewriter API usage.
+  SetVector<Operation *> erasedOps;
 
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
@@ -1099,13 +1118,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
   return success();
 }
 
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
-  if (!repl)
-    return;
+void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.mapping.lookupOrNull(value);
+  assert(repl && "expected that value is mapped");
 
   if (isa<BlockArgument>(repl)) {
-    rewriter.replaceAllUsesWith(arg, repl);
+    rewriter.replaceAllUsesWith(value, repl);
     return;
   }
 
@@ -1114,13 +1132,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   // replacement value.
   Operation *replOp = cast<OpResult>(repl).getOwner();
   Block *replBlock = replOp->getBlock();
-  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+void ReplaceAllUsesRewrite::rollback() { rewriterImpl.mapping.erase(value); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
@@ -1205,7 +1223,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(), erasedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1216,8 +1234,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
 
-  while (replacedOps.size() != state.numReplacedOps)
-    replacedOps.pop_back();
+  while (erasedOps.size() != state.numErasedOps)
+    erasedOps.pop_back();
 }
 
 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1282,13 +1300,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
 }
 
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
-  // Check to see if this operation is ignored or was replaced.
-  return replacedOps.count(op) || ignoredOps.count(op);
+  // Check to see if this operation is ignored or was erased.
+  return erasedOps.count(op) || ignoredOps.count(op);
 }
 
-bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
-  // Check to see if this operation was replaced.
-  return replacedOps.count(op);
+bool ConversionPatternRewriterImpl::wasOpErased(Operation *op) const {
+  // Check to see if this operation was scheduled for erasure.
+  return erasedOps.count(op);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1434,7 +1452,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
       mapping.map(origArg, inputMap->replacementValue);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      appendRewrite<ReplaceAllUsesRewrite>(origArg);
       continue;
     }
 
@@ -1469,7 +1487,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     }
 
     mapping.map(origArg, newArg);
-    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+    appendRewrite<ReplaceAllUsesRewrite>(origArg);
     argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
@@ -1535,8 +1553,8 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  assert(!wasOpReplaced(op->getParentOp()) &&
-         "attempting to insert into a block within a replaced/erased op");
+  assert(!wasOpErased(op->getParentOp()) &&
+         "attempting to insert into a block within an erased op");
 
   if (!previous.isSet()) {
     // This is a newly created op.
@@ -1571,8 +1589,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
                                          resultChanged);
 
-  // Mark this operation and all nested ops as replaced.
-  op->walk([&](Operation *op) { replacedOps.insert(op); });
+  // Mark this operation and all nested ops as erased.
+  op->walk([&](Operation *op) { erasedOps.insert(op); });
 }
 
 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
@@ -1583,8 +1601,8 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
-  assert(!wasOpReplaced(block->getParentOp()) &&
-         "attempting to insert into a region within a replaced/erased op");
+  assert(!wasOpErased(block->getParentOp()) &&
+         "attempting to insert into a region within an erased op");
   LLVM_DEBUG(
       {
         Operation *parent = block->getParentOp();
@@ -1660,8 +1678,8 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
-  assert(!impl->wasOpReplaced(block->getParentOp()) &&
-         "attempting to erase a block within a replaced/erased op");
+  assert(!impl->wasOpErased(block->getParentOp()) &&
+         "attempting to erase a block within an erased op");
 
   // Mark all ops for erasure.
   for (Operation &op : *block)
@@ -1678,41 +1696,59 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
 Block *ConversionPatternRewriter::applySignatureConversion(
     Region *region, TypeConverter::SignatureConversion &conversion,
     const TypeConverter *converter) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
+  assert(!impl->wasOpErased(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within an "
+         "erased op");
   return impl->applySignatureConversion(*this, region, conversion, converter);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
     Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
+  assert(!impl->wasOpErased(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within an "
+         "erased op");
   return impl->convertRegionTypes(*this, region, converter, entryConversion);
 }
 
 LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
     Region *region, const TypeConverter &converter,
     ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
+  assert(!impl->wasOpErased(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within an "
+         "erased op");
   return impl->convertNonEntryRegionTypes(*this, region, converter,
                                           blockConversions);
 }
 
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           Value to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, Value to) {
+#ifndef NDEBUG
   LLVM_DEBUG({
-    Operation *parentOp = from.getOwner()->getParentOp();
-    impl->logger.startLine() << "** Replace Argument : '" << from
-                             << "'(in region of '" << parentOp->getName()
-                             << "'(" << from.getOwner()->getParentOp() << ")\n";
+    Block *parentBlock = from.getParentBlock();
+    Operation *parentOp = parentBlock ? parentBlock->getParentOp() : nullptr;
+    impl->logger.startLine() << "** Replace value : '" << from;
+    if (parentOp) {
+      impl->logger.getOStream() << "' (in region of '" << parentOp->getName()
+                                << "'(" << parentOp << ")\n";
+    } else {
+      impl->logger.getOStream() << "' (detached)\n";
+    }
   });
-  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
-  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+  if (OpResult opResult = dyn_cast<OpResult>(from)) {
+    assert(!impl->wasOpErased(opResult.getDefiningOp()) &&
+           "attempting to replace an OpResult defined by an erased op");
+  }
+  if (OpResult opResult = dyn_cast<OpResult>(to)) {
+    assert(!impl->wasOpErased(opResult.getDefiningOp()) &&
+           "attempting to replace with an OpResult defined by an erased op");
+  }
+  // A value cannot be replaced multiple times. That would likely require a more
+  // fine-grained tracking of replacements (i.e., each use must be tracked).
+  assert(!impl->mapping.lookupOrNull(from) && "value was already replaced");
+#endif // NDEBUG
+
+  impl->appendRewrite<ReplaceAllUsesRewrite>(from);
+  impl->mapping.map(from, to);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -1738,10 +1774,10 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 #ifndef NDEBUG
   assert(argValues.size() == source->getNumArguments() &&
          "incorrect # of argument replacement values");
-  assert(!impl->wasOpReplaced(source->getParentOp()) &&
-         "attempting to inline a block from a replaced/erased op");
-  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
-         "attempting to inline a block into a replaced/erased op");
+  assert(!impl->wasOpErased(source->getParentOp()) &&
+         "attempting to inline a block from an erased op");
+  assert(!impl->wasOpErased(dest->getParentOp()) &&
+         "attempting to inline a block into an erased op");
   auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
   // The source block will be deleted, so it should not have any users (i.e.,
   // there should be no predecessors).
@@ -1762,7 +1798,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 
   // Replace all uses of block arguments.
   for (auto it : llvm::zip(source->getArguments(), argValues))
-    replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
   if (fastPath) {
     // Move all ops at once.
@@ -1778,8 +1814,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 }
 
 void ConversionPatternRewriter::startOpModification(Operation *op) {
-  assert(...
[truncated]

void commit(RewriterBase &rewriter) override;

void rollback() override;
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole flow of Dialect conversion tracking of these changes is too complicated for me to know whether the commit/rollback logic is safe and complete here :(
It's likely that only testing will tell, but that's unfortunate!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the dialect conversion is too complicated. I'm working on this for over a month now and there are still parts that I do not understand.

Thinking longer term, what would make the design much simpler: no automatic rollback, materialize all IR changes immediately (in particular, materialize unrealized_conversion_casts eagerly and no more adaptors). Maybe as part of a dialect conversion v2 built on top of the existing RewritePattern, ConversionPolicy, TypeConverter...

@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_all_uses_with_nontemplate branch from e70c754 to 48279f3 Compare April 2, 2024 01:55
Base automatically changed from users/matthias-springer/replace_all_uses_with_nontemplate to main April 2, 2024 02:03
This commit adds support for `RewriterBase::replaceAllUsesWith` to the dialect conversion. Uses are not immediately replaced, but in a delayed fashion during the "commit" phase. No type conversions are performed; this is consistent with `ConversionPatternRewriter::replaceUsesOfBlockArgument`.

- `RewriterBase::replaceAllUsesWith` is now virtual, so that it can be overridden in the dialect conversion. Note: `RewriterBase::replaceOp` can now be turned into a non-virtual function in a follow-up commit.
- `ConversionPatternRewriter::replaceUsesOfBlockArgument` is generalized to `ConversionPatternRewriter::replaceAllUsesWith`, following the same implementation strategy.
- A new kind of "IR rewrite" is added: `ValueRewrite` with `ReplaceAllUsesRewrite` (replacing `ReplaceBlockArgRewrite`) as the only value rewrite for now.
- `replacedOps` is renamed to `erasedOps` to better capture its meaning.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:gpu mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants