Skip to content

Commit

Permalink
[mlir][scf] WhileOp patterns cleanup
Browse files Browse the repository at this point in the history
Fix review comments from https://reviews.llvm.org/D146252
Merge `WhileRemoveUnusedArgs` pattern with (unused) `WhileUnusedArg`,
use `getConditionOp`, use `SmallPtrSet` and early check, move tests

Differential Revision: https://reviews.llvm.org/D148256
  • Loading branch information
Hardcode84 committed Apr 14, 2023
1 parent 684914f commit c97b7bc
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 106 deletions.
125 changes: 45 additions & 80 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
Expand Down Expand Up @@ -3738,50 +3739,61 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
}
};

struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
/// Remove unused init/yield args.
struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;

LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {

if (!llvm::any_of(op.getBeforeArguments(),
[](Value arg) { return arg.use_empty(); }))
return failure();
return rewriter.notifyMatchFailure(op, "No args to remove");

YieldOp yield = op.getYieldOp();

// Collect results mapping, new terminator args and new result types.
SmallVector<Value> newYields;
SmallVector<Value> newInits;
llvm::BitVector argsToErase(op.getBeforeArguments().size());
for (const auto &it : llvm::enumerate(llvm::zip(
op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
Value beforeArg = std::get<0>(it.value());
Value yieldValue = std::get<1>(it.value());
Value initValue = std::get<2>(it.value());
llvm::BitVector argsToErase;

size_t argsCount = op.getBeforeArguments().size();
newYields.reserve(argsCount);
newInits.reserve(argsCount);
argsToErase.reserve(argsCount);
for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
if (beforeArg.use_empty()) {
argsToErase.set(it.index());
argsToErase.push_back(true);
} else {
argsToErase.push_back(false);
newYields.emplace_back(yieldValue);
newInits.emplace_back(initValue);
}
}

if (argsToErase.none())
return failure();
Block &beforeBlock = op.getBefore().front();
Block &afterBlock = op.getAfter().front();

rewriter.startRootUpdate(op);
op.getBefore().front().eraseArguments(argsToErase);
rewriter.finalizeRootUpdate(op);
beforeBlock.eraseArguments(argsToErase);

WhileOp replacement =
rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
replacement.getBefore().takeBody(op.getBefore());
replacement.getAfter().takeBody(op.getAfter());
rewriter.replaceOp(op, replacement.getResults());
Location loc = op.getLoc();
auto newWhileOp =
rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
/*beforeBody*/ nullptr, /*afterBody*/ nullptr);
Block &newBeforeBlock = newWhileOp.getBefore().front();
Block &newAfterBlock = newWhileOp.getAfter().front();

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);

rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
newBeforeBlock.getArguments());
rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
newAfterBlock.getArguments());

rewriter.replaceOp(op, newWhileOp.getResults());
return success();
}
};
Expand All @@ -3792,24 +3804,28 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {

LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
Block &beforeBlock = op.getBefore().front();
Block &afterBlock = op.getAfter().front();

auto condOp = cast<ConditionOp>(beforeBlock.getTerminator());
ConditionOp condOp = op.getConditionOp();
ValueRange condOpArgs = condOp.getArgs();

llvm::SmallPtrSet<Value, 8> argsSet;
for (Value arg : condOpArgs)
argsSet.insert(arg);

if (argsSet.size() == condOpArgs.size())
return rewriter.notifyMatchFailure(op, "No results to remove");

llvm::SmallDenseMap<Value, unsigned> argsMap;
SmallVector<Value> newArgs;
for (auto arg : condOpArgs) {
argsMap.reserve(condOpArgs.size());
newArgs.reserve(condOpArgs.size());
for (Value arg : condOpArgs) {
if (!argsMap.count(arg)) {
auto pos = static_cast<unsigned>(argsMap.size());
argsMap.insert({arg, pos});
newArgs.emplace_back(arg);
}
}

if (argsMap.size() == condOpArgs.size())
return rewriter.notifyMatchFailure(op, "No results to remove");

ValueRange argsRange(newArgs);

Location loc = op.getLoc();
Expand All @@ -3834,64 +3850,13 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
argsRange);

rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
newBeforeBlock.getArguments());
rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
rewriter.replaceOp(op, resultsMapping);
return success();
}
};

/// Remove unused init/yield args.
struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern<WhileOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
Block &beforeBlock = op.getBefore().front();
Block &afterBlock = op.getAfter().front();

auto yield = cast<YieldOp>(afterBlock.getTerminator());

llvm::BitVector argsToRemove;
SmallVector<Value> newInits;
SmallVector<Value> newYieldArgs;

bool changed = false;
for (auto &&[arg, init, yieldArg] : llvm::zip(
beforeBlock.getArguments(), op.getInits(), yield.getResults())) {
bool empty = arg.use_empty();
argsToRemove.push_back(empty);
if (empty) {
changed = true;
continue;
}

newInits.emplace_back(init);
newYieldArgs.emplace_back(yieldArg);
}

if (!changed)
return rewriter.notifyMatchFailure(op, "No args to remove");

beforeBlock.eraseArguments(argsToRemove);

Location loc = op.getLoc();
auto newWhileOp =
rewriter.create<WhileOp>(loc, op->getResultTypes(), newInits,
/*beforeBody*/ nullptr, /*afterBody*/ nullptr);
Block &newBeforeBlock = newWhileOp.getBefore().front();
Block &newAfterBlock = newWhileOp.getAfter().front();

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<YieldOp>(yield, newYieldArgs);

rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
newBeforeBlock.getArguments());
rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
newAfterBlock.getArguments());
rewriter.replaceOp(op, newWhileOp.getResults());
rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
rewriter.replaceOp(op, resultsMapping);
return success();
}
};
Expand Down
54 changes: 28 additions & 26 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Expand Up @@ -1019,30 +1019,6 @@ func.func @while_cond_true() -> i1 {

// -----

// CHECK-LABEL: @while_unused_arg
func.func @while_unused_arg(%x : i32, %y : f64) -> i32 {
%0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
%condition = "test.condition"(%arg1) : (i32) -> i1
scf.condition(%condition) %arg1 : i32
} do {
^bb0(%arg1: i32):
%next = "test.use"(%arg1) : (i32) -> (i32)
scf.yield %next, %y : i32, f64
}
return %0 : i32
}
// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 {
// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[post:.+]]: i32):
// CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
// CHECK-NEXT: scf.yield %[[next]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32

// -----

// CHECK-LABEL: @invariant_loop_args_in_same_order
// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor<i32>)
func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
Expand Down Expand Up @@ -1221,10 +1197,36 @@ func.func @while_duplicated_res() -> (i32, i32) {
// CHECK: }
// CHECK: return %[[RES]], %[[RES]] : i32, i32


// -----

// CHECK-LABEL: @while_unused_arg1
func.func @while_unused_arg1(%x : i32, %y : f64) -> i32 {
%0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) {
%condition = "test.condition"(%arg1) : (i32) -> i1
scf.condition(%condition) %arg1 : i32
} do {
^bb0(%arg1: i32):
%next = "test.use"(%arg1) : (i32) -> (i32)
scf.yield %next, %y : i32, f64
}
return %0 : i32
}
// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.*]] = %{{.*}}) : (i32) -> i32 {
// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[post:.*]]: i32):
// CHECK-NEXT: %[[next:.*]] = "test.use"(%[[post]]) : (i32) -> i32
// CHECK-NEXT: scf.yield %[[next]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[res]] : i32


// -----

// CHECK-LABEL: @while_unused_arg
func.func @while_unused_arg(%val0: i32) -> i32 {
// CHECK-LABEL: @while_unused_arg2
func.func @while_unused_arg2(%val0: i32) -> i32 {
%0 = scf.while (%val1 = %val0) : (i32) -> i32 {
%val = "test.val"() : () -> i32
%condition = "test.condition"() : () -> i1
Expand Down

0 comments on commit c97b7bc

Please sign in to comment.