Skip to content

Commit

Permalink
[mlir][scf] Add scf-to-cf lowering for scf.index_switch
Browse files Browse the repository at this point in the history
This patch adds lowering from `scf.index_switch` to `cf.switch.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D136883
  • Loading branch information
Mogball committed Oct 31, 2022
1 parent 144e38f commit 91effec
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
68 changes: 67 additions & 1 deletion mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ struct DoWhileLowering : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const override;
};

/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IndexSwitchOp op,
PatternRewriter &rewriter) const override;
};
} // namespace

LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
Expand Down Expand Up @@ -615,10 +623,68 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}

LogicalResult
IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
PatternRewriter &rewriter) const {
// Split the block at the op.
Block *condBlock = rewriter.getInsertionBlock();
Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));

// Create the arguments on the continue block with which to replace the
// results of the op.
SmallVector<Value> results;
results.reserve(op.getNumResults());
for (Type resultType : op.getResultTypes())
results.push_back(continueBlock->addArgument(resultType, op.getLoc()));

// Handle the regions.
auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
Block *block = &region.front();

// Convert the yield terminator to a branch to the continue block.
auto yield = cast<scf::YieldOp>(block->getTerminator());
rewriter.setInsertionPoint(yield);
rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
yield.getOperands());

// Inline the region.
rewriter.inlineRegionBefore(region, continueBlock);
return block;
};

// Convert the case regions.
SmallVector<Block *> caseSuccessors;
SmallVector<int32_t> caseValues;
caseSuccessors.reserve(op.getCases().size());
caseValues.reserve(op.getCases().size());
for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
FailureOr<Block *> block = convertRegion(region);
if (failed(block))
return failure();
caseSuccessors.push_back(*block);
caseValues.push_back(value);
}

// Convert the default region.
FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
if (failed(defaultBlock))
return failure();

// Create the switch.
rewriter.setInsertionPointToEnd(condBlock);
SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
rewriter.create<cf::SwitchOp>(
op.getLoc(), op.getArg(), *defaultBlock, ValueRange(),
rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}

void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
ExecuteRegionLowering>(patterns.getContext());
ExecuteRegionLowering, IndexSwitchLowering>(
patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}

Expand Down
29 changes: 28 additions & 1 deletion mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ func.func @while_values(%arg0: i32, %arg1: f32) {
scf.condition(%0) %2, %3 : i64, f64
} do {
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
^bb0(%arg2: i64, %arg3: f64):
^bb0(%arg2: i64, %arg3: f64):
// CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
scf.yield %c0_i32, %cst : i32, f32
}
Expand Down Expand Up @@ -620,3 +620,30 @@ func.func @func_execute_region_elim_multi_yield() {
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]])
// CHECK: return

// SWITCH-LABEL: @index_switch
func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
// SWITCH: cf.switch %arg0 : index
// SWITCH-NEXT: default: ^bb3
// SWITCH-NEXT: 0: ^bb1
// SWITCH-NEXT: 1: ^bb2
%0 = scf.index_switch %i -> i32
// SWITCH: ^bb1:
case 0 {
// SWITCH-NEXT: llvm.br ^bb4(%arg1
scf.yield %a : i32
}
// SWITCH: ^bb2:
case 1 {
// SWITCH-NEXT: llvm.br ^bb4(%arg2
scf.yield %b : i32
}
// SWITCH: ^bb3:
default {
// SWITCH-NEXT: llvm.br ^bb4(%arg3
scf.yield %c : i32
}
// SWITCH: ^bb4(%[[V:.*]]: i32
// SWITCH-NEXT: return %[[V]]
return %0 : i32
}

0 comments on commit 91effec

Please sign in to comment.