diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp index 187caa6043ac8..e7e71a9f0bcb3 100644 --- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp @@ -126,6 +126,7 @@ struct IterWhileConversion : public mlir::OpRewritePattern { mlir::Value okInit = iterWhileOp.getIterateIn(); mlir::ValueRange iterArgs = iterWhileOp.getInitArgs(); + bool hasFinalValue = iterWhileOp.getFinalValue().has_value(); mlir::SmallVector initVals; initVals.push_back(lowerBound); @@ -164,17 +165,22 @@ struct IterWhileConversion : public mlir::OpRewritePattern { auto *afterBody = scfWhileOp.getAfterBody(); auto resultOp = mlir::cast(afterBody->getTerminator()); - mlir::SmallVector results(resultOp->getOperands()); - mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0]; + mlir::SmallVector results; + mlir::Value iv = scfWhileOp.getAfterArguments()[0]; rewriter.setInsertionPointToStart(afterBody); - results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step); + results.push_back(mlir::arith::AddIOp::create(rewriter, loc, iv, step)); + llvm::append_range(results, hasFinalValue + ? resultOp->getOperands().drop_front() + : resultOp->getOperands()); rewriter.setInsertionPointToEnd(afterBody); rewriter.replaceOpWithNewOp(resultOp, results); scfWhileOp->setAttrs(iterWhileOp->getAttrs()); - rewriter.replaceOp(iterWhileOp, scfWhileOp); + rewriter.replaceOp(iterWhileOp, + hasFinalValue ? scfWhileOp->getResults() + : scfWhileOp->getResults().drop_front()); return mlir::success(); } }; diff --git a/flang/test/Fir/FirToSCF/iter-while.fir b/flang/test/Fir/FirToSCF/iter-while.fir index 0de7aabed120e..19cfaac1cc460 100644 --- a/flang/test/Fir/FirToSCF/iter-while.fir +++ b/flang/test/Fir/FirToSCF/iter-while.fir @@ -1,4 +1,4 @@ -// RUN: fir-opt %s --fir-to-scf | FileCheck %s +// RUN: fir-opt %s --fir-to-scf --allow-unregistered-dialect | FileCheck %s // CHECK-LABEL: func.func @test_simple_iterate_while_1() -> (index, i1, i16, i32) { // CHECK: %[[VAL_0:.*]] = arith.constant 11 : index @@ -97,3 +97,30 @@ func.func @test_zero_iterations() -> (index, i1, i8) { return %res#0, %res#1, %res#2 : index, i1, i8 } + +// CHECK-LABEL: func.func @test_without_final_value( +// CHECK-SAME: %[[ARG0:.*]]: index, +// CHECK-SAME: %[[ARG1:.*]]: index) -> i1 { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant true +// CHECK: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[ARG0]], %[[VAL_1:.*]] = %[[CONSTANT_1]]) : (index, i1) -> (index, i1) { +// CHECK: %[[CMPI_0:.*]] = arith.cmpi sle, %[[VAL_0]], %[[ARG1]] : index +// CHECK: %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[VAL_1]] : i1 +// CHECK: scf.condition(%[[ANDI_0]]) %[[VAL_0]], %[[VAL_1]] : index, i1 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: i1): +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_0]] : index +// CHECK: %[[VAL_4:.*]] = "test.get_some_value"() : () -> i1 +// CHECK: scf.yield %[[ADDI_0]], %[[VAL_4]] : index, i1 +// CHECK: } +// CHECK: return %[[VAL_5:.*]]#1 : i1 +// CHECK: } +func.func @test_without_final_value(%lo : index, %up : index) -> i1 { + %c1 = arith.constant 1 : index + %ok1 = arith.constant true + %ok2 = fir.iterate_while (%i = %lo to %up step %c1) and (%j = %ok1) { + %ok = "test.get_some_value"() : () -> i1 + fir.result %ok : i1 + } + return %ok2 : i1 +}