Skip to content

Commit

Permalink
[CIR][Lowering] Fix loop lowering for top-level break/continue (#349)
Browse files Browse the repository at this point in the history
This PR fixes a couple of corner cases connected with the `YieldOp`
lowering in loops.
Previously, in #211 we introduced `lowerNestedBreakContinue` but we
didn't check that `YieldOp` may belong to the same region, i.e. it is
not nested, e.g.
```
while(1) {
   break;
}
```
Hence the error `op already replaced`. 

Next, we fix `yield` lowering for `ifOp` and `switchOp` but didn't cover
`scopeOp`, and the same error occurred. This PR fixes this as well.

I added two tests - with no checks actually, just to make sure no more
crashes happen.

fixes #324
  • Loading branch information
gitoleg authored and lanza committed Dec 20, 2023
1 parent f1c620f commit 6a9391a
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 5 deletions.
27 changes: 22 additions & 5 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,15 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
lowerNestedBreakContinue(mlir::Region &loopBody, mlir::Block *exitBlock,
mlir::Block *continueBlock,
mlir::ConversionPatternRewriter &rewriter) const {
// top-level yields are lowered in matchAndRewrite
auto isNested = [&](mlir::Operation *op) {
return op->getParentRegion() != &loopBody;
};

auto processBreak = [&](mlir::Operation *op) {
if (!isNested(op))
return mlir::WalkResult::advance();

if (isa<mlir::cir::LoopOp, mlir::cir::SwitchOp>(
*op)) // don't process breaks in nested loops and switches
return mlir::WalkResult::skip();
Expand All @@ -421,6 +428,9 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
};

auto processContinue = [&](mlir::Operation *op) {
if (!isNested(op))
return mlir::WalkResult::advance();

if (isa<mlir::cir::LoopOp>(
*op)) // don't process continues in nested loops
return mlir::WalkResult::skip();
Expand Down Expand Up @@ -490,7 +500,10 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {

// Branch from body to condition or to step on for-loop cases.
rewriter.setInsertionPoint(bodyYield);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, &stepBlock);
auto bodyYieldDest = bodyYield.getKind() == mlir::cir::YieldOpKind::Break
? continueBlock
: &stepBlock;
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(bodyYield, bodyYieldDest);

// Is a for loop: branch from step to condition.
if (kind == LoopKind::For) {
Expand Down Expand Up @@ -822,11 +835,15 @@ class CIRScopeOpLowering
// Stack restore before leaving the body region.
rewriter.setInsertionPointToEnd(afterBody);
auto yieldOp = cast<mlir::cir::YieldOp>(afterBody->getTerminator());
auto branchOp = rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
yieldOp, yieldOp.getArgs(), continueBlock);

// // Insert stack restore before jumping out of the body of the region.
rewriter.setInsertionPoint(branchOp);
if (!isLoopYield(yieldOp)) {
auto branchOp = rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(
yieldOp, yieldOp.getArgs(), continueBlock);

// // Insert stack restore before jumping out of the body of the region.
rewriter.setInsertionPoint(branchOp);
}

// TODO(CIR): stackrestore?
// rewriter.create<mlir::LLVM::StackRestoreOp>(loc, stackSaveOp);

Expand Down
78 changes: 78 additions & 0 deletions clang/test/CIR/Lowering/loop.cir
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,82 @@ module {
// MLIR-NEXT: llvm.br ^bb6
// MLIR-NEXT: ^bb6:
// MLIR-NEXT: llvm.return

// test corner case
// while (1) {
// break;
// }
cir.func @whileCornerCase() {
cir.scope {
cir.loop while(cond : {
%0 = cir.const(#cir.int<1> : !s32i) : !s32i
%1 = cir.cast(int_to_bool, %0 : !s32i), !cir.bool
cir.brcond %1 ^bb1, ^bb2
^bb1: // pred: ^bb0
cir.yield continue
^bb2: // pred: ^bb0
cir.yield
}, step : {
cir.yield
}) {
cir.yield break
}
}
cir.return
}
// MLIR: llvm.func @whileCornerCase()
// MLIR: %0 = llvm.mlir.constant(1 : i32) : i32
// MLIR-NEXT: %1 = llvm.mlir.constant(0 : i32) : i32
// MLIR-NEXT: %2 = llvm.icmp "ne" %0, %1 : i32
// MLIR-NEXT: %3 = llvm.zext %2 : i1 to i8
// MLIR-NEXT: %4 = llvm.trunc %3 : i8 to i
// MLIR-NEXT: llvm.cond_br %4, ^bb3, ^bb4
// MLIR-NEXT: ^bb3: // pred: ^bb2
// MLIR-NEXT: llvm.br ^bb5
// MLIR-NEXT: ^bb4: // pred: ^bb2
// MLIR-NEXT: llvm.br ^bb6
// MLIR-NEXT: ^bb5: // pred: ^bb3
// MLIR-NEXT: llvm.br ^bb6
// MLIR-NEXT: ^bb6: // 2 preds: ^bb4, ^bb5
// MLIR-NEXT: llvm.br ^bb7
// MLIR-NEXT: ^bb7: // pred: ^bb6
// MLIR-NEXT: llvm.return

// test corner case - no fails during the lowering
// for (;;) {
// break;
// }
cir.func @forCornerCase() {
cir.scope {
cir.loop for(cond : {
cir.yield continue
}, step : {
cir.yield
}) {
cir.scope {
cir.yield break
}
cir.yield
}
}
cir.return
}
// MLIR: llvm.func @forCornerCase()
// MLIR: llvm.br ^bb1
// MLIR-NEXT: ^bb1: // pred: ^bb0
// MLIR-NEXT: llvm.br ^bb2
// MLIR-NEXT: ^bb2: // 2 preds: ^bb1, ^bb6
// MLIR-NEXT: llvm.br ^bb3
// MLIR-NEXT: ^bb3: // pred: ^bb2
// MLIR-NEXT: llvm.br ^bb4
// MLIR-NEXT: ^bb4: // pred: ^bb3
// MLIR-NEXT: llvm.br ^bb7
// MLIR-NEXT: ^bb5: // no predecessors
// MLIR-NEXT: llvm.br ^bb6
// MLIR-NEXT: ^bb6: // pred: ^bb5
// MLIR-NEXT: llvm.br ^bb2
// MLIR-NEXT: ^bb7: // pred: ^bb4
// MLIR-NEXT: llvm.br ^bb8
// MLIR-NEXT: ^bb8: // pred: ^bb7
// MLIR-NEXT: llvm.return
}

0 comments on commit 6a9391a

Please sign in to comment.