Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Operation clone is currently faulty. Suppose you have a block like as follows: ``` (%x0 : i32) { %x1 = f(%x0) return %x1 } ``` The test case we have is that we want to "unroll" this, in which we want to change this to compute `f(f(x0))` instead of just `f(x0)`. We do so by making a copy of the body at the end of the block and set the uses of the argument in the copy operations with the value returned from the original block. This is implemented as follows: 1) map to the block arguments to the returned value (`map[x0] = x1`). 2) clone the body Now for this small example, this works as intended and we get the following. ``` (%x0 : i32) { %x1 = f(%x0) %x2 = f(%x1) return %x2 } ``` This is because the current logic to clone `x1 = f(x0)` first looks up the arguments in the map (which finds `x0` maps to `x1` from the initialization), and then sets the map of the result to the cloned result (`map[x1] = x2`). However, this fails if `x0` is not an argument to the op, but instead used inside the region, like below. ``` (%x0 : i32) { %x1 = f() { yield %x0 } return %x1 } ``` This is because cloning an op currently first looks up the args (none), sets the map of the result (`map[%x1] = %x2`), and then clones the regions. This results in the following, which is clearly illegal: ``` (%x0 : i32) { %x1 = f() { yield %x0 } %x2 = f() { yield %x2 } return %x2 } ``` Diving deeper, this is partially due to the ordering (how this PR fixes it), as well as how region cloning works. Namely it will first clone with the mapping, and then it will remap all operands. Since the ordering above now has a map of `x0 -> x1` and `x1 -> x2`, we end up with the incorrect behavior here. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122531
- Loading branch information
Showing
6 changed files
with
101 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func.func(test-clone)" -split-input-file | ||
|
||
module { | ||
func @fixpoint(%arg1 : i32) -> i32 { | ||
%r = "test.use"(%arg1) ({ | ||
"test.yield"(%arg1) : (i32) -> () | ||
}) : (i32) -> i32 | ||
return %r : i32 | ||
} | ||
} | ||
|
||
// CHECK: func @fixpoint(%[[arg0:.+]]: i32) -> i32 { | ||
// CHECK-NEXT: %[[i0:.+]] = "test.use"(%[[arg0]]) ({ | ||
// CHECK-NEXT: "test.yield"(%arg0) : (i32) -> () | ||
// CHECK-NEXT: }) : (i32) -> i32 | ||
// CHECK-NEXT: %[[i1:.+]] = "test.use"(%[[i0]]) ({ | ||
// CHECK-NEXT: "test.yield"(%[[i0]]) : (i32) -> () | ||
// CHECK-NEXT: }) : (i32) -> i32 | ||
// CHECK-NEXT: return %[[i1]] : i32 | ||
// CHECK-NEXT: } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
//===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "TestDialect.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
/// This is a test pass which clones the body of a function. Specifically | ||
/// this pass replaces f(x) to instead return f(f(x)) in which the cloned body | ||
/// takes the result of the first operation return as an input. | ||
struct ClonePass | ||
: public PassWrapper<ClonePass, InterfacePass<FunctionOpInterface>> { | ||
StringRef getArgument() const final { return "test-clone"; } | ||
StringRef getDescription() const final { return "Test clone of op"; } | ||
void runOnOperation() override { | ||
FunctionOpInterface op = getOperation(); | ||
|
||
// Limit testing to ops with only one region. | ||
if (op->getNumRegions() != 1) | ||
return; | ||
|
||
Region ®ion = op->getRegion(0); | ||
if (!region.hasOneBlock()) | ||
return; | ||
|
||
Block ®ionEntry = region.front(); | ||
auto terminator = regionEntry.getTerminator(); | ||
|
||
// Only handle functions whose returns match the inputs. | ||
if (terminator->getNumOperands() != regionEntry.getNumArguments()) | ||
return; | ||
|
||
BlockAndValueMapping map; | ||
for (auto tup : | ||
llvm::zip(terminator->getOperands(), regionEntry.getArguments())) { | ||
if (std::get<0>(tup).getType() != std::get<1>(tup).getType()) | ||
return; | ||
map.map(std::get<1>(tup), std::get<0>(tup)); | ||
} | ||
|
||
OpBuilder B(op->getContext()); | ||
B.setInsertionPointToEnd(®ionEntry); | ||
SmallVector<Operation *> toClone; | ||
for (Operation &inst : regionEntry) | ||
toClone.push_back(&inst); | ||
for (Operation *inst : toClone) | ||
B.clone(*inst, map); | ||
terminator->erase(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
namespace mlir { | ||
void registerCloneTestPasses() { PassRegistration<ClonePass>(); } | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters