Skip to content

Commit

Permalink
[mlir][IR] Add rewriter API for moving operations (#78988)
Browse files Browse the repository at this point in the history
The pattern rewriter documentation states that "*all* IR mutations [...]
are required to be performed via the `PatternRewriter`." This commit
adds two functions that were missing from the rewriter API:
`moveOpBefore` and `moveOpAfter`.

After an operation was moved, the `notifyOperationInserted` callback is
triggered. This allows listeners such as the greedy pattern rewrite
driver to react to IR changes.

This commit narrows the discrepancy between the kind of IR modification
that can be performed and the kind of IR modifications that can be
listened to.
  • Loading branch information
matthias-springer committed Jan 25, 2024
1 parent 45fec0c commit 5cc0f76
Show file tree
Hide file tree
Showing 20 changed files with 138 additions and 41 deletions.
6 changes: 5 additions & 1 deletion flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
LLVM_DUMP_METHOD void dumpFunc();

/// FirOpBuilder hook for creating new operation.
void notifyOperationInserted(mlir::Operation *op) override {
void notifyOperationInserted(mlir::Operation *op,
mlir::OpBuilder::InsertPoint previous) override {
// We only care about newly created operations.
if (previous.isSet())
return;
setCommonAttributes(op);
}

Expand Down
7 changes: 4 additions & 3 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
HLFIRListener(fir::FirOpBuilder &builder,
mlir::ConversionPatternRewriter &rewriter)
: builder{builder}, rewriter{rewriter} {}
void notifyOperationInserted(mlir::Operation *op) override {
builder.notifyOperationInserted(op);
rewriter.notifyOperationInserted(op);
void notifyOperationInserted(mlir::Operation *op,
mlir::OpBuilder::InsertPoint previous) override {
builder.notifyOperationInserted(op, previous);
rewriter.notifyOperationInserted(op, previous);
}
virtual void notifyBlockCreated(mlir::Block *block) override {
builder.notifyBlockCreated(block);
Expand Down
20 changes: 13 additions & 7 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class Builder {
/// automatically inserted at an insertion point. The builder is copyable.
class OpBuilder : public Builder {
public:
class InsertPoint;
struct Listener;

/// Create a builder with the given context.
Expand Down Expand Up @@ -285,12 +286,17 @@ class OpBuilder : public Builder {

virtual ~Listener() = default;

/// Notification handler for when an operation is inserted into the builder.
/// `op` is the operation that was inserted.
virtual void notifyOperationInserted(Operation *op) {}

/// Notification handler for when a block is created using the builder.
/// `block` is the block that was created.
/// Notify the listener that the specified operation was inserted.
///
/// * If the operation was moved, then `previous` is the previous location
/// of the op.
/// * If the operation was unlinked before it was inserted, then `previous`
/// is empty.
///
/// Note: Creating an (unlinked) op does not trigger this notification.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}

/// Notify the listener that the specified block was inserted.
virtual void notifyBlockCreated(Block *block) {}

protected:
Expand Down Expand Up @@ -517,7 +523,7 @@ class OpBuilder : public Builder {
if (succeeded(tryFold(op, results)))
op->erase();
else if (listener)
listener->notifyOperationInserted(op);
listener->notifyOperationInserted(op, /*previous=*/{});
}

/// Overload to create or fold a single result operation.
Expand Down
26 changes: 24 additions & 2 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {

/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
///
/// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationRemoved(Operation *op) {}

/// Notify the listener that the pattern failed to match the given
Expand All @@ -450,8 +452,8 @@ class RewriterBase : public OpBuilder {
struct ForwardingListener : public RewriterBase::Listener {
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}

void notifyOperationInserted(Operation *op) override {
listener->notifyOperationInserted(op);
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
listener->notifyOperationInserted(op, previous);
}
void notifyBlockCreated(Block *block) override {
listener->notifyBlockCreated(block);
Expand Down Expand Up @@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);

/// Unlink this operation from its current block and insert it right before
/// `existingOp` which may be in the same or another block in the same
/// function.
void moveOpBefore(Operation *op, Operation *existingOp);

/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
virtual void moveOpBefore(Operation *op, Block *block,
Block::iterator iterator);

/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
/// function.
void moveOpAfter(Operation *op, Operation *existingOp);

/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);

/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or
Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
using PatternRewriter::cloneRegionBefore;

/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op) override;
void notifyOperationInserted(Operation *op, InsertPoint previous) override;

/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
Expand All @@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
detail::ConversionPatternRewriterImpl &getImpl();

private:
// Hide unsupported pattern rewriter API.
using OpBuilder::getListener;
using OpBuilder::setListener;

void moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) override;
void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) override;

std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
if (failed(applyOp->fold(constOperands, foldResults)) ||
foldResults.empty()) {
if (OpBuilder::Listener *listener = b.getListener())
listener->notifyOperationInserted(applyOp);
listener->notifyOperationInserted(applyOp, /*previous=*/{});
return applyOp.getResult();
}

Expand Down Expand Up @@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
if (failed(minMaxOp->fold(constOperands, foldResults)) ||
foldResults.empty()) {
if (OpBuilder::Listener *listener = b.getListener())
listener->notifyOperationInserted(minMaxOp);
listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
return minMaxOp.getResult();
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
// Insert function into the module symbol table and assign it unique name.
SymbolTable symbolTable(module);
symbolTable.insert(func);
rewriter.getListener()->notifyOperationInserted(func);
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});

// Create function entry block.
Block *block =
Expand Down Expand Up @@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
// Insert function into the module symbol table and assign it unique name.
SymbolTable symbolTable(module);
symbolTable.insert(func);
rewriter.getListener()->notifyOperationInserted(func);
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});

// Create function entry block.
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
toMemrefOps.erase(op);
}

void notifyOperationInserted(Operation *op) override {
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
// We only care about newly created ops.
if (previous.isSet())
return;

erasedOps.erase(op);

// Gather statistics about allocs.
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,12 @@ class NewOpsListener : public RewriterBase::ForwardingListener {
}

private:
void notifyOperationInserted(Operation *op) override {
ForwardingListener::notifyOperationInserted(op);
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
ForwardingListener::notifyOperationInserted(op, previous);
// We only care about newly created ops.
if (previous.isSet())
return;
auto inserted = newOps.insert(op);
(void)inserted;
assert(inserted.second && "expected newly created op");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {

// Inline for-loop body operations into 'after' region.
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
arg.moveBefore(afterBlock, afterBlock->end());
rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());

// Add incremented IV to yield operations
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
for (Operation *user : srcBuffer->getUsers()) {
if (hasEffect<MemoryEffects::Free>(user)) {
if (user->getBlock() == parallelCombiningParent->getBlock())
user->moveBefore(user->getBlock()->getTerminator());
rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
break;
}
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) {
block->getOperations().insert(insertPoint, op);

if (listener)
listener->notifyOperationInserted(op);
listener->notifyOperationInserted(op, /*previous=*/{});
return op;
}

Expand Down Expand Up @@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp);
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
};
for (Region &region : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
Expand Down
28 changes: 28 additions & 0 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}

void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
}

void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
Block::iterator currentIterator = op->getIterator();
op->moveBefore(block, iterator);
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
}

void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
}

void RewriterBase::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
Block::iterator currentIterator = op->getIterator();
op->moveAfter(block, iterator);
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
}
18 changes: 16 additions & 2 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
Block *cloned = mapping.lookup(&b);
impl->notifyCreatedBlock(cloned);
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) { notifyOperationInserted(op); });
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
}
}

void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
impl->logger.startLine()
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
Expand Down Expand Up @@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
return impl->notifyMatchFailure(loc, reasonCallback);
}

void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
"moving single ops is not supported in a dialect conversion");
}

void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
"moving single ops is not supported in a dialect conversion");
}

detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
Expand Down
18 changes: 13 additions & 5 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,16 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
}
}

void notifyOperationInserted(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op);
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
// Invalidate the finger print of the op that owns the block into which the
// op was inserted into.
invalidateFingerPrint(op->getParentOp());

// Also invalidate the finger print of the op that owns the block from which
// the op was moved from. (Only applicable if the op was moved.)
if (previous.isSet())
invalidateFingerPrint(previous.getBlock()->getParentOp());
}

void notifyOperationModified(Operation *op) override {
Expand Down Expand Up @@ -331,7 +338,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
void notifyOperationInserted(Operation *op) override;
void notifyOperationInserted(Operation *op, InsertPoint previous) override;

/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
Expand Down Expand Up @@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
config.listener->notifyBlockRemoved(block);
}

void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationInserted(op);
config.listener->notifyOperationInserted(op, previous);
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
OpResult newLoopResult = loopLike.getLoopResults()->back();
extractionOp->moveBefore(loopLike);
insertionOp->moveAfter(loopLike);
rewriter.moveOpBefore(extractionOp, loopLike);
rewriter.moveOpAfter(insertionOp, loopLike);
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
auto ifOp =
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
// True branch.
op->moveBefore(&ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin());
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin());
rewriter.setInsertionPointAfter(op);
if (op->getNumResults() > 0)
rewriter.create<scf::YieldOp>(loc, op->getResults());
Expand Down
4 changes: 1 addition & 3 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
return failure();
if (!toBeHoisted->hasAttr("eligible"))
return failure();
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
// op is modified.
rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
rewriter.moveOpBefore(toBeHoisted, op);
return success();
}
};
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/lib/IR/TestClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ using namespace mlir;
namespace {

struct DumpNotifications : public OpBuilder::Listener {
void notifyOperationInserted(Operation *op) override {
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
}
};
Expand Down

0 comments on commit 5cc0f76

Please sign in to comment.