Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms][NFC] Turn op/block arg replacements into IRRewrites #81757

Merged
merged 1 commit into from
Feb 23, 2024

Conversation

matthias-springer
Copy link
Member

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).

Until now, op replacements and block argument replacements were kept track in separate data structures inside the dialect conversion. This commit turns them into IRRewrites, so that they can be committed or rolled back just like any other rewrite. This simplifies the internal state of the dialect conversion.

Overview of changes:

  • Add two new rewrite classes: ReplaceBlockArgRewrite and ReplaceOperationRewrite. Remove the OpReplacement helper class; it is now part of ReplaceOperationRewrite.
  • Simplify RewriterState: numReplacements and numArgReplacements are no longer needed. (Now being kept track of by numRewrites.)
  • Add IRRewrite::cleanup. Operations should not be erased in commit because they may still be referenced in other internal state of the dialect conversion (mapping). Detaching operations is fine.
  • trackedOps are now updated during the "commit" phase instead of after applying all rewrites.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Feb 14, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 14, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).

Until now, op replacements and block argument replacements were kept track in separate data structures inside the dialect conversion. This commit turns them into IRRewrites, so that they can be committed or rolled back just like any other rewrite. This simplifies the internal state of the dialect conversion.

Overview of changes:

  • Add two new rewrite classes: ReplaceBlockArgRewrite and ReplaceOperationRewrite. Remove the OpReplacement helper class; it is now part of ReplaceOperationRewrite.
  • Simplify RewriterState: numReplacements and numArgReplacements are no longer needed. (Now being kept track of by numRewrites.)
  • Add IRRewrite::cleanup. Operations should not be erased in commit because they may still be referenced in other internal state of the dialect conversion (mapping). Detaching operations is fine.
  • trackedOps are now updated during the "commit" phase instead of after applying all rewrites.

Patch is 21.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81757.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+159-138)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b2baa88879b6e9..a07c8a56822de5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,14 +153,12 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
-                unsigned numReplacements, unsigned numArgReplacements,
                 unsigned numRewrites, unsigned numIgnoredOperations,
                 unsigned numErased)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
-        numReplacements(numReplacements),
-        numArgReplacements(numArgReplacements), numRewrites(numRewrites),
-        numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
+        numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+        numErased(numErased) {}
 
   /// The current number of created operations.
   unsigned numCreatedOps;
@@ -168,12 +166,6 @@ struct RewriterState {
   /// The current number of unresolved materializations.
   unsigned numUnresolvedMaterializations;
 
-  /// The current number of replacements queued.
-  unsigned numReplacements;
-
-  /// The current number of argument replacements queued.
-  unsigned numArgReplacements;
-
   /// The current number of rewrites performed.
   unsigned numRewrites;
 
@@ -184,20 +176,6 @@ struct RewriterState {
   unsigned numErased;
 };
 
-//===----------------------------------------------------------------------===//
-// OpReplacement
-
-/// This class represents one requested operation replacement via 'replaceOp' or
-/// 'eraseOp`.
-struct OpReplacement {
-  OpReplacement(const TypeConverter *converter = nullptr)
-      : converter(converter) {}
-
-  /// An optional type converter that can be used to materialize conversions
-  /// between the new and old values if necessary.
-  const TypeConverter *converter;
-};
-
 //===----------------------------------------------------------------------===//
 // UnresolvedMaterialization
 
@@ -318,8 +296,10 @@ class IRRewrite {
     MoveBlock,
     SplitBlock,
     BlockTypeConversion,
+    ReplaceBlockArg,
     MoveOperation,
-    ModifyOperation
+    ModifyOperation,
+    ReplaceOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -330,6 +310,12 @@ class IRRewrite {
   /// Commit the rewrite.
   virtual void commit() {}
 
+  /// Cleanup operations. Operations may be unlinked from their blocks during
+  /// the commit/rollback phase, but they must not be erased yet. This is
+  /// because internal dialect conversion state (such as `mapping`) may still
+  /// be using them. Operations must be erased during cleanup.
+  virtual void cleanup() {}
+
   /// Erase the given op (unless it was already erased).
   void eraseOp(Operation *op);
 
@@ -356,7 +342,7 @@ class BlockRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::CreateBlock &&
-           rewrite->getKind() <= Kind::BlockTypeConversion;
+           rewrite->getKind() <= Kind::ReplaceBlockArg;
   }
 
 protected:
@@ -424,6 +410,8 @@ class EraseBlockRewrite : public BlockRewrite {
   void commit() override {
     // Erase the block.
     assert(block && "expected block");
+    assert(block->empty() && "expected empty block");
+    block->dropAllDefinedValueUses();
     delete block;
     block = nullptr;
   }
@@ -585,6 +573,27 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   const TypeConverter *converter;
 };
 
+/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but the actual replacement is delayed
+/// until the rewrite is committed.
+class ReplaceBlockArgRewrite : public BlockRewrite {
+public:
+  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                         Block *block, BlockArgument arg)
+      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ReplaceBlockArg;
+  }
+
+  void commit() override;
+
+  void rollback() override;
+
+private:
+  BlockArgument arg;
+};
+
 /// An operation rewrite.
 class OperationRewrite : public IRRewrite {
 public:
@@ -593,7 +602,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::ModifyOperation;
+           rewrite->getKind() <= Kind::ReplaceOperation;
   }
 
 protected:
@@ -664,6 +673,41 @@ class ModifyOperationRewrite : public OperationRewrite {
   SmallVector<Value, 8> operands;
   SmallVector<Block *, 2> successors;
 };
+
+/// Replacing an operation. Erasing an operation is treated as a special case
+/// with "null" replacements. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but values are not replaced and the
+/// original op is not erased until the rewrite is committed.
+class ReplaceOperationRewrite : public OperationRewrite {
+public:
+  ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                          Operation *op, const TypeConverter *converter,
+                          bool changedResults)
+      : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
+        converter(converter), changedResults(changedResults) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ReplaceOperation;
+  }
+
+  void commit() override;
+
+  void rollback() override;
+
+  void cleanup() override;
+
+private:
+  friend struct OperationConverter;
+
+  /// An optional type converter that can be used to materialize conversions
+  /// between the new and old values if necessary.
+  const TypeConverter *converter;
+
+  /// A vector of indices into `replacements` of operations that were replaced
+  /// with values with different result types than the original operation, e.g.
+  /// 1->N conversion of some kind.
+  bool changedResults;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -856,6 +900,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     void eraseBlock(Block *block) override {
       if (erased.contains(block))
         return;
+      assert(block->empty() && "expected empty block");
       block->dropAllDefinedValueUses();
       RewriterBase::eraseBlock(block);
     }
@@ -887,12 +932,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// conversion.
   SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
 
-  /// Ordered map of requested operation replacements.
-  llvm::MapVector<Operation *, OpReplacement> replacements;
-
-  /// Ordered vector of any requested block argument replacements.
-  SmallVector<BlockArgument, 4> argReplacements;
-
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
@@ -907,11 +946,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// operation was ignored.
   SetVector<Operation *> ignoredOps;
 
-  /// A vector of indices into `replacements` of operations that were replaced
-  /// with values with different result types than the original operation, e.g.
-  /// 1->N conversion of some kind.
-  SmallVector<unsigned, 4> operationsWithChangedResults;
-
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
   const TypeConverter *currentTypeConverter = nullptr;
@@ -923,6 +957,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// This allows the user to collect the match failure message.
   function_ref<void(Diagnostic &)> notifyCallback;
 
+  DenseSet<Operation *> *trackedOps = nullptr;
+
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -969,6 +1005,8 @@ void BlockTypeConversionRewrite::commit() {
     }
   }
 
+  assert(origBlock->empty() && "expected empty block");
+  origBlock->dropAllDefinedValueUses();
   delete origBlock;
   origBlock = nullptr;
 }
@@ -1031,6 +1069,47 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
   return success();
 }
 
+void ReplaceBlockArgRewrite::commit() {
+  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
+  if (!repl)
+    return;
+
+  if (isa<BlockArgument>(repl)) {
+    arg.replaceAllUsesWith(repl);
+    return;
+  }
+
+  // If the replacement value is an operation, we check to make sure that we
+  // don't replace uses that are within the parent operation of the
+  // replacement value.
+  Operation *replOp = cast<OpResult>(repl).getOwner();
+  Block *replBlock = replOp->getBlock();
+  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+    Operation *user = operand.getOwner();
+    return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+  });
+}
+
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+
+void ReplaceOperationRewrite::commit() {
+  for (OpResult result : op->getResults())
+    if (Value newValue =
+            rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+      result.replaceAllUsesWith(newValue);
+  if (rewriterImpl.trackedOps)
+    rewriterImpl.trackedOps->erase(op);
+  // Do not erase the operation yet. It may still be referenced in `mapping`.
+  op->getBlock()->getOperations().remove(op);
+}
+
+void ReplaceOperationRewrite::rollback() {
+  for (auto result : op->getResults())
+    rewriterImpl.mapping.erase(result);
+}
+
+void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+
 void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
   for (Region &region : op->getRegions()) {
     for (Block &block : region.getBlocks()) {
@@ -1053,51 +1132,16 @@ void ConversionPatternRewriterImpl::discardRewrites() {
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
-  // Apply all of the rewrites replacements requested during conversion.
-  for (auto &repl : replacements) {
-    for (OpResult result : repl.first->getResults())
-      if (Value newValue = mapping.lookupOrNull(result, result.getType()))
-        result.replaceAllUsesWith(newValue);
-  }
-
-  // Apply all of the requested argument replacements.
-  for (BlockArgument arg : argReplacements) {
-    Value repl = mapping.lookupOrNull(arg, arg.getType());
-    if (!repl)
-      continue;
-
-    if (isa<BlockArgument>(repl)) {
-      arg.replaceAllUsesWith(repl);
-      continue;
-    }
-
-    // If the replacement value is an operation, we check to make sure that we
-    // don't replace uses that are within the parent operation of the
-    // replacement value.
-    Operation *replOp = cast<OpResult>(repl).getOwner();
-    Block *replBlock = replOp->getBlock();
-    arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
-      Operation *user = operand.getOwner();
-      return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
-    });
-  }
+  // Commit all rewrites.
+  for (auto &rewrite : rewrites)
+    rewrite->commit();
+  for (auto &rewrite : rewrites)
+    rewrite->cleanup();
 
   // Drop all of the unresolved materialization operations created during
   // conversion.
   for (auto &mat : unresolvedMaterializations)
     eraseRewriter.eraseOp(mat.getOp());
-
-  // In a second pass, erase all of the replaced operations in reverse. This
-  // allows processing nested operations before their parent region is
-  // destroyed. Because we process in reverse order, producers may be deleted
-  // before their users (a pattern deleting a producer and then the consumer)
-  // so we first drop all uses explicitly.
-  for (auto &repl : llvm::reverse(replacements))
-    eraseRewriter.eraseOp(repl.first);
-
-  // Commit all rewrites.
-  for (auto &rewrite : rewrites)
-    rewrite->commit();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1105,28 +1149,14 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
-                       replacements.size(), argReplacements.size(),
                        rewrites.size(), ignoredOps.size(),
                        eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
-  // Reset any replaced arguments.
-  for (BlockArgument replacedArg :
-       llvm::drop_begin(argReplacements, state.numArgReplacements))
-    mapping.erase(replacedArg);
-  argReplacements.resize(state.numArgReplacements);
-
   // Undo any rewrites.
   undoRewrites(state.numRewrites);
 
-  // Reset any replaced operations and undo any saved mappings.
-  for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
-    for (auto result : repl.first->getResults())
-      mapping.erase(result);
-  while (replacements.size() != state.numReplacements)
-    replacements.pop_back();
-
   // Pop all of the newly inserted materializations.
   while (unresolvedMaterializations.size() !=
          state.numUnresolvedMaterializations) {
@@ -1151,11 +1181,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
 
-  // Reset operations with changed results.
-  while (!operationsWithChangedResults.empty() &&
-         operationsWithChangedResults.back() >= state.numReplacements)
-    operationsWithChangedResults.pop_back();
-
   while (eraseRewriter.erased.size() != state.numErased)
     eraseRewriter.erased.pop_back();
 }
@@ -1224,7 +1249,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
 
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
   // Check to see if this operation was replaced or its parent ignored.
-  return replacements.count(op) || ignoredOps.count(op->getParentOp());
+  return ignoredOps.count(op->getParentOp()) ||
+         llvm::any_of(rewrites, [&](auto &rewrite) {
+           auto *opReplacement =
+               dyn_cast<ReplaceOperationRewrite>(rewrite.get());
+           if (!opReplacement)
+             return false;
+           return opReplacement->getOperation() == op;
+         });
 }
 
 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1374,7 +1406,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
       mapping.map(origArg, inputMap->replacementValue);
-      argReplacements.push_back(origArg);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
     }
 
@@ -1408,7 +1440,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     }
 
     mapping.map(origArg, newArg);
-    argReplacements.push_back(origArg);
+    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
     argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
@@ -1440,7 +1472,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
                                                      ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
-  assert(!replacements.count(op) && "operation was already replaced");
+#ifndef NDEBUG
+  for (auto &rewrite : rewrites)
+    if (auto *opReplacement = dyn_cast<ReplaceOperationRewrite>(rewrite.get()))
+      assert(opReplacement->getOperation() != op &&
+             "operation was already replaced");
+#endif // NDEBUG
 
   // Track if any of the results changed, e.g. erased and replaced with null.
   bool resultChanged = false;
@@ -1455,11 +1492,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
     mapping.map(result, newValue);
     resultChanged |= (newValue.getType() != result.getType());
   }
-  if (resultChanged)
-    operationsWithChangedResults.push_back(replacements.size());
 
-  // Record the requested operation replacement.
-  replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
+  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
+                                         resultChanged);
 
   // Mark this operation as recursively ignored so that we don't need to
   // convert any nested operations.
@@ -1554,8 +1589,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
-  impl->notifyBlockIsBeingErased(block);
-
   // Mark all ops for erasure.
   for (Operation &op : *block)
     eraseOp(&op);
@@ -1564,6 +1597,7 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
   // object and will be actually destroyed when rewrites are applied. This
   // allows us to keep the operations in the block live and undo the removal by
   // re-inserting the block.
+  impl->notifyBlockIsBeingErased(block);
   block->getParent()->getBlocks().remove(block);
 }
 
@@ -1593,7 +1627,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                              << "'(in region of '" << parentOp->getName()
                              << "'(" << from.getOwner()->getParentOp() << ")\n";
   });
-  impl->argReplacements.push_back(from);
+  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
 }
 
@@ -2015,16 +2049,13 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
 
 #ifndef NDEBUG
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
-
   // Check that the root was either replaced or updated in place.
+  auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
   auto replacedRoot = [&] {
-    return llvm::any_of(
-        llvm::drop_begin(impl.replacements, curState.numReplacements),
-        [op](auto &it) { return it.first == op; });
+    return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
   };
   auto updatedRootInPlace = [&] {
-    return hasRewrite<ModifyOperationRewrite>(
-        llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
+    return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
   };
   assert((replacedRoot() || updatedRootInPlace()) &&
          "expected pattern to replace the root operation");
@@ -2057,7 +2088,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     if (!rewrite)
       continue;
     Block *block = rewrite->getBlock();
-    if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
+    if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
+            ReplaceBlockArgRewrite>(rewrite))
       continue;
     // Only check blocks outside of the current operation.
     Operation *parentOp = block->getParentOp();
@@ -2452,6 +2484,7 @@ LogicalResult OperationConverter::convertOperations(
   ConversionPatternRewriter rewriter(ops.front()->getContext());
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
   rewriterImpl.notifyCallback = notifyCallback;
+  rewriterImpl.trackedOps = trackedOps;
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
@@ -2469,13 +2502,6 @@ LogicalResult OperationConverter::convertOperations(
     rewriterImpl.discardRewrites();
   } else {
     rewriterImpl.applyRewrites();
-
-    // It is possible for a later pattern to erase an op that was originally
-    // identified as illegal and added to the trackedOps, remove it now after
-    // replacements have been computed.
-    if (trackedOps)
-      for (auto &repl : rewriterImpl.replacements)
-        trackedOps->erase(repl.first);
   }
   return success();
 }
@@ -2489,21 +2515,20 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
       failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
     return failure();
 
-  if (rewriterImpl.operationsWithChangedResults.empty())
-    return success();
-
   // Process requ...
[truncated]

@matthias-springer matthias-springer force-pushed the users/matthias-springer/block_type_conversion branch from ae08e91 to dcd13b8 Compare February 14, 2024 16:42
@matthias-springer matthias-springer force-pushed the users/matthias-springer/op_replacements branch from 8405efc to 613a616 Compare February 14, 2024 16:43
@matthias-springer matthias-springer force-pushed the users/matthias-springer/block_type_conversion branch from dcd13b8 to 61e82f6 Compare February 16, 2024 15:17
@matthias-springer matthias-springer force-pushed the users/matthias-springer/op_replacements branch from 613a616 to b8d4cbd Compare February 16, 2024 15:20
matthias-springer added a commit that referenced this pull request Feb 21, 2024
When a `ModifyOperationRewrite` is committed, the operation may already
have been erased, so `OperationName` must be cached in the rewrite
object.

Note: This will no longer be needed with #81757, which adds a "cleanup"
method to `IRRewrite`.
matthias-springer added a commit that referenced this pull request Feb 21, 2024
When a `ModifyOperationRewrite` is committed, the operation may already
have been erased, so `OperationName` must be cached in the rewrite
object.

Note: This will no longer be needed with #81757, which adds a "cleanup"
method to `IRRewrite`.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/block_type_conversion branch 2 times, most recently from 68cb259 to 04eb7bd Compare February 22, 2024 08:57
Base automatically changed from users/matthias-springer/block_type_conversion to main February 22, 2024 09:22
@matthias-springer matthias-springer force-pushed the users/matthias-springer/op_replacements branch from b8d4cbd to 886f558 Compare February 22, 2024 09:59
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).

Until now, op replacements and block argument replacements were kept track in separate data structures inside the dialect conversion. This commit turns them into `IRRewrite`s, so that they can be committed or rolled back just like any other rewrite. This simplifies the internal state of the dialect conversion.

Overview of changes:
* Add two new rewrite classes: `ReplaceBlockArgRewrite` and `ReplaceOperationRewrite`. Remove the `OpReplacement` helper class; it is now part of `ReplaceOperationRewrite`.
* Simplify `RewriterState`: `numReplacements` and `numArgReplacements` are no longer needed. (Now being kept track of by `numRewrites`.)
* Add `IRRewrite::cleanup`. Operations should not be erased in `commit` because they may still be referenced in other internal state of the dialect conversion (`mapping`). Detaching operations is fine.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/op_replacements branch from 886f558 to 6f7d3e7 Compare February 22, 2024 10:06
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks

@matthias-springer matthias-springer merged commit d68d295 into main Feb 23, 2024
4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/op_replacements branch February 23, 2024 08:48
@@ -1462,7 +1490,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
assert(!replacements.count(op) && "operation was already replaced");
#ifndef NDEBUG
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this check present before this PR? I have multiple Flang end-to-end tests spending very long time in compilation. I tried to find the guilty PR, but I was using debug compiler builds. This PR increases this test compilation from 26 seconds to 137 seconds:

  Character(1),Parameter :: c717(2,3,4,5,6,7,8) = Reshape([('a',i=1,Size(c717))], Shape(c717))
End

I will have to confirm if this check is causing it, but you may know it right away. If it is indeed an expensive check, should it be only enabled under LLVM_ENABLE_EXPENSIVE_CHECKS?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, false alarm. Disabling this only brings the time to 107 seconds. I will look further, and also try with the release compiler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I reproduce this locally? Debug builds should not make the compile time that much slower. But you are right, this is likely a performance regression of this change or one of the other dialect conversion changes that I submitted recently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that this commit slows down the above FIR to MLIR LLVM IR dialect conversion pass for the Fortran program that Slava gave (*) by 3.5x in release mode too, and that #81759 adds another 1.5x slowdown on that (from 1.8s to 8.8s on my machine in release mode in with the two patch).

Using perf, it seems the slowdown is caused by changes in mlir::detail::ConversionPatternRewriterImpl::isOpIgnored (probably related to the data structure changes).

After the two patches in release mode:

Samples: 158K of event 'cycles', Event count (approx.): 21633186643
  Children      Self  Command  Shared Object         Symbol
+   65.37%    64.59%  fir-opt  fir-opt               [.] mlir::detail::ConversionPatternRewriterImpl::isOpIgnored
+   65.22%     0.00%  fir-opt  [unknown]             [.] 0x0000000000000001
+   65.18%     0.00%  fir-opt  [unknown]             [.] 0x480029bc01058d48
+   65.18%     0.00%  fir-opt  fir-opt               [.] mlir::detail::ConversionPatternRewriterImpl::~ConversionPatternRewriterImpl
+   24.25%     0.00%  fir-opt  [unknown]             [.] 0x0000000100000001                                                
+   24.19%    23.65%  fir-opt  fir-opt               [.] mlir::detail::ConversionPatternRewriterImpl::notifyOpReplaced
+   24.05%     0.00%  fir-opt  [unknown]             [.] 0x26ee058d48fb8948                                
+   24.05%     0.00%  fir-opt  fir-opt               [.] mlir::RegisteredOperationName::Model<fir::InsertValueOp>::~Model              
+   24.05%     0.00%  fir-opt  [unknown]             [.] 0x000055a74703d488
+    0.82%     0.00%  fir-opt  [unknown]             [k] 0000000000000000                                                          
     0.78%     0.47%  fir-opt  fir-opt               [.] mlir::Lexer::lexBareIdentifierOrKeyword
....

Before the patches, mlir::detail::ConversionPatternRewriterImpl::isOpIgnored was nowhere to see in the perf report (MLIR IO dominated the run).

You can reproduce if you have flang builds enabled, and with repro.f90 that is the Fortran source from Slava above with:

# First phase of compilation not impacted by patch
bin/bbc -emit-fir repro.f90 -o - | bin/fir-opt --cg-rewrite -o -input.fir
# FIR to LLVM dialect conversion pass impacted by pass
time bin/fir-opt --fir-to-llvm-ir -input.fir -o output.mlir

I will try to see if I can come with a pure MLIR reproducer.

(*) about the Fortran program: this program generates a global that is a 7d array of 40320 chars with an initial value.

So far, flang generates a chain of insert_value for character types, so the operation that is impacted by the slow-down is a fir.global where the body contains a chain of 40320 fir.insert_value + 3 other ops (the value being inserted, and the terminator). We are planning to move away from this and use attribute for global initializer as much as possible. However, the slow down will likely kick in for every functions with more than a few thousands ops, and this is quite easily reached with big Fortran programs. Global constants are just an easy way to create a lot of IR with a few lines of Fortran to reproduce the issue.

 fir.global internal @_QFECc717 constant : !fir.array<2x3x4x5x6x7x8x!fir.char<1>> {
    %0 = fir.undefined !fir.array<2x3x4x5x6x7x8x!fir.char<1>>
    %1 = fir.string_lit "a"(1) : !fir.char<1>
    %2 = fir.insert_value %0, %1, [0 : index, 0 : index, 0 : index, 0 : index, 0 : index, 0 : index, 0 : index] : (!fir.array<2x3x4x5x6x7x8x!fir.char<1>>, !fir.char<1>) -> !fir.array<2x3x4x5x6x7x8x!fir.char<1>>
   // ....
   %40321 = fir.insert_value %40320, %1, [1 : index, 2 : index, 3 : index, 4 : index, 5 : index, 6 : index, 7 : index] : (!fir.array<2x3x4x5x6x7x8x!fir.char<1>>, !fir.char<1>) -> !fir.array<2x3x4x5x6x7x8x!fir.char<1>>
    fir.has_value %40321 : !fir.array<2x3x4x5x6x7x8x!fir.char<1>>
  }

This is being rewritten to very similar LLVM dialect IR:

 llvm.mlir.global internal constant @_QFECc717() {addr_space = 0 : i32} : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>> {
    %0 = llvm.mlir.undef : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>>
    %1 = llvm.mlir.constant("a") : !llvm.array<1 x i8>
    %2 = llvm.insertvalue %1, %0[0, 0, 0, 0, 0, 0, 0] : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>>
    // ...
    %40321 = llvm.insertvalue %1, %40320[7, 6, 5, 4, 3, 2, 1] : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>>
    llvm.return %40321 : !llvm.array<8 x array<7 x array<6 x array<5 x array<4 x array<3 x array<2 x array<1 x i8>>>>>>>>
  }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I traced the issue back to the same function (isOpIgnored) on some other tests.

Thx for the detailed instructions. Looks like it fixes the test case:

time build/bin/fir-opt --fir-to-llvm-ir input.fir -o output.mlir                     
build/bin/fir-opt --fir-to-llvm-ir input.fir -o output.mlir  0.73s user 0.08s system 100% cpu 0.816 total

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still waiting for a CI run on another project that had a regression, to make sure that the issue is fixed there as well...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix! I confirm this fixes all the compilation time slowdown on the Fortran program.

matthias-springer added a commit that referenced this pull request Feb 26, 2024
The dialect conversion does not directly erase ops that are replaced/erased with a rewriter. Instead, the op stays in place and is erased at the end if the dialect conversion succeeds. However, ops that were replaced/erased are ignored from that point on.

#81757 introduced a compile time regression that made the check whether an op is ignored or not more expensive. Whether an op is ignored or not is queried many times throughout a dialect conversion, so the check must be fast.

After this change, replaced ops are stored in the `ignoredOps` set. This also simplifies the dialect conversion a bit.
matthias-springer added a commit that referenced this pull request Feb 26, 2024
…83023)

The dialect conversion does not directly erase ops that are
replaced/erased with a rewriter. Instead, the op stays in place and is
erased at the end if the dialect conversion succeeds. However, ops that
were replaced/erased are ignored from that point on.

#81757 introduced a compile time regression that made the check whether
an op is ignored or not more expensive. Whether an op is ignored or not
is queried many times throughout a dialect conversion, so the check must
be fast.

After this change, replaced ops are stored in the `ignoredOps` set. This
also simplifies the dialect conversion a bit.
qedawkins pushed a commit to iree-org/llvm-project that referenced this pull request Feb 26, 2024
…lvm#83023)

The dialect conversion does not directly erase ops that are
replaced/erased with a rewriter. Instead, the op stays in place and is
erased at the end if the dialect conversion succeeds. However, ops that
were replaced/erased are ignored from that point on.

llvm#81757 introduced a compile time regression that made the check whether
an op is ignored or not more expensive. Whether an op is ignored or not
is queried many times throughout a dialect conversion, so the check must
be fast.

After this change, replaced ops are stored in the `ignoredOps` set. This
also simplifies the dialect conversion a bit.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants