Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms][NFC] Turn in-place op modification into IRRewrite #81245

Conversation

matthias-springer
Copy link
Member

This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Feb 9, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 9, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.


Full diff: https://github.com/llvm/llvm-project/pull/81245.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+58-70)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ffdb069f6e9b8..d0114a148cd37 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,15 +154,13 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numRewriteActions, unsigned numIgnoredOperations,
-                unsigned numRootUpdates)
+                unsigned numRewriteActions, unsigned numIgnoredOperations)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
         numArgReplacements(numArgReplacements),
         numRewriteActions(numRewriteActions),
-        numIgnoredOperations(numIgnoredOperations),
-        numRootUpdates(numRootUpdates) {}
+        numIgnoredOperations(numIgnoredOperations) {}
 
   /// The current number of created operations.
   unsigned numCreatedOps;
@@ -181,44 +179,6 @@ struct RewriterState {
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
-
-  /// The current number of operations that were updated in place.
-  unsigned numRootUpdates;
-};
-
-//===----------------------------------------------------------------------===//
-// OperationTransactionState
-
-/// The state of an operation that was updated by a pattern in-place. This
-/// contains all of the necessary information to reconstruct an operation that
-/// was updated in place.
-class OperationTransactionState {
-public:
-  OperationTransactionState() = default;
-  OperationTransactionState(Operation *op)
-      : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
-        operands(op->operand_begin(), op->operand_end()),
-        successors(op->successor_begin(), op->successor_end()) {}
-
-  /// Discard the transaction state and reset the state of the original
-  /// operation.
-  void resetOperation() const {
-    op->setLoc(loc);
-    op->setAttrs(attrs);
-    op->setOperands(operands);
-    for (const auto &it : llvm::enumerate(successors))
-      op->setSuccessor(it.value(), it.index());
-  }
-
-  /// Return the original operation of this state.
-  Operation *getOperation() const { return op; }
-
-private:
-  Operation *op;
-  LocationAttr loc;
-  DictionaryAttr attrs;
-  SmallVector<Value, 8> operands;
-  SmallVector<Block *, 2> successors;
 };
 
 //===----------------------------------------------------------------------===//
@@ -758,7 +718,8 @@ class RewriteAction {
     MoveBlock,
     SplitBlock,
     BlockTypeConversion,
-    MoveOperation
+    MoveOperation,
+    ModifyOperation
   };
 
   virtual ~RewriteAction() = default;
@@ -980,7 +941,7 @@ class OperationAction : public RewriteAction {
 
   static bool classof(const RewriteAction *action) {
     return action->getKind() >= Kind::MoveOperation &&
-           action->getKind() <= Kind::MoveOperation;
+           action->getKind() <= Kind::ModifyOperation;
   }
 
 protected:
@@ -1019,6 +980,34 @@ class MoveOperationAction : public OperationAction {
   // this operation was the only operation in the region.
   Operation *insertBeforeOp;
 };
+
+/// Rewrite action that represents the in-place modification of an operation.
+/// The previous state of the operation is stored in this action.
+class ModifyOperationAction : public OperationAction {
+public:
+  ModifyOperationAction(ConversionPatternRewriterImpl &rewriterImpl,
+                        Operation *op)
+      : OperationAction(Kind::ModifyOperation, rewriterImpl, op),
+        loc(op->getLoc()), attrs(op->getAttrDictionary()),
+        operands(op->operand_begin(), op->operand_end()),
+        successors(op->successor_begin(), op->successor_end()) {}
+
+  /// Discard the transaction state and reset the state of the original
+  /// operation.
+  void rollback() override {
+    op->setLoc(loc);
+    op->setAttrs(attrs);
+    op->setOperands(operands);
+    for (const auto &it : llvm::enumerate(successors))
+      op->setSuccessor(it.value(), it.index());
+  }
+
+private:
+  LocationAttr loc;
+  DictionaryAttr attrs;
+  SmallVector<Value, 8> operands;
+  SmallVector<Block *, 2> successors;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1172,9 +1161,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// operation was ignored.
   SetVector<Operation *> ignoredOps;
 
-  /// A transaction state for each of operations that were updated in-place.
-  SmallVector<OperationTransactionState, 4> rootUpdates;
-
   /// A vector of indices into `replacements` of operations that were replaced
   /// with values with different result types than the original operation, e.g.
   /// 1->N conversion of some kind.
@@ -1226,10 +1212,6 @@ static void detachNestedAndErase(Operation *op) {
 }
 
 void ConversionPatternRewriterImpl::discardRewrites() {
-  // Reset any operations that were updated in place.
-  for (auto &state : rootUpdates)
-    state.resetOperation();
-
   undoRewriteActions();
 
   // Remove any newly created ops.
@@ -1304,16 +1286,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       rewriteActions.size(), ignoredOps.size(),
-                       rootUpdates.size());
+                       rewriteActions.size(), ignoredOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
-  // Reset any operations that were updated in place.
-  for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
-    rootUpdates[i].resetOperation();
-  rootUpdates.resize(state.numRootUpdates);
-
   // Reset any replaced arguments.
   for (BlockArgument replacedArg :
        llvm::drop_begin(argReplacements, state.numArgReplacements))
@@ -1740,7 +1716,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
 #endif
-  impl->rootUpdates.emplace_back(op);
+  impl->appendRewriteAction<ModifyOperationAction>(op);
 }
 
 void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
@@ -1759,13 +1735,17 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
          "operation did not have a pending in-place update");
 #endif
   // Erase the last update for this operation.
-  auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
-  auto &rootUpdates = impl->rootUpdates;
-  auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
-  assert(it != rootUpdates.rend() && "no root update started on op");
-  (*it).resetOperation();
-  int updateIdx = std::prev(rootUpdates.rend()) - it;
-  rootUpdates.erase(rootUpdates.begin() + updateIdx);
+  auto it =
+      llvm::find_if(llvm::reverse(impl->rewriteActions),
+                    [&](std::unique_ptr<RewriteAction> &action) {
+                      auto *modifyAction =
+                          dynamic_cast<ModifyOperationAction *>(action.get());
+                      return modifyAction && modifyAction->getOperation() == op;
+                    });
+  assert(it != impl->rewriteActions.rend() && "no root update started on op");
+  (*it)->rollback();
+  int updateIdx = std::prev(impl->rewriteActions.rend()) - it;
+  impl->rewriteActions.erase(impl->rewriteActions.begin() + updateIdx);
 }
 
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
@@ -2118,8 +2098,11 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
   };
   auto updatedRootInPlace = [&] {
     return llvm::any_of(
-        llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
-        [op](auto &state) { return state.getOperation() == op; });
+        llvm::drop_begin(impl.rewriteActions, curState.numRewriteActions),
+        [op](auto &action) {
+          auto *modifyAction = dyn_cast<ModifyOperationAction>(action.get());
+          return modifyAction && modifyAction->getOperation() == op;
+        });
   };
   (void)replacedRoot;
   (void)updatedRootInPlace;
@@ -2213,8 +2196,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     RewriterState &state, RewriterState &newState) {
-  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
-    Operation *op = impl.rootUpdates[i].getOperation();
+  for (int i = state.numRewriteActions, e = newState.numRewriteActions; i != e;
+       ++i) {
+    auto *action =
+        dyn_cast<ModifyOperationAction>(impl.rewriteActions[i].get());
+    if (!action)
+      continue;
+    Operation *op = action->getOperation();
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(
           impl.logger, "failed to legalize operation updated in-place '{0}'",

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_move_op_before branch from 7503c0c to ebfaca6 Compare February 12, 2024 09:09
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch from cdbd927 to f7010ea Compare February 12, 2024 09:12
@matthias-springer matthias-springer changed the title [mlir][Transforms][NFC] Turn in-place op modifications into RewriteActions [mlir][Transforms][NFC] Turn in-place op modification into IRRewrite Feb 12, 2024
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_move_op_before branch from ebfaca6 to 1d17c76 Compare February 14, 2024 16:12
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch from f7010ea to 4bb6521 Compare February 14, 2024 16:18
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_move_op_before branch from 1d17c76 to 5e261de Compare February 14, 2024 16:33
Base automatically changed from users/matthias-springer/dialect_conversion_move_op_before to main February 14, 2024 16:40
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch 2 times, most recently from 820bcdd to 1c69f42 Compare February 16, 2024 15:11

private:
LocationAttr loc;
DictionaryAttr attrs;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is properties needed too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it a separate PR (#82474), so that this PR can stay NFC.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG for keeping this a pure refactor

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch from 1c69f42 to 8a8a79d Compare February 21, 2024 15:19
…ction`s

This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_modify_op_inplace branch from 8a8a79d to 5cabd6c Compare February 21, 2024 15:20
@matthias-springer matthias-springer merged commit e214f00 into main Feb 21, 2024
3 of 4 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/dialect_conversion_modify_op_inplace branch February 21, 2024 15:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants