Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLIR][SCF] Add for-to-while loop transformation pass
This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop condition is placed in the 'before' region of the while operation, and indctuion variable incrementation + the loop body in the 'after' region. The loop carried values of the while op are the induction variable (IV) of the for-loop + any iter_args specified for the for-loop. Any 'yield' ops in the for-loop are rewritten to additionally yield the (incremented) induction variable. This transformation is useful for passes where we want to consider structured control flow solely on the basis of a loop body and the computation of a loop condition. As an example, when doing high-level synthesis in CIRCT, the incrementation of an IV in a for-loop is "just another part" of a circuit datapath, and what we really care about is the distinction between our datapath and our control logic (the condition variable). Differential Revision: https://reviews.llvm.org/D108454
- Loading branch information
Showing
5 changed files
with
297 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
//===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Transforms SCF.ForOp's into SCF.WhileOp's. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "PassDetail.h" | ||
#include "mlir/Dialect/SCF/Passes.h" | ||
#include "mlir/Dialect/SCF/SCF.h" | ||
#include "mlir/Dialect/SCF/Transforms.h" | ||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
using namespace llvm; | ||
using namespace mlir; | ||
using scf::ForOp; | ||
using scf::WhileOp; | ||
|
||
namespace { | ||
|
||
struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { | ||
using OpRewritePattern<ForOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ForOp forOp, | ||
PatternRewriter &rewriter) const override { | ||
// Generate type signature for the loop-carried values. The induction | ||
// variable is placed first, followed by the forOp.iterArgs. | ||
SmallVector<Type, 8> lcvTypes; | ||
lcvTypes.push_back(forOp.getInductionVar().getType()); | ||
llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes), | ||
[&](auto v) { return v.getType(); }); | ||
|
||
// Build scf.WhileOp | ||
SmallVector<Value> initArgs; | ||
initArgs.push_back(forOp.lowerBound()); | ||
llvm::append_range(initArgs, forOp.initArgs()); | ||
auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs, | ||
forOp->getAttrs()); | ||
|
||
// 'before' region contains the loop condition and forwarding of iteration | ||
// arguments to the 'after' region. | ||
auto *beforeBlock = rewriter.createBlock( | ||
&whileOp.before(), whileOp.before().begin(), lcvTypes, {}); | ||
rewriter.setInsertionPointToStart(&whileOp.before().front()); | ||
auto cmpOp = rewriter.create<CmpIOp>(whileOp.getLoc(), CmpIPredicate::slt, | ||
beforeBlock->getArgument(0), | ||
forOp.upperBound()); | ||
rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(), | ||
beforeBlock->getArguments()); | ||
|
||
// Inline for-loop body into an executeRegion operation in the "after" | ||
// region. The return type of the execRegionOp does not contain the | ||
// iv - yields in the source for-loop contain only iterArgs. | ||
auto *afterBlock = rewriter.createBlock( | ||
&whileOp.after(), whileOp.after().begin(), lcvTypes, {}); | ||
|
||
// Add induction variable incrementation | ||
rewriter.setInsertionPointToEnd(afterBlock); | ||
auto ivIncOp = rewriter.create<AddIOp>( | ||
whileOp.getLoc(), afterBlock->getArgument(0), forOp.step()); | ||
|
||
// Rewrite uses of the for-loop block arguments to the new while-loop | ||
// "after" arguments | ||
for (auto barg : enumerate(forOp.getBody(0)->getArguments())) | ||
barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index())); | ||
|
||
// Inline for-loop body operations into 'after' region. | ||
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) | ||
arg.moveBefore(afterBlock, afterBlock->end()); | ||
|
||
// Add incremented IV to yield operations | ||
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) { | ||
SmallVector<Value> yieldOperands = yieldOp.getOperands(); | ||
yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); | ||
yieldOp->setOperands(yieldOperands); | ||
} | ||
|
||
// We cannot do a direct replacement of the forOp since the while op returns | ||
// an extra value (the induction variable escapes the loop through being | ||
// carried in the set of iterargs). Instead, rewrite uses of the forOp | ||
// results. | ||
for (auto arg : llvm::enumerate(forOp.getResults())) | ||
arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1)); | ||
|
||
rewriter.eraseOp(forOp); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> { | ||
void runOnFunction() override { | ||
FuncOp funcOp = getFunction(); | ||
MLIRContext *ctx = funcOp.getContext(); | ||
RewritePatternSet patterns(ctx); | ||
patterns.add<ForLoopLoweringPattern>(ctx); | ||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<Pass> mlir::createForToWhileLoopPass() { | ||
return std::make_unique<ForToWhileLoop>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-for-to-while)' -split-input-file | FileCheck %s | ||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py | ||
|
||
// CHECK-LABEL: func @single_loop( | ||
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xi32>, | ||
// CHECK-SAME: %[[VAL_1:.*]]: index, | ||
// CHECK-SAME: %[[VAL_2:.*]]: i32) { | ||
// CHECK: %[[VAL_3:.*]] = constant 0 : index | ||
// CHECK: %[[VAL_4:.*]] = constant 1 : index | ||
// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index { | ||
// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index | ||
// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index | ||
// CHECK: } do { | ||
// CHECK: ^bb0(%[[VAL_8:.*]]: index): | ||
// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index | ||
// CHECK: %[[VAL_10:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32 | ||
// CHECK: memref.store %[[VAL_10]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32> | ||
// CHECK: scf.yield %[[VAL_9]] : index | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } | ||
func @single_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) { | ||
%c0 = constant 0 : index | ||
%c1 = constant 1 : index | ||
scf.for %i = %c0 to %arg1 step %c1 { | ||
%0 = addi %arg2, %arg2 : i32 | ||
memref.store %0, %arg0[%i] : memref<?xi32> | ||
} | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @nested_loop( | ||
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xi32>, | ||
// CHECK-SAME: %[[VAL_1:.*]]: index, | ||
// CHECK-SAME: %[[VAL_2:.*]]: i32) { | ||
// CHECK: %[[VAL_3:.*]] = constant 0 : index | ||
// CHECK: %[[VAL_4:.*]] = constant 1 : index | ||
// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index { | ||
// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index | ||
// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index | ||
// CHECK: } do { | ||
// CHECK: ^bb0(%[[VAL_8:.*]]: index): | ||
// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index | ||
// CHECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_3]]) : (index) -> index { | ||
// CHECK: %[[VAL_12:.*]] = cmpi slt, %[[VAL_11]], %[[VAL_1]] : index | ||
// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index | ||
// CHECK: } do { | ||
// CHECK: ^bb0(%[[VAL_13:.*]]: index): | ||
// CHECK: %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_4]] : index | ||
// CHECK: %[[VAL_15:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32 | ||
// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32> | ||
// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_13]]] : memref<?xi32> | ||
// CHECK: scf.yield %[[VAL_14]] : index | ||
// CHECK: } | ||
// CHECK: scf.yield %[[VAL_9]] : index | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } | ||
func @nested_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) { | ||
%c0 = constant 0 : index | ||
%c1 = constant 1 : index | ||
scf.for %i = %c0 to %arg1 step %c1 { | ||
scf.for %j = %c0 to %arg1 step %c1 { | ||
%0 = addi %arg2, %arg2 : i32 | ||
memref.store %0, %arg0[%i] : memref<?xi32> | ||
memref.store %0, %arg0[%j] : memref<?xi32> | ||
} | ||
} | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @for_iter_args( | ||
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, | ||
// CHECK-SAME: %[[VAL_2:.*]]: index) -> f32 { | ||
// CHECK: %[[VAL_3:.*]] = constant 0.000000e+00 : f32 | ||
// CHECK: %[[VAL_4:.*]]:3 = scf.while (%[[VAL_5:.*]] = %[[VAL_0]], %[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_3]]) : (index, f32, f32) -> (index, f32, f32) { | ||
// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_5]], %[[VAL_1]] : index | ||
// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : index, f32, f32 | ||
// CHECK: } do { | ||
// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32): | ||
// CHECK: %[[VAL_12:.*]] = addi %[[VAL_9]], %[[VAL_2]] : index | ||
// CHECK: %[[VAL_13:.*]] = addf %[[VAL_10]], %[[VAL_11]] : f32 | ||
// CHECK: scf.yield %[[VAL_12]], %[[VAL_13]], %[[VAL_13]] : index, f32, f32 | ||
// CHECK: } | ||
// CHECK: return %[[VAL_14:.*]]#2 : f32 | ||
// CHECK: } | ||
func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 { | ||
%s0 = constant 0.0 : f32 | ||
%result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) { | ||
%sn = addf %iarg0, %iarg1 : f32 | ||
scf.yield %sn, %sn : f32, f32 | ||
} | ||
return %result#1 : f32 | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @exec_region_multiple_yields( | ||
// CHECK-SAME: %[[VAL_0:.*]]: i32, | ||
// CHECK-SAME: %[[VAL_1:.*]]: index, | ||
// CHECK-SAME: %[[VAL_2:.*]]: i32) -> i32 { | ||
// CHECK: %[[VAL_3:.*]] = constant 0 : index | ||
// CHECK: %[[VAL_4:.*]] = constant 1 : index | ||
// CHECK: %[[VAL_5:.*]]:2 = scf.while (%[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_0]]) : (index, i32) -> (index, i32) { | ||
// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index | ||
// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_6]], %[[VAL_7]] : index, i32 | ||
// CHECK: } do { | ||
// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32): | ||
// CHECK: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_4]] : index | ||
// CHECK: %[[VAL_12:.*]] = scf.execute_region -> i32 { | ||
// CHECK: %[[VAL_13:.*]] = cmpi slt, %[[VAL_9]], %[[VAL_4]] : index | ||
// CHECK: cond_br %[[VAL_13]], ^bb1, ^bb2 | ||
// CHECK: ^bb1: | ||
// CHECK: %[[VAL_14:.*]] = subi %[[VAL_10]], %[[VAL_0]] : i32 | ||
// CHECK: scf.yield %[[VAL_14]] : i32 | ||
// CHECK: ^bb2: | ||
// CHECK: %[[VAL_15:.*]] = muli %[[VAL_10]], %[[VAL_2]] : i32 | ||
// CHECK: scf.yield %[[VAL_15]] : i32 | ||
// CHECK: } | ||
// CHECK: scf.yield %[[VAL_11]], %[[VAL_16:.*]] : index, i32 | ||
// CHECK: } | ||
// CHECK: return %[[VAL_17:.*]]#1 : i32 | ||
// CHECK: } | ||
func @exec_region_multiple_yields(%arg0: i32, %arg1: index, %arg2: i32) -> i32 { | ||
%c1_i32 = constant 1 : i32 | ||
%c2_i32 = constant 2 : i32 | ||
%c0 = constant 0 : index | ||
%c1 = constant 1 : index | ||
%c5 = constant 5 : index | ||
%0 = scf.for %i = %c0 to %arg1 step %c1 iter_args(%iarg0 = %arg0) -> i32 { | ||
%2 = scf.execute_region -> i32 { | ||
%1 = cmpi slt, %i, %c1 : index | ||
cond_br %1, ^bb1, ^bb2 | ||
^bb1: | ||
%2 = subi %iarg0, %arg0 : i32 | ||
scf.yield %2 : i32 | ||
^bb2: | ||
%3 = muli %iarg0, %arg2 : i32 | ||
scf.yield %3 : i32 | ||
} | ||
scf.yield %2 : i32 | ||
} | ||
return %0 : i32 | ||
} |