Skip to content

Commit

Permalink
[mlir] notify insertion of parent op first when cloning (#73806)
Browse files Browse the repository at this point in the history
When cloning an operation with a region, the builder was currently
notifying about the insertion of the cloned operations inside the region
before the cloned operation itself.

When using cloning inside rewrite pass, this could cause issues if a
pattern is expected to be applied on a cloned parent operation before
trying to apply patterns on the cloned operations it contains (the
patterns are attempted in order of notifications for the cloned
operations).
  • Loading branch information
jeanPerier committed Dec 1, 2023
1 parent d55692d commit 5a4ca51
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
7 changes: 4 additions & 3 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,17 +527,18 @@ LogicalResult OpBuilder::tryFold(Operation *op,

Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
// The `insert` call below handles the notification for inserting `newOp`
newOp = insert(newOp);
// The `insert` call above handles the notification for inserting `newOp`
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp);
};
for (Region &region : newOp->getRegions())
region.walk(walkFn);
region.walk<WalkOrder::PreOrder>(walkFn);
}
return insert(newOp);
return newOp;
}

Operation *OpBuilder::clone(Operation &op) {
Expand Down
23 changes: 19 additions & 4 deletions mlir/test/IR/test-clone.mlir
Original file line number Diff line number Diff line change
@@ -1,20 +1,35 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(test-clone))" -split-input-file
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(test-clone))" | FileCheck %s

module {
func.func @fixpoint(%arg1 : i32) -> i32 {
%r = "test.use"(%arg1) ({
"test.yield"(%arg1) : (i32) -> ()
%r2 = "test.use2"(%arg1) ({
"test.yield2"(%arg1) : (i32) -> ()
}) : (i32) -> i32
"test.yield"(%r2) : (i32) -> ()
}) : (i32) -> i32
return %r : i32
}
}

// CHECK: notifyOperationInserted: test.use
// CHECK-NEXT: notifyOperationInserted: test.use2
// CHECK-NEXT: notifyOperationInserted: test.yield2
// CHECK-NEXT: notifyOperationInserted: test.yield
// CHECK-NEXT: notifyOperationInserted: func.return

// CHECK: func @fixpoint(%[[arg0:.+]]: i32) -> i32 {
// CHECK-NEXT: %[[i0:.+]] = "test.use"(%[[arg0]]) ({
// CHECK-NEXT: "test.yield"(%arg0) : (i32) -> ()
// CHECK-NEXT: %[[r2:.+]] = "test.use2"(%[[arg0]]) ({
// CHECK-NEXT: "test.yield2"(%[[arg0]]) : (i32) -> ()
// CHECK-NEXT: }) : (i32) -> i32
// CHECK-NEXT: "test.yield"(%[[r2]]) : (i32) -> ()
// CHECK-NEXT: }) : (i32) -> i32
// CHECK-NEXT: %[[i1:.+]] = "test.use"(%[[i0]]) ({
// CHECK-NEXT: "test.yield"(%[[i0]]) : (i32) -> ()
// CHECK-NEXT: %[[r2:.+]] = "test.use2"(%[[i0]]) ({
// CHECK-NEXT: "test.yield2"(%[[i0]]) : (i32) -> ()
// CHECK-NEXT: }) : (i32) -> i32
// CHECK-NEXT: "test.yield"(%[[r2]]) : (i32) -> ()
// CHECK-NEXT: }) : (i32) -> i32
// CHECK-NEXT: return %[[i1]] : i32
// CHECK-NEXT: }
8 changes: 8 additions & 0 deletions mlir/test/lib/IR/TestClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ using namespace mlir;

namespace {

struct DumpNotifications : public OpBuilder::Listener {
void notifyOperationInserted(Operation *op) override {
llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
}
};

/// This is a test pass which clones the body of a function. Specifically
/// this pass replaces f(x) to instead return f(f(x)) in which the cloned body
/// takes the result of the first operation return as an input.
Expand Down Expand Up @@ -50,6 +56,8 @@ struct ClonePass
}

OpBuilder builder(op->getContext());
DumpNotifications dumpNotifications;
builder.setListener(&dumpNotifications);
builder.setInsertionPointToEnd(&regionEntry);
SmallVector<Operation *> toClone;
for (Operation &inst : regionEntry)
Expand Down

0 comments on commit 5a4ca51

Please sign in to comment.