Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Change notifyBlockCreated to notifyBlockInserted #79472

Conversation

matthias-springer
Copy link
Member

This change makes the callback consistent with notifyOperationInserted: both now notify about IR insertion, not IR creation. See also #78988.

This change also simplifies the dialect conversion: it is no longer necessary to override the inlineRegionBefore method. All information that is necessary for rollback is provided with the notifyBlockInserted callback.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 25, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 25, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This change makes the callback consistent with notifyOperationInserted: both now notify about IR insertion, not IR creation. See also #78988.

This change also simplifies the dialect conversion: it is no longer necessary to override the inlineRegionBefore method. All information that is necessary for rollback is provided with the notifyBlockInserted callback.


Full diff: https://github.com/llvm/llvm-project/pull/79472.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/Builders.h (+9-1)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+5-4)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+2-6)
  • (modified) mlir/lib/IR/Builders.cpp (+1-1)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+12-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+25-41)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+6-4)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 6b95be7c6d372f8..8c25a1aa2fad14a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -297,7 +297,15 @@ class OpBuilder : public Builder {
     virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
 
     /// Notify the listener that the specified block was inserted.
-    virtual void notifyBlockCreated(Block *block) {}
+    ///
+    /// * If the block was moved, then `previous` and `previousIt` are the
+    ///   previous location of the block.
+    /// * If the block was unlinked before it was inserted, then `previous`
+    ///   is "nullptr".
+    ///
+    /// Note: Creating an (unlinked) block does not trigger this notification.
+    virtual void notifyBlockInserted(Block *block, Region *previous,
+                                     Region::iterator previousIt) {}
 
   protected:
     Listener(Kind kind) : ListenerBase(kind) {}
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 7f233cd3f4d4b3c..8eb129206b95ef6 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -455,8 +455,9 @@ class RewriterBase : public OpBuilder {
     void notifyOperationInserted(Operation *op, InsertPoint previous) override {
       listener->notifyOperationInserted(op, previous);
     }
-    void notifyBlockCreated(Block *block) override {
-      listener->notifyBlockCreated(block);
+    void notifyBlockInserted(Block *block, Region *previous,
+                             Region::iterator previousIt) override {
+      listener->notifyBlockInserted(block, previous, previousIt);
     }
     void notifyBlockRemoved(Block *block) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
@@ -495,8 +496,8 @@ class RewriterBase : public OpBuilder {
   /// another region "parent". The two regions must be different. The caller
   /// is responsible for creating or updating the operation transferring flow
   /// of control to the region and passing it the correct block arguments.
-  virtual void inlineRegionBefore(Region &region, Region &parent,
-                                  Region::iterator before);
+  void inlineRegionBefore(Region &region, Region &parent,
+                          Region::iterator before);
   void inlineRegionBefore(Region &region, Block *before);
 
   /// Clone the blocks that belong to "region" before the given position in
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 32c5937d014e9ef..d9470de9ceb9f56 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -713,7 +713,8 @@ class ConversionPatternRewriter final : public PatternRewriter,
   void eraseBlock(Block *block) override;
 
   /// PatternRewriter hook creating a new block.
-  void notifyBlockCreated(Block *block) override;
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override;
 
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
@@ -723,11 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
                          ValueRange argValues = std::nullopt) override;
   using PatternRewriter::inlineBlockBefore;
 
-  /// PatternRewriter hook for moving blocks out of a region.
-  void inlineRegionBefore(Region &region, Region &parent,
-                          Region::iterator before) override;
-  using PatternRewriter::inlineRegionBefore;
-
   /// PatternRewriter hook for cloning blocks of one region into another. The
   /// given region to clone *must* not have been modified as part of conversion
   /// yet, i.e. it must be within an operation that is either in the process of
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index a319afcdc6a9a23..7acef1073c6de20 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -429,7 +429,7 @@ Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
   setInsertionPointToEnd(b);
 
   if (listener)
-    listener->notifyBlockCreated(b);
+    listener->notifyBlockInserted(b, /*previous=*/nullptr, /*previousIt=*/{});
   return b;
 }
 
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index affb8898fa07544..817bbb363e0d585 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -343,7 +343,18 @@ Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
 /// region and pass it the correct block arguments.
 void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
                                       Region::iterator before) {
-  parent.getBlocks().splice(before, region.getBlocks());
+  // Fast path: If no listener is attached, move all blocks at once.
+  if (!listener) {
+    parent.getBlocks().splice(before, region.getBlocks());
+    return;
+  }
+
+  // Move blocks from the beginning of the region one-by-one.
+  while (!region.empty()) {
+    Block *block = &region.front();
+    parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
+    listener->notifyBlockInserted(block, &region, region.begin());
+  }
 }
 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
   inlineRegionBefore(region, *before->getParent(), before->getIterator());
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f5bede2b94f9cb2..a79e9076fc28faf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -250,10 +250,10 @@ enum class BlockActionKind {
 };
 
 /// Original position of the given block in its parent region. During undo
-/// actions, the block needs to be placed after `insertAfterBlock`.
+/// actions, the block needs to be placed before `insertBeforeBlock`.
 struct BlockPosition {
   Region *region;
-  Block *insertAfterBlock;
+  Block *insertBeforeBlock;
 };
 
 /// Information needed to undo inlining actions.
@@ -910,7 +910,8 @@ struct ConversionPatternRewriterImpl {
   void notifyBlockIsBeingErased(Block *block);
 
   /// Notifies that a block was created.
-  void notifyCreatedBlock(Block *block);
+  void notifyInsertedBlock(Block *block, Region *previous,
+                           Region::iterator previousIt);
 
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
@@ -919,10 +920,6 @@ struct ConversionPatternRewriterImpl {
   void notifyBlockBeingInlined(Block *block, Block *srcBlock,
                                Block::iterator before);
 
-  /// Notifies that the blocks of a region are about to be moved.
-  void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
-                                        Region::iterator before);
-
   /// Notifies that a pattern match failed for the given reason.
   LogicalResult
   notifyMatchFailure(Location loc,
@@ -1173,10 +1170,9 @@ void ConversionPatternRewriterImpl::undoBlockActions(
     // Put the block (owned by action) back into its original position.
     case BlockActionKind::Erase: {
       auto &blockList = action.originalPosition.region->getBlocks();
-      Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
-      blockList.insert((insertAfterBlock
-                            ? std::next(Region::iterator(insertAfterBlock))
-                            : blockList.begin()),
+      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
+      blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
+                                          : blockList.end()),
                        action.block);
       break;
     }
@@ -1196,10 +1192,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
     // Move the block back to its original position.
     case BlockActionKind::Move: {
       Region *originalRegion = action.originalPosition.region;
-      Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
+      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
       originalRegion->getBlocks().splice(
-          (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
-                            : originalRegion->end()),
+          (insertBeforeBlock ? Region::iterator(insertBeforeBlock)
+                             : originalRegion->end()),
           action.block->getParent()->getBlocks(), action.block);
       break;
     }
@@ -1398,12 +1394,19 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
 
 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
   Region *region = block->getParent();
-  Block *origPrevBlock = block->getPrevNode();
-  blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
+  Block *origNextBlock = block->getNextNode();
+  blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
 }
 
-void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
-  blockActions.push_back(BlockAction::getCreate(block));
+void ConversionPatternRewriterImpl::notifyInsertedBlock(
+    Block *block, Region *previous, Region::iterator previousIt) {
+  if (!previous) {
+    // This is a newly created block.
+    blockActions.push_back(BlockAction::getCreate(block));
+    return;
+  }
+  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
+  blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock}));
 }
 
 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
@@ -1416,19 +1419,6 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
   blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
 }
 
-void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
-    Region &region, Region &parent, Region::iterator before) {
-  if (region.empty())
-    return;
-  Block *laterBlock = &region.back();
-  for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
-    blockActions.push_back(
-        BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
-    laterBlock = &earlierBlock;
-  }
-  blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
-}
-
 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
@@ -1551,8 +1541,9 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                            results);
 }
 
-void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
-  impl->notifyCreatedBlock(block);
+void ConversionPatternRewriter::notifyBlockInserted(
+    Block *block, Region *previous, Region::iterator previousIt) {
+  impl->notifyInsertedBlock(block, previous, previousIt);
 }
 
 Block *ConversionPatternRewriter::splitBlock(Block *block,
@@ -1582,13 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   eraseBlock(source);
 }
 
-void ConversionPatternRewriter::inlineRegionBefore(Region &region,
-                                                   Region &parent,
-                                                   Region::iterator before) {
-  impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
-  PatternRewriter::inlineRegionBefore(region, parent, before);
-}
-
 void ConversionPatternRewriter::cloneRegionBefore(Region &region,
                                                   Region &parent,
                                                   Region::iterator before,
@@ -1600,7 +1584,7 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
 
   for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
     Block *cloned = mapping.lookup(&b);
-    impl->notifyCreatedBlock(cloned);
+    impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
     cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
         [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
   }
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index c27fee7a738eba0..543dab0f309136f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -377,8 +377,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// simplifications.
   void addOperandsToWorklist(ValueRange operands);
 
-  /// Notify the driver that the given block was created.
-  void notifyBlockCreated(Block *block) override;
+  /// Notify the driver that the given block was inserted.
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override;
 
   /// Notify the driver that the given block is about to be removed.
   void notifyBlockRemoved(Block *block) override;
@@ -638,9 +639,10 @@ void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
     worklist.push(op);
 }
 
-void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
+void GreedyPatternRewriteDriver::notifyBlockInserted(
+    Block *block, Region *previous, Region::iterator previousIt) {
   if (config.listener)
-    config.listener->notifyBlockCreated(block);
+    config.listener->notifyBlockInserted(block, previous, previousIt);
 }
 
 void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jan 26, 2024
This change makes the callback consistent with `notifyOperationInserted`: both now notify about IR insertion, not IR creation. See also llvm#78988.

This change also simplifies the dialect conversion: it is no longer necessary to override the `inlineRegionBefore` method. All information that is necessary for rollback is provided with the `notifyBlockInserted` callback.
@matthias-springer matthias-springer merged commit 3ed98cb into llvm:main Jan 26, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants