Skip to content

Commit

Permalink
[mlir][scf] More WhileOp canonicalizations
Browse files Browse the repository at this point in the history
Remove duplicated ConditonOp args, remove unused init/yield args.

Differential Revision: https://reviews.llvm.org/D146252
  • Loading branch information
Hardcode84 committed Apr 12, 2023
1 parent 63c3895 commit e78d341
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 1 deletion.
118 changes: 117 additions & 1 deletion mlir/lib/Dialect/SCF/IR/SCF.cpp
Expand Up @@ -3785,13 +3785,129 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
return success();
}
};

/// Remove duplicated ConditionOp args.
struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
Block &beforeBlock = op.getBefore().front();
Block &afterBlock = op.getAfter().front();

auto condOp = cast<ConditionOp>(beforeBlock.getTerminator());
ValueRange condOpArgs = condOp.getArgs();
llvm::SmallDenseMap<Value, unsigned> argsMap;
SmallVector<Value> newArgs;
for (auto arg : condOpArgs) {
if (!argsMap.count(arg)) {
auto pos = static_cast<unsigned>(argsMap.size());
argsMap.insert({arg, pos});
newArgs.emplace_back(arg);
}
}

if (argsMap.size() == condOpArgs.size())
return rewriter.notifyMatchFailure(op, "No results to remove");

ValueRange argsRange(newArgs);
auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
// Nothing
};

Location loc = op.getLoc();
auto newWhileOp = rewriter.create<scf::WhileOp>(
loc, argsRange.getTypes(), op.getInits(), emptyBuilder, emptyBuilder);
Block &newBeforeBlock = newWhileOp.getBefore().front();
Block &newAfterBlock = newWhileOp.getAfter().front();

SmallVector<Value> afterArgsMapping;
SmallVector<Value> resultsMapping;
for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
auto it = argsMap.find(arg);
assert(it != argsMap.end());
auto pos = it->second;
afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
resultsMapping.emplace_back(newWhileOp->getResult(pos));
}

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(condOp);
rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
argsRange);

rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
newBeforeBlock.getArguments());
rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
rewriter.replaceOp(op, resultsMapping);
return success();
}
};

/// Remove unused init/yield args.
struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern<WhileOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(WhileOp op,
PatternRewriter &rewriter) const override {
Block &beforeBlock = op.getBefore().front();
Block &afterBlock = op.getAfter().front();

auto yield = cast<YieldOp>(afterBlock.getTerminator());

llvm::BitVector argsToRemove;
SmallVector<Value> newInits;
SmallVector<Value> newYieldArgs;

bool changed = false;
for (auto &&[arg, init, yieldArg] : llvm::zip(
beforeBlock.getArguments(), op.getInits(), yield.getResults())) {
bool empty = arg.use_empty();
argsToRemove.push_back(empty);
if (empty) {
changed = true;
continue;
}

newInits.emplace_back(init);
newYieldArgs.emplace_back(yieldArg);
}

if (!changed)
return rewriter.notifyMatchFailure(op, "No args to remove");

beforeBlock.eraseArguments(argsToRemove);

auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
// Nothing
};

Location loc = op.getLoc();
auto newWhileOp = rewriter.create<WhileOp>(
loc, op->getResultTypes(), newInits, emptyBuilder, emptyBuilder);
Block &newBeforeBlock = newWhileOp.getBefore().front();
Block &newAfterBlock = newWhileOp.getAfter().front();

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<YieldOp>(yield, newYieldArgs);

rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
newBeforeBlock.getArguments());
rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
newAfterBlock.getArguments());
rewriter.replaceOp(op, newWhileOp.getResults());
return success();
}
};
} // namespace

void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
RemoveLoopInvariantValueYielded, WhileConditionTruth,
WhileCmpCond, WhileUnusedResult>(context);
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
WhileRemoveUnusedArgs>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Expand Up @@ -1195,6 +1195,60 @@ func.func @while_cmp_rhs(%arg0 : i32) {
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }

// -----

// CHECK-LABEL: @while_duplicated_res
func.func @while_duplicated_res() -> (i32, i32) {
%0:2 = scf.while () : () -> (i32, i32) {
%val = "test.val"() : () -> i32
%condition = "test.condition"() : () -> i1
scf.condition(%condition) %val, %val : i32, i32
} do {
^bb0(%val2: i32, %val3: i32):
"test.use"(%val2, %val3) : (i32, i32) -> ()
scf.yield
}
return %0#0, %0#1: i32, i32
}
// CHECK: %[[RES:.*]] = scf.while : () -> i32 {
// CHECK: %[[VAL:.*]] = "test.val"() : () -> i32
// CHECK: %[[COND:.*]] = "test.condition"() : () -> i1
// CHECK: scf.condition(%[[COND]]) %[[VAL]] : i32
// CHECK: } do {
// CHECK: ^bb0(%[[ARG:.*]]: i32):
// CHECK: "test.use"(%[[ARG]], %[[ARG]]) : (i32, i32) -> ()
// CHECK: scf.yield
// CHECK: }
// CHECK: return %[[RES]], %[[RES]] : i32, i32

// -----

// CHECK-LABEL: @while_unused_arg
func.func @while_unused_arg(%val0: i32) -> i32 {
%0 = scf.while (%val1 = %val0) : (i32) -> i32 {
%val = "test.val"() : () -> i32
%condition = "test.condition"() : () -> i1
scf.condition(%condition) %val: i32
} do {
^bb0(%val2: i32):
"test.use"(%val2) : (i32) -> ()
%val1 = "test.val1"() : () -> i32
scf.yield %val1 : i32
}
return %0 : i32
}
// CHECK: %[[RES:.*]] = scf.while : () -> i32 {
// CHECK: %[[VAL:.*]] = "test.val"() : () -> i32
// CHECK: %[[COND:.*]] = "test.condition"() : () -> i1
// CHECK: scf.condition(%[[COND]]) %[[VAL]] : i32
// CHECK: } do {
// CHECK: ^bb0(%[[ARG:.*]]: i32):
// CHECK: "test.use"(%[[ARG]]) : (i32) -> ()
// CHECK: scf.yield
// CHECK: }
// CHECK: return %[[RES]] : i32


// -----

// CHECK-LABEL: @combineIfs
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Expand Up @@ -518,6 +518,8 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
%arg2: index) {
scf.while (%arg3 = %arg1) : (tensor<5xi1>) -> () {
%0 = tensor.extract %arg0[%arg2] : tensor<5xi1>
%1 = tensor.extract %arg3[%arg2] : tensor<5xi1>
"dummy.use"(%1) : (i1) -> ()
scf.condition(%0)
} do {
%0 = "dummy.some_op"() : () -> index
Expand Down

0 comments on commit e78d341

Please sign in to comment.