Skip to content

Commit

Permalink
[CIR][Lowering][Bugfix] Lower nested breaks in switch statements (#357)
Browse files Browse the repository at this point in the history
This PR fixes lowering of the next code: 
```
void foo(int x, int y) {
    switch (x) {
        case 0: 
            if (y) 
                break;
            break;
    }
}
```
i.e. when some sub statement contains `break` as well. Previously, we
did this trick for `loop`: process nested `break`/`continue` statements
while `LoopOp` lowering if they don't belong to another `LoopOp` or
`SwitchOp`. This is why there is some refactoring here as well, but the
idea is stiil the same: we need to process nested operations and emit
branches to the proper blocks.

This is quite frequent bug in `llvm-test-suite`
  • Loading branch information
gitoleg authored and lanza committed Apr 29, 2024
1 parent 7c9c457 commit 6d49675
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 58 deletions.
100 changes: 42 additions & 58 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,35 @@ mlir::LLVM::Linkage convertLinkage(mlir::cir::GlobalLinkageKind linkage) {
};
}

static void lowerNestedYield(mlir::cir::YieldOpKind targetKind,
mlir::ConversionPatternRewriter &rewriter,
mlir::Region &body,
mlir::Block *dst) {
// top-level yields are lowered in matchAndRewrite of the parent operations
auto isNested = [&](mlir::Operation *op) {
return op->getParentRegion() != &body;
};

body.walk<mlir::WalkOrder::PreOrder>(
[&](mlir::Operation *op) {
if (!isNested(op))
return mlir::WalkResult::advance();

// don't process breaks/continues in nested loops and switches
if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(*op))
return mlir::WalkResult::skip();

auto yield = dyn_cast<mlir::cir::YieldOp>(*op);
if (yield && yield.getKind() == targetKind) {
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, yield.getArgs(), dst);
}

return mlir::WalkResult::advance();
});
}


class CIRCopyOpLowering : public mlir::OpConversionPattern<mlir::cir::CopyOp> {
public:
using mlir::OpConversionPattern<mlir::cir::CopyOp>::OpConversionPattern;
Expand Down Expand Up @@ -398,57 +427,6 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
return mlir::success();
}

void makeYieldIf(mlir::cir::YieldOpKind kind, mlir::cir::YieldOp &op,
mlir::Block *to,
mlir::ConversionPatternRewriter &rewriter) const {
if (op.getKind() == kind) {
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, op.getArgs(), to);
}
}

void
lowerNestedBreakContinue(mlir::Region &loopBody, mlir::Block *exitBlock,
mlir::Block *continueBlock,
mlir::ConversionPatternRewriter &rewriter) const {
// top-level yields are lowered in matchAndRewrite
auto isNested = [&](mlir::Operation *op) {
return op->getParentRegion() != &loopBody;
};

auto processBreak = [&](mlir::Operation *op) {
if (!isNested(op))
return mlir::WalkResult::advance();

if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
*op)) // don't process breaks in nested loops and switches
return mlir::WalkResult::skip();

if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
makeYieldIf(mlir::cir::YieldOpKind::Break, yield, exitBlock, rewriter);

return mlir::WalkResult::advance();
};

auto processContinue = [&](mlir::Operation *op) {
if (!isNested(op))
return mlir::WalkResult::advance();

if (isa<mlir::cir::LoopOp>(
*op)) // don't process continues in nested loops
return mlir::WalkResult::skip();

if (auto yield = dyn_cast<mlir::cir::YieldOp>(*op))
makeYieldIf(mlir::cir::YieldOpKind::Continue, yield, continueBlock,
rewriter);

return mlir::WalkResult::advance();
};

loopBody.walk<mlir::WalkOrder::PreOrder>(processBreak);
loopBody.walk<mlir::WalkOrder::PreOrder>(processContinue);
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::LoopOp loopOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -478,7 +456,10 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
dyn_cast<mlir::cir::YieldOp>(stepRegion.back().getTerminator());
auto &stepBlock = (kind == LoopKind::For ? stepFrontBlock : condFrontBlock);

lowerNestedBreakContinue(bodyRegion, continueBlock, &stepBlock, rewriter);
lowerNestedYield(mlir::cir::YieldOpKind::Break,
rewriter, bodyRegion, continueBlock);
lowerNestedYield(mlir::cir::YieldOpKind::Continue,
rewriter, bodyRegion, &stepBlock);

// Move loop op region contents to current CFG.
rewriter.inlineRegionBefore(condRegion, continueBlock);
Expand Down Expand Up @@ -713,7 +694,7 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
}
};

static bool isLoopYield(mlir::cir::YieldOp &op) {
static bool isBreakOrContinue(mlir::cir::YieldOp &op) {
return op.getKind() == mlir::cir::YieldOpKind::Break ||
op.getKind() == mlir::cir::YieldOpKind::Continue;
}
Expand Down Expand Up @@ -746,8 +727,8 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
rewriter.setInsertionPointToEnd(thenAfterBody);
if (auto thenYieldOp =
dyn_cast<mlir::cir::YieldOp>(thenAfterBody->getTerminator())) {
if (!isLoopYield(thenYieldOp)) // lowering of parent loop yields is
// deferred to loop lowering
if (!isBreakOrContinue(thenYieldOp)) // lowering of parent loop yields is
// deferred to loop lowering
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
thenYieldOp, thenYieldOp.getArgs(), continueBlock);
} else if (!dyn_cast<mlir::cir::ReturnOp>(thenAfterBody->getTerminator())) {
Expand Down Expand Up @@ -777,8 +758,8 @@ class CIRIfLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
rewriter.setInsertionPointToEnd(elseAfterBody);
if (auto elseYieldOp =
dyn_cast<mlir::cir::YieldOp>(elseAfterBody->getTerminator())) {
if (!isLoopYield(elseYieldOp)) // lowering of parent loop yields is
// deferred to loop lowering
if (!isBreakOrContinue(elseYieldOp)) // lowering of parent loop yields is
// deferred to loop lowering
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
elseYieldOp, elseYieldOp.getArgs(), continueBlock);
} else if (!dyn_cast<mlir::cir::ReturnOp>(
Expand Down Expand Up @@ -839,7 +820,7 @@ class CIRScopeOpLowering
rewriter.setInsertionPointToEnd(afterBody);
auto yieldOp = cast<mlir::cir::YieldOp>(afterBody->getTerminator());

if (!isLoopYield(yieldOp)) {
if (!isBreakOrContinue(yieldOp)) {
auto branchOp = rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
yieldOp, yieldOp.getArgs(), continueBlock);

Expand Down Expand Up @@ -1411,6 +1392,9 @@ class CIRSwitchOpLowering
}
}

lowerNestedYield(mlir::cir::YieldOpKind::Break,
rewriter, region, exitBlock);

// Extract region contents before erasing the switch op.
rewriter.inlineRegionBefore(region, exitBlock);
}
Expand Down
46 changes: 46 additions & 0 deletions clang/test/CIR/Lowering/switch.cir
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,50 @@ module {
// CHECK: llvm.return
// CHECK: }

cir.func @shouldLowerNestedBreak(%arg0: !s32i, %arg1: !s32i) -> !s32i {
%0 = cir.alloca !s32i, cir.ptr <!s32i>, ["x", init] {alignment = 4 : i64}
%1 = cir.alloca !s32i, cir.ptr <!s32i>, ["y", init] {alignment = 4 : i64}
%2 = cir.alloca !s32i, cir.ptr <!s32i>, ["__retval"] {alignment = 4 : i64}
cir.store %arg0, %0 : !s32i, cir.ptr <!s32i>
cir.store %arg1, %1 : !s32i, cir.ptr <!s32i>
cir.scope {
%5 = cir.load %0 : cir.ptr <!s32i>, !s32i
cir.switch (%5 : !s32i) [
case (equal, 0) {
cir.scope {
%6 = cir.load %1 : cir.ptr <!s32i>, !s32i
%7 = cir.const(#cir.int<0> : !s32i) : !s32i
%8 = cir.cmp(ge, %6, %7) : !s32i, !s32i
%9 = cir.cast(int_to_bool, %8 : !s32i), !cir.bool
cir.if %9 {
cir.yield break
}
}
cir.yield break
}
]
}
%3 = cir.const(#cir.int<3> : !s32i) : !s32i
cir.store %3, %2 : !s32i, cir.ptr <!s32i>
%4 = cir.load %2 : cir.ptr <!s32i>, !s32i
cir.return %4 : !s32i
}
// CHECK: llvm.func @shouldLowerNestedBreak
// CHECK: llvm.switch %6 : i32, ^bb7 [
// CHECK: 0: ^bb2
// CHECK: ]
// CHECK: ^bb2: // pred: ^bb1
// CHECK: llvm.br ^bb3
// CHECK: ^bb3: // pred: ^bb2
// CHECK: llvm.cond_br %14, ^bb4, ^bb5
// CHECK: ^bb4: // pred: ^bb3
// CHECK: llvm.br ^bb7
// CHECK: ^bb5: // pred: ^bb3
// CHECK: llvm.br ^bb6
// CHECK: ^bb6: // pred: ^bb5
// CHECK: llvm.br ^bb7
// CHECK: ^bb7: // 3 preds: ^bb1, ^bb4, ^bb6
// CHECK: llvm.br ^bb8
// CHECK: ^bb8: // pred: ^bb7
// CHECK: llvm.return
}

0 comments on commit 6d49675

Please sign in to comment.