diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 3f354b7868a3b..a4185a47318c7 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -29,6 +29,7 @@ #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/MutableBox.h" +#include "flang/Optimizer/Builder/Runtime/Character.h" #include "flang/Optimizer/Builder/Runtime/Ragged.h" #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Support/FIRContext.h" @@ -811,6 +812,9 @@ class FirConverter : public Fortran::lower::AbstractConverter { cat == Fortran::common::TypeCategory::Complex || cat == Fortran::common::TypeCategory::Logical; } + static bool isLogicalCategory(Fortran::common::TypeCategory cat) { + return cat == Fortran::common::TypeCategory::Logical; + } bool isCharacterCategory(Fortran::common::TypeCategory cat) { return cat == Fortran::common::TypeCategory::Character; } @@ -818,6 +822,14 @@ class FirConverter : public Fortran::lower::AbstractConverter { return cat == Fortran::common::TypeCategory::Derived; } + /// Insert a new block before \p block. Leave the insertion point unchanged. + mlir::Block *insertBlock(mlir::Block *block) { + mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + mlir::Block *newBlock = builder->createBlock(block); + builder->restoreInsertionPoint(insertPt); + return newBlock; + } + mlir::Block *blockOfLabel(Fortran::lower::pft::Evaluation &eval, Fortran::parser::Label label) { const Fortran::lower::pft::LabelEvalMap &labelEvaluationMap = @@ -1399,7 +1411,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { } void genFIR(const Fortran::parser::CaseConstruct &) { - TODO(toLocation(), "CaseConstruct lowering"); + for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) + genFIR(e); } template @@ -1630,8 +1643,170 @@ class FirConverter : public Fortran::lower::AbstractConverter { TODO(toLocation(), "OpenMPDeclarativeConstruct lowering"); } - void genFIR(const Fortran::parser::SelectCaseStmt &) { - TODO(toLocation(), "SelectCaseStmt lowering"); + /// Generate FIR for a SELECT CASE statement. + /// The type may be CHARACTER, INTEGER, or LOGICAL. + void genFIR(const Fortran::parser::SelectCaseStmt &stmt) { + Fortran::lower::pft::Evaluation &eval = getEval(); + MLIRContext *context = builder->getContext(); + mlir::Location loc = toLocation(); + Fortran::lower::StatementContext stmtCtx; + const Fortran::lower::SomeExpr *expr = Fortran::semantics::GetExpr( + std::get>(stmt.t)); + bool isCharSelector = isCharacterCategory(expr->GetType()->category()); + bool isLogicalSelector = isLogicalCategory(expr->GetType()->category()); + auto charValue = [&](const Fortran::lower::SomeExpr *expr) { + fir::ExtendedValue exv = genExprAddr(*expr, stmtCtx, &loc); + return exv.match( + [&](const fir::CharBoxValue &cbv) { + return fir::factory::CharacterExprHelper{*builder, loc} + .createEmboxChar(cbv.getAddr(), cbv.getLen()); + }, + [&](auto) { + fir::emitFatalError(loc, "not a character"); + return mlir::Value{}; + }); + }; + mlir::Value selector; + if (isCharSelector) { + selector = charValue(expr); + } else { + selector = createFIRExpr(loc, expr, stmtCtx); + if (isLogicalSelector) + selector = builder->createConvert(loc, builder->getI1Type(), selector); + } + mlir::Type selectType = selector.getType(); + llvm::SmallVector attrList; + llvm::SmallVector valueList; + llvm::SmallVector blockList; + mlir::Block *defaultBlock = eval.parentConstruct->constructExit->block; + using CaseValue = Fortran::parser::Scalar; + auto addValue = [&](const CaseValue &caseValue) { + const Fortran::lower::SomeExpr *expr = + Fortran::semantics::GetExpr(caseValue.thing); + if (isCharSelector) + valueList.push_back(charValue(expr)); + else if (isLogicalSelector) + valueList.push_back(builder->createConvert( + loc, selectType, createFIRExpr(toLocation(), expr, stmtCtx))); + else + valueList.push_back(builder->createIntegerConstant( + loc, selectType, *Fortran::evaluate::ToInt64(*expr))); + }; + for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e; + e = e->controlSuccessor) { + const auto &caseStmt = e->getIf(); + assert(e->block && "missing CaseStmt block"); + const auto &caseSelector = + std::get(caseStmt->t); + const auto *caseValueRangeList = + std::get_if>( + &caseSelector.u); + if (!caseValueRangeList) { + defaultBlock = e->block; + continue; + } + for (const Fortran::parser::CaseValueRange &caseValueRange : + *caseValueRangeList) { + blockList.push_back(e->block); + if (const auto *caseValue = std::get_if(&caseValueRange.u)) { + attrList.push_back(fir::PointIntervalAttr::get(context)); + addValue(*caseValue); + continue; + } + const auto &caseRange = + std::get(caseValueRange.u); + if (caseRange.lower && caseRange.upper) { + attrList.push_back(fir::ClosedIntervalAttr::get(context)); + addValue(*caseRange.lower); + addValue(*caseRange.upper); + } else if (caseRange.lower) { + attrList.push_back(fir::LowerBoundAttr::get(context)); + addValue(*caseRange.lower); + } else { + attrList.push_back(fir::UpperBoundAttr::get(context)); + addValue(*caseRange.upper); + } + } + } + // Skip a logical default block that can never be referenced. + if (isLogicalSelector && attrList.size() == 2) + defaultBlock = eval.parentConstruct->constructExit->block; + attrList.push_back(mlir::UnitAttr::get(context)); + blockList.push_back(defaultBlock); + + // Generate a fir::SelectCaseOp. + // Explicit branch code is better for the LOGICAL type. The CHARACTER type + // does not yet have downstream support, and also uses explicit branch code. + // The -no-structured-fir option can be used to force generation of INTEGER + // type branch code. + if (!isLogicalSelector && !isCharSelector && eval.lowerAsStructured()) { + // Numeric selector is a ssa register, all temps that may have + // been generated while evaluating it can be cleaned-up before the + // fir.select_case. + stmtCtx.finalize(); + builder->create(loc, selector, attrList, valueList, + blockList); + return; + } + + // Generate a sequence of case value comparisons and branches. + auto caseValue = valueList.begin(); + auto caseBlock = blockList.begin(); + for (mlir::Attribute attr : attrList) { + if (attr.isa()) { + genFIRBranch(*caseBlock++); + break; + } + auto genCond = [&](mlir::Value rhs, + mlir::arith::CmpIPredicate pred) -> mlir::Value { + if (!isCharSelector) + return builder->create(loc, pred, selector, rhs); + fir::factory::CharacterExprHelper charHelper{*builder, loc}; + std::pair lhsVal = + charHelper.createUnboxChar(selector); + mlir::Value &lhsAddr = lhsVal.first; + mlir::Value &lhsLen = lhsVal.second; + std::pair rhsVal = + charHelper.createUnboxChar(rhs); + mlir::Value &rhsAddr = rhsVal.first; + mlir::Value &rhsLen = rhsVal.second; + return fir::runtime::genCharCompare(*builder, loc, pred, lhsAddr, + lhsLen, rhsAddr, rhsLen); + }; + mlir::Block *newBlock = insertBlock(*caseBlock); + if (attr.isa()) { + mlir::Block *newBlock2 = insertBlock(*caseBlock); + mlir::Value cond = + genCond(*caseValue++, mlir::arith::CmpIPredicate::sge); + genFIRConditionalBranch(cond, newBlock, newBlock2); + builder->setInsertionPointToEnd(newBlock); + mlir::Value cond2 = + genCond(*caseValue++, mlir::arith::CmpIPredicate::sle); + genFIRConditionalBranch(cond2, *caseBlock++, newBlock2); + builder->setInsertionPointToEnd(newBlock2); + continue; + } + mlir::arith::CmpIPredicate pred; + if (attr.isa()) { + pred = mlir::arith::CmpIPredicate::eq; + } else if (attr.isa()) { + pred = mlir::arith::CmpIPredicate::sge; + } else { + assert(attr.isa() && "unexpected predicate"); + pred = mlir::arith::CmpIPredicate::sle; + } + mlir::Value cond = genCond(*caseValue++, pred); + genFIRConditionalBranch(cond, *caseBlock++, newBlock); + builder->setInsertionPointToEnd(newBlock); + } + assert(caseValue == valueList.end() && caseBlock == blockList.end() && + "select case list mismatch"); + // Clean-up the selector at the end of the construct if it is a temporary + // (which is possible with characters). + mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + builder->setInsertionPointToEnd(eval.parentConstruct->constructExit->block); + stmtCtx.finalize(); + builder->restoreInsertionPoint(insertPt); } fir::ExtendedValue @@ -2115,10 +2290,6 @@ class FirConverter : public Fortran::lower::AbstractConverter { genFIRBranch(getEval().controlSuccessor->block); } - void genFIR(const Fortran::parser::CaseStmt &) { - TODO(toLocation(), "CaseStmt lowering"); - } - void genFIR(const Fortran::parser::ElseIfStmt &) { TODO(toLocation(), "ElseIfStmt lowering"); } @@ -2135,16 +2306,14 @@ class FirConverter : public Fortran::lower::AbstractConverter { TODO(toLocation(), "EndMpSubprogramStmt lowering"); } - void genFIR(const Fortran::parser::EndSelectStmt &) { - TODO(toLocation(), "EndSelectStmt lowering"); - } - // Nop statements - No code, or code is generated at the construct level. void genFIR(const Fortran::parser::AssociateStmt &) {} // nop + void genFIR(const Fortran::parser::CaseStmt &) {} // nop void genFIR(const Fortran::parser::ContinueStmt &) {} // nop void genFIR(const Fortran::parser::EndAssociateStmt &) {} // nop void genFIR(const Fortran::parser::EndFunctionStmt &) {} // nop void genFIR(const Fortran::parser::EndIfStmt &) {} // nop + void genFIR(const Fortran::parser::EndSelectStmt &) {} // nop void genFIR(const Fortran::parser::EndSubroutineStmt &) {} // nop void genFIR(const Fortran::parser::EntryStmt &) {} // nop @@ -2168,6 +2337,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { TODO(toLocation(), "NamelistStmt lowering"); } + /// Generate FIR for the Evaluation `eval`. void genFIR(Fortran::lower::pft::Evaluation &eval, bool unstructuredContext = true) { if (unstructuredContext) { @@ -2181,6 +2351,19 @@ class FirConverter : public Fortran::lower::AbstractConverter { setCurrentEval(eval); setCurrentPosition(eval.position); eval.visit([&](const auto &stmt) { genFIR(stmt); }); + + if (unstructuredContext && blockIsUnterminated()) { + // Exit from an unstructured IF or SELECT construct block. + Fortran::lower::pft::Evaluation *successor{}; + if (eval.isActionStmt()) + successor = eval.controlSuccessor; + else if (eval.isConstruct() && + eval.getLastNestedEvaluation() + .lexicalSuccessor->isIntermediateConstructStmt()) + successor = eval.constructExit; + if (successor && successor->block) + genFIRBranch(successor->block); + } } //===--------------------------------------------------------------------===// diff --git a/flang/test/Lower/select-case-statement.f90 b/flang/test/Lower/select-case-statement.f90 new file mode 100644 index 0000000000000..2efbcc036dd67 --- /dev/null +++ b/flang/test/Lower/select-case-statement.f90 @@ -0,0 +1,211 @@ +! RUN: bbc -emit-fir -o - %s | FileCheck %s + + ! CHECK-LABEL: sinteger + function sinteger(n) + integer sinteger + nn = -88 + ! CHECK: fir.select_case {{.*}} : i32 + ! CHECK-SAME: upper, %c1 + ! CHECK-SAME: point, %c2 + ! CHECK-SAME: point, %c3 + ! CHECK-SAME: interval, %c4{{.*}} %c5 + ! CHECK-SAME: point, %c6 + ! CHECK-SAME: point, %c7 + ! CHECK-SAME: interval, %c8{{.*}} %c15 + ! CHECK-SAME: lower, %c21 + ! CHECK-SAME: unit + select case(n) + case (:1) + nn = 1 + case (2) + nn = 2 + case default + nn = 0 + case (3) + nn = 3 + case (4:5+1-1) + nn = 4 + case (6) + nn = 6 + case (7,8:15,21:) + nn = 7 + end select + sinteger = nn + end + + ! CHECK-LABEL: slogical + subroutine slogical(L) + logical :: L + n1 = 0 + n2 = 0 + n3 = 0 + n4 = 0 + n5 = 0 + n6 = 0 + n7 = 0 + n8 = 0 + + select case (L) + end select + + select case (L) + ! CHECK: cmpi eq, {{.*}} %false + ! CHECK: cond_br + case (.false.) + n2 = 1 + end select + + select case (L) + ! CHECK: cmpi eq, {{.*}} %true + ! CHECK: cond_br + case (.true.) + n3 = 2 + end select + + select case (L) + case default + n4 = 3 + end select + + select case (L) + ! CHECK: cmpi eq, {{.*}} %false + ! CHECK: cond_br + case (.false.) + n5 = 1 + ! CHECK: cmpi eq, {{.*}} %true + ! CHECK: cond_br + case (.true.) + n5 = 2 + end select + + select case (L) + ! CHECK: cmpi eq, {{.*}} %false + ! CHECK: cond_br + case (.false.) + n6 = 1 + case default + n6 = 3 + end select + + select case (L) + ! CHECK: cmpi eq, {{.*}} %true + ! CHECK: cond_br + case (.true.) + n7 = 2 + case default + n7 = 3 + end select + + select case (L) + ! CHECK: cmpi eq, {{.*}} %false + ! CHECK: cond_br + case (.false.) + n8 = 1 + ! CHECK: cmpi eq, {{.*}} %true + ! CHECK: cond_br + case (.true.) + n8 = 2 + ! CHECK-NOT: constant 888 + case default ! dead + n8 = 888 + end select + + print*, n1, n2, n3, n4, n5, n6, n7, n8 + end + + ! CHECK-LABEL: scharacter + subroutine scharacter(c) + character(*) :: c + nn = 0 + select case (c) + case default + nn = -1 + ! CHECK: CharacterCompareScalar1 + ! CHECK-NEXT: constant 0 + ! CHECK-NEXT: cmpi sle, {{.*}} %c0 + ! CHECK-NEXT: cond_br + case (:'d') + nn = 10 + ! CHECK: CharacterCompareScalar1 + ! CHECK-NEXT: constant 0 + ! CHECK-NEXT: cmpi sge, {{.*}} %c0 + ! CHECK-NEXT: cond_br + ! CHECK: CharacterCompareScalar1 + ! CHECK-NEXT: constant 0 + ! CHECK-NEXT: cmpi sle, {{.*}} %c0 + ! CHECK-NEXT: cond_br + case ('ff':'ffff') + nn = 20 + ! CHECK: CharacterCompareScalar1 + ! CHECK-NEXT: constant 0 + ! CHECK-NEXT: cmpi eq, {{.*}} %c0 + ! CHECK-NEXT: cond_br + case ('m') + nn = 30 + ! CHECK: CharacterCompareScalar1 + ! CHECK-NEXT: constant 0 + ! CHECK-NEXT: cmpi eq, {{.*}} %c0 + ! CHECK-NEXT: cond_br + case ('qq') + nn = 40 + ! CHECK: CharacterCompareScalar1 + ! CHECK-NEXT: constant 0 + ! CHECK-NEXT: cmpi sge, {{.*}} %c0 + ! CHECK-NEXT: cond_br + case ('x':) + nn = 50 + end select + print*, nn + end + + ! CHECK-LABEL: func @_QPtest_char_temp_selector + subroutine test_char_temp_selector() + ! Test that character selector that are temps are deallocated + ! only after they have been used in the select case comparisons. + interface + function gen_char_temp_selector() + character(:), allocatable :: gen_char_temp_selector + end function + end interface + select case (gen_char_temp_selector()) + case ('case1') + call foo1() + case ('case2') + call foo2() + case ('case3') + call foo3() + case default + call foo_default() + end select + ! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box>> {bindc_name = ".result"} + ! CHECK: %[[VAL_1:.*]] = fir.call @_QPgen_char_temp_selector() : () -> !fir.box>> + ! CHECK: fir.save_result %[[VAL_1]] to %[[VAL_0]] : !fir.box>>, !fir.ref>>> + ! CHECK: cond_br %{{.*}}, ^bb2, ^bb1 + ! CHECK: ^bb1: + ! CHECK: cond_br %{{.*}}, ^bb4, ^bb3 + ! CHECK: ^bb2: + ! CHECK: fir.call @_QPfoo1() : () -> () + ! CHECK: br ^bb8 + ! CHECK: ^bb3: + ! CHECK: cond_br %{{.*}}, ^bb6, ^bb5 + ! CHECK: ^bb4: + ! CHECK: fir.call @_QPfoo2() : () -> () + ! CHECK: br ^bb8 + ! CHECK: ^bb5: + ! CHECK: br ^bb7 + ! CHECK: ^bb6: + ! CHECK: fir.call @_QPfoo3() : () -> () + ! CHECK: br ^bb8 + ! CHECK: ^bb7: + ! CHECK: fir.call @_QPfoo_default() : () -> () + ! CHECK: br ^bb8 + ! CHECK: ^bb8: + ! CHECK: %[[VAL_36:.*]] = fir.load %[[VAL_0]] : !fir.ref>>> + ! CHECK: %[[VAL_37:.*]] = fir.box_addr %[[VAL_36]] : (!fir.box>>) -> !fir.heap> + ! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_37]] : (!fir.heap>) -> i64 + ! CHECK: %[[VAL_39:.*]] = arith.constant 0 : i64 + ! CHECK: %[[VAL_40:.*]] = arith.cmpi ne, %[[VAL_38]], %[[VAL_39]] : i64 + ! CHECK: fir.if %[[VAL_40]] { + ! CHECK: fir.freemem %[[VAL_37]] + ! CHECK: } + end subroutine