diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp index f06ad2db90d55..9a0071a6c6ae6 100644 --- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp @@ -87,13 +87,48 @@ struct DoLoopConversion : public OpRewritePattern { return success(); } }; + +struct IfConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(fir::IfOp ifOp, + PatternRewriter &rewriter) const override { + mlir::Location loc = ifOp.getLoc(); + mlir::detail::TypedValue condition = ifOp.getCondition(); + ValueTypeRange resultTypes = ifOp.getResultTypes(); + mlir::scf::IfOp scfIfOp = rewriter.create( + loc, resultTypes, condition, !ifOp.getElseRegion().empty()); + // then region + scfIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + Block &scfThenBlock = scfIfOp.getThenRegion().front(); + Operation *scfThenTerminator = scfThenBlock.getTerminator(); + // fir.result->scf.yield + rewriter.setInsertionPointToEnd(&scfThenBlock); + rewriter.replaceOpWithNewOp(scfThenTerminator, + scfThenTerminator->getOperands()); + + // else region + if (!ifOp.getElseRegion().empty()) { + scfIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + mlir::Block &elseBlock = scfIfOp.getElseRegion().front(); + mlir::Operation *elseTerminator = elseBlock.getTerminator(); + + rewriter.setInsertionPointToEnd(&elseBlock); + rewriter.replaceOpWithNewOp(elseTerminator, + elseTerminator->getOperands()); + } + + scfIfOp->setAttrs(ifOp->getAttrs()); + rewriter.replaceOp(ifOp, scfIfOp); + return success(); + } +}; } // namespace void FIRToSCFPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); ConversionTarget target(getContext()); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/flang/test/Fir/FirToSCF/if.fir b/flang/test/Fir/FirToSCF/if.fir new file mode 100644 index 0000000000000..9e43cf1cd11d0 --- /dev/null +++ b/flang/test/Fir/FirToSCF/if.fir @@ -0,0 +1,56 @@ +// RUN: fir-opt %s --fir-to-scf | FileCheck %s + +// CHECK: func.func @test_only(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) { +// CHECK: scf.if %[[ARG0:.*]] { +// CHECK: %[[VAL_1:.*]] = arith.addi %[[ARG1:.*]], %[[ARG1:.*]] : i32 +// CHECK: } +// CHECK: return +// CHECK: } +func.func @test_only(%arg0 : i1, %arg1 : i32) { + fir.if %arg0 { + %0 = arith.addi %arg1, %arg1 : i32 + } + return +} + +// CHECK: func.func @test_else() { +// CHECK: %[[VAL_1:.*]] = arith.constant false +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32 +// CHECK: scf.if %[[VAL_1:.*]] { +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32 +// CHECK: } else { +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32 +// CHECK: } +// CHECK: return +// CHECK: } +func.func @test_else() { + %false = arith.constant false + %1 = arith.constant 2 : i32 + fir.if %false { + %2 = arith.constant 3 : i32 + } else { + %3 = arith.constant 3 : i32 + } + return +} + +// CHECK-LABEL: func.func @test_two_result() { +// CHECK: %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant false +// CHECK: %[[RES:[0-9]+]]:2 = scf.if %[[VAL_2:.*]] -> (f32, f32) { +// CHECK: scf.yield %[[VAL_1:.*]], %[[VAL_1:.*]] : f32, f32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_1:.*]], %[[VAL_1:.*]] : f32, f32 +// CHECK: } +// CHECK: return +// CHECK: } +func.func @test_two_result() { + %1 = arith.constant 2.0 : f32 + %cmp = arith.constant false + %x, %y = fir.if %cmp -> (f32, f32) { + fir.result %1, %1 : f32, f32 + } else { + fir.result %1, %1 : f32, f32 + } + return +}