Skip to content

Commit

Permalink
[MLIR][SCF] Add for-to-while loop transformation pass
Browse files Browse the repository at this point in the history
This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop condition is placed in the 'before' region of the while operation, and indctuion variable incrementation + the loop body in the 'after' region. The loop carried values of the while op are the induction variable (IV) of the for-loop + any iter_args specified for the for-loop.
Any 'yield' ops in the for-loop are rewritten to additionally yield the (incremented) induction variable.

This transformation is useful for passes where we want to consider structured control flow solely on the basis of a loop body and the computation of a loop condition. As an example, when doing high-level synthesis in CIRCT, the incrementation of an IV in a for-loop is "just another part" of a circuit datapath, and what we really care about is the distinction between our datapath and our control logic (the condition variable).

Differential Revision: https://reviews.llvm.org/D108454
  • Loading branch information
mortbopet committed Sep 21, 2021
1 parent 791b6eb commit 032cb16
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Passes.h
Expand Up @@ -52,6 +52,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
/// loop range.
std::unique_ptr<Pass> createForLoopRangeFoldingPass();

// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 35 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Passes.td
Expand Up @@ -78,4 +78,39 @@ def SCFForLoopRangeFolding
let constructor = "mlir::createForLoopRangeFoldingPass()";
}

def SCFForToWhileLoop
: FunctionPass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
let description = [{
This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop
condition is placed in the 'before' region of the while operation, and the
induction variable incrementation and loop body in the 'after' region.
The loop carried values of the while op are the induction variable (IV) of
the for-loop + any iter_args specified for the for-loop.
Any 'yield' ops in the for-loop are rewritten to additionally yield the
(incremented) induction variable.

```mlir
# Before:
scf.for %i = %c0 to %arg1 step %c1 {
%0 = addi %arg2, %arg2 : i32
memref.store %0, %arg0[%i] : memref<?xi32>
}

# After:
%0 = scf.while (%i = %c0) : (index) -> index {
%1 = cmpi slt, %i, %arg1 : index
scf.condition(%1) %i : index
} do {
^bb0(%i: index): // no predecessors
%1 = addi %i, %c1 : index
%2 = addi %arg2, %arg2 : i32
memref.store %2, %arg0[%i] : memref<?xi32>
scf.yield %1 : index
}
```
}];
}

#endif // MLIR_DIALECT_SCF_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRSCFTransforms
Bufferize.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
Expand Down
110 changes: 110 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -0,0 +1,110 @@
//===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ForOp's into SCF.WhileOp's.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace llvm;
using namespace mlir;
using scf::ForOp;
using scf::WhileOp;

namespace {

struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override {
// Generate type signature for the loop-carried values. The induction
// variable is placed first, followed by the forOp.iterArgs.
SmallVector<Type, 8> lcvTypes;
lcvTypes.push_back(forOp.getInductionVar().getType());
llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes),
[&](auto v) { return v.getType(); });

// Build scf.WhileOp
SmallVector<Value> initArgs;
initArgs.push_back(forOp.lowerBound());
llvm::append_range(initArgs, forOp.initArgs());
auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
forOp->getAttrs());

// 'before' region contains the loop condition and forwarding of iteration
// arguments to the 'after' region.
auto *beforeBlock = rewriter.createBlock(
&whileOp.before(), whileOp.before().begin(), lcvTypes, {});
rewriter.setInsertionPointToStart(&whileOp.before().front());
auto cmpOp = rewriter.create<CmpIOp>(whileOp.getLoc(), CmpIPredicate::slt,
beforeBlock->getArgument(0),
forOp.upperBound());
rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
beforeBlock->getArguments());

// Inline for-loop body into an executeRegion operation in the "after"
// region. The return type of the execRegionOp does not contain the
// iv - yields in the source for-loop contain only iterArgs.
auto *afterBlock = rewriter.createBlock(
&whileOp.after(), whileOp.after().begin(), lcvTypes, {});

// Add induction variable incrementation
rewriter.setInsertionPointToEnd(afterBlock);
auto ivIncOp = rewriter.create<AddIOp>(
whileOp.getLoc(), afterBlock->getArgument(0), forOp.step());

// Rewrite uses of the for-loop block arguments to the new while-loop
// "after" arguments
for (auto barg : enumerate(forOp.getBody(0)->getArguments()))
barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));

// Inline for-loop body operations into 'after' region.
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
arg.moveBefore(afterBlock, afterBlock->end());

// Add incremented IV to yield operations
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
SmallVector<Value> yieldOperands = yieldOp.getOperands();
yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
yieldOp->setOperands(yieldOperands);
}

// We cannot do a direct replacement of the forOp since the while op returns
// an extra value (the induction variable escapes the loop through being
// carried in the set of iterargs). Instead, rewrite uses of the forOp
// results.
for (auto arg : llvm::enumerate(forOp.getResults()))
arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));

rewriter.eraseOp(forOp);
return success();
}
};

struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
void runOnFunction() override {
FuncOp funcOp = getFunction();
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
patterns.add<ForLoopLoweringPattern>(ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
};
} // namespace

std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
return std::make_unique<ForToWhileLoop>();
}
148 changes: 148 additions & 0 deletions mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir
@@ -0,0 +1,148 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-for-to-while)' -split-input-file | FileCheck %s
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py

// CHECK-LABEL: func @single_loop(
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xi32>,
// CHECK-SAME: %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: i32) {
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 1 : index
// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index {
// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_8:.*]]: index):
// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index
// CHECK: %[[VAL_10:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32
// CHECK: memref.store %[[VAL_10]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
// CHECK: scf.yield %[[VAL_9]] : index
// CHECK: }
// CHECK: return
// CHECK: }
func @single_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
%c0 = constant 0 : index
%c1 = constant 1 : index
scf.for %i = %c0 to %arg1 step %c1 {
%0 = addi %arg2, %arg2 : i32
memref.store %0, %arg0[%i] : memref<?xi32>
}
return
}

// -----

// CHECK-LABEL: func @nested_loop(
// CHECK-SAME: %[[VAL_0:.*]]: memref<?xi32>,
// CHECK-SAME: %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: i32) {
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 1 : index
// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index {
// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_8:.*]]: index):
// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index
// CHECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_3]]) : (index) -> index {
// CHECK: %[[VAL_12:.*]] = cmpi slt, %[[VAL_11]], %[[VAL_1]] : index
// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_13:.*]]: index):
// CHECK: %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_4]] : index
// CHECK: %[[VAL_15:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32
// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_13]]] : memref<?xi32>
// CHECK: scf.yield %[[VAL_14]] : index
// CHECK: }
// CHECK: scf.yield %[[VAL_9]] : index
// CHECK: }
// CHECK: return
// CHECK: }
func @nested_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
%c0 = constant 0 : index
%c1 = constant 1 : index
scf.for %i = %c0 to %arg1 step %c1 {
scf.for %j = %c0 to %arg1 step %c1 {
%0 = addi %arg2, %arg2 : i32
memref.store %0, %arg0[%i] : memref<?xi32>
memref.store %0, %arg0[%j] : memref<?xi32>
}
}
return
}

// -----

// CHECK-LABEL: func @for_iter_args(
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: index) -> f32 {
// CHECK: %[[VAL_3:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[VAL_4:.*]]:3 = scf.while (%[[VAL_5:.*]] = %[[VAL_0]], %[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_3]]) : (index, f32, f32) -> (index, f32, f32) {
// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_5]], %[[VAL_1]] : index
// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : index, f32, f32
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
// CHECK: %[[VAL_12:.*]] = addi %[[VAL_9]], %[[VAL_2]] : index
// CHECK: %[[VAL_13:.*]] = addf %[[VAL_10]], %[[VAL_11]] : f32
// CHECK: scf.yield %[[VAL_12]], %[[VAL_13]], %[[VAL_13]] : index, f32, f32
// CHECK: }
// CHECK: return %[[VAL_14:.*]]#2 : f32
// CHECK: }
func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 {
%s0 = constant 0.0 : f32
%result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) {
%sn = addf %iarg0, %iarg1 : f32
scf.yield %sn, %sn : f32, f32
}
return %result#1 : f32
}

// -----

// CHECK-LABEL: func @exec_region_multiple_yields(
// CHECK-SAME: %[[VAL_0:.*]]: i32,
// CHECK-SAME: %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: i32) -> i32 {
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 1 : index
// CHECK: %[[VAL_5:.*]]:2 = scf.while (%[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_0]]) : (index, i32) -> (index, i32) {
// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_6]], %[[VAL_7]] : index, i32
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32):
// CHECK: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_4]] : index
// CHECK: %[[VAL_12:.*]] = scf.execute_region -> i32 {
// CHECK: %[[VAL_13:.*]] = cmpi slt, %[[VAL_9]], %[[VAL_4]] : index
// CHECK: cond_br %[[VAL_13]], ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: %[[VAL_14:.*]] = subi %[[VAL_10]], %[[VAL_0]] : i32
// CHECK: scf.yield %[[VAL_14]] : i32
// CHECK: ^bb2:
// CHECK: %[[VAL_15:.*]] = muli %[[VAL_10]], %[[VAL_2]] : i32
// CHECK: scf.yield %[[VAL_15]] : i32
// CHECK: }
// CHECK: scf.yield %[[VAL_11]], %[[VAL_16:.*]] : index, i32
// CHECK: }
// CHECK: return %[[VAL_17:.*]]#1 : i32
// CHECK: }
func @exec_region_multiple_yields(%arg0: i32, %arg1: index, %arg2: i32) -> i32 {
%c1_i32 = constant 1 : i32
%c2_i32 = constant 2 : i32
%c0 = constant 0 : index
%c1 = constant 1 : index
%c5 = constant 5 : index
%0 = scf.for %i = %c0 to %arg1 step %c1 iter_args(%iarg0 = %arg0) -> i32 {
%2 = scf.execute_region -> i32 {
%1 = cmpi slt, %i, %c1 : index
cond_br %1, ^bb1, ^bb2
^bb1:
%2 = subi %iarg0, %arg0 : i32
scf.yield %2 : i32
^bb2:
%3 = muli %iarg0, %arg2 : i32
scf.yield %3 : i32
}
scf.yield %2 : i32
}
return %0 : i32
}

0 comments on commit 032cb16

Please sign in to comment.