Skip to content

Commit

Permalink
[CIR][Lowering] Fixed break/continue lowering for loops (#211)
Browse files Browse the repository at this point in the history
This PR fixes lowering for `break/continue`  in loops.
The idea is to replace `cir.yield break` and `cir.yield continue` with
the branch operations to the corresponding blocks. Note, that we need to
ignore nesting loops and don't touch `break` in switch operations. Also,
`yield` from `if` need to be considered only when it's not the loop
`yield` and `continue` in switch is ignored since it's processed in the
loops lowering.

Fixes #160
  • Loading branch information
gitoleg committed Aug 9, 2023
1 parent afadbd0 commit 1ebed61
Show file tree
Hide file tree
Showing 3 changed files with 702 additions and 8 deletions.
70 changes: 62 additions & 8 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,47 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
return mlir::success();
}

void makeYieldIf(mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op,
mlir::Block *to,
mlir::ConversionPatternRewriter &rewriter) const {
if (op.getKind() == kind) {
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, op.getArgs(), to);
}
}

void
lowerNestedBreakContinue(mlir::Region &loopBody, mlir::Block *exitBlock,
mlir::Block *continueBlock,
mlir::ConversionPatternRewriter &rewriter) const {

auto processBreak = [&](mlir::Operation *op) {
if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
*op)) // don't process breaks in nested loops and switches
return mlir::WalkResult::skip();

if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
makeYieldIf(mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);

return mlir::WalkResult::advance();
};

auto processContinue = [&](mlir::Operation *op) {
if (isa<mlir::cir::LoopOp>(
*op)) // don't process continues in nested loops
return mlir::WalkResult::skip();

if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
makeYieldIf(mlir::cir::YieldOpKind::Continue, yield, continueBlock,
rewriter);

return mlir::WalkResult::advance();
};

loopBody.walk<mlir::WalkOrder::PreOrder>(processBreak);
loopBody.walk<mlir::WalkOrder::PreOrder>(processContinue);
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -265,6 +306,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
auto &stepFrontBlock = stepRegion.front();
auto stepYield =
dyn_cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());
auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);

lowerNestedBreakContinue(bodyRegion, continueBlock, &stepBlock, rewriter);

// Move loop op region contents to current CFG.
rewriter.inlineRegionBefore(condRegion, continueBlock);
Expand All @@ -287,8 +331,7 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {

// Branch from body to condition or to step on for-loop cases.
rewriter.setInsertionPoint(bodyYield);
auto &bodyExit = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &bodyExit);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &stepBlock);

// Is a for loop: branch from step to condition.
if (kind == LoopKind::For) {
Expand Down Expand Up @@ -480,6 +523,11 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
}
};

static bool isLoopYield(mlir::cir::YieldOp &op) {
return op.getKind() == mlir::cir::YieldOpKind::Break ||
op.getKind() == mlir::cir::YieldOpKind::Continue;
}

class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
public:
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
Expand Down Expand Up @@ -508,8 +556,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
rewriter.setInsertionPointToEnd(thenAfterBody);
if (auto thenYieldOp =
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
if (!isLoopYield(thenYieldOp)) // lowering of parent loop yields is
// deferred to loop lowering
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
} else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator())) {
llvm_unreachable("what are we terminating with?");
}
Expand Down Expand Up @@ -537,8 +587,10 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
rewriter.setInsertionPointToEnd(elseAfterBody);
if (auto elseYieldOp =
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
if (!isLoopYield(elseYieldOp)) // lowering of parent loop yields is
// deferred to loop lowering
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
} else if (!dyn_cast<mlir::cir::ReturnOp>(
elseAfterBody->getTerminator())) {
llvm_unreachable("what are we terminating with?");
Expand Down Expand Up @@ -1097,6 +1149,9 @@ class CIRSwitchOpLowering
case mlir::cir::YieldOpKind::Break:
rewriteYieldOp(rewriter, yieldOp, exitBlock);
break;
case mlir::cir::YieldOpKind::Continue: // Continue is handled only in
// loop lowering
break;
default:
return op->emitError("invalid yield kind in case statement");
}
Expand Down Expand Up @@ -1676,8 +1731,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering,
CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering,
CIRStructElementAddrOpLowering, CIRSwitchOpLowering,
CIRPtrDiffOpLowering>(
converter, patterns.getContext());
CIRPtrDiffOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down
Loading

0 comments on commit 1ebed61

Please sign in to comment.