Skip to content

Commit

Permalink
[flang] Lower select case statement
Browse files Browse the repository at this point in the history
This patch adds lowering for the `select case`
statement.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D122007

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
Co-authored-by: V Donaldson <vdonaldson@nvidia.com>
  • Loading branch information
4 people committed Mar 18, 2022
1 parent 1b7ef6a commit 308fc3f
Show file tree
Hide file tree
Showing 2 changed files with 405 additions and 11 deletions.
205 changes: 194 additions & 11 deletions flang/lib/Lower/Bridge.cpp
Expand Up @@ -29,6 +29,7 @@
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/MutableBox.h"
#include "flang/Optimizer/Builder/Runtime/Character.h"
#include "flang/Optimizer/Builder/Runtime/Ragged.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Support/FIRContext.h"
Expand Down Expand Up @@ -811,13 +812,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
cat == Fortran::common::TypeCategory::Complex ||
cat == Fortran::common::TypeCategory::Logical;
}
static bool isLogicalCategory(Fortran::common::TypeCategory cat) {
return cat == Fortran::common::TypeCategory::Logical;
}
bool isCharacterCategory(Fortran::common::TypeCategory cat) {
return cat == Fortran::common::TypeCategory::Character;
}
bool isDerivedCategory(Fortran::common::TypeCategory cat) {
return cat == Fortran::common::TypeCategory::Derived;
}

/// Insert a new block before \p block. Leave the insertion point unchanged.
mlir::Block *insertBlock(mlir::Block *block) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
mlir::Block *newBlock = builder->createBlock(block);
builder->restoreInsertionPoint(insertPt);
return newBlock;
}

mlir::Block *blockOfLabel(Fortran::lower::pft::Evaluation &eval,
Fortran::parser::Label label) {
const Fortran::lower::pft::LabelEvalMap &labelEvaluationMap =
Expand Down Expand Up @@ -1399,7 +1411,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}

void genFIR(const Fortran::parser::CaseConstruct &) {
TODO(toLocation(), "CaseConstruct lowering");
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
genFIR(e);
}

template <typename A>
Expand Down Expand Up @@ -1630,8 +1643,170 @@ class FirConverter : public Fortran::lower::AbstractConverter {
TODO(toLocation(), "OpenMPDeclarativeConstruct lowering");
}

void genFIR(const Fortran::parser::SelectCaseStmt &) {
TODO(toLocation(), "SelectCaseStmt lowering");
/// Generate FIR for a SELECT CASE statement.
/// The type may be CHARACTER, INTEGER, or LOGICAL.
void genFIR(const Fortran::parser::SelectCaseStmt &stmt) {
Fortran::lower::pft::Evaluation &eval = getEval();
MLIRContext *context = builder->getContext();
mlir::Location loc = toLocation();
Fortran::lower::StatementContext stmtCtx;
const Fortran::lower::SomeExpr *expr = Fortran::semantics::GetExpr(
std::get<Fortran::parser::Scalar<Fortran::parser::Expr>>(stmt.t));
bool isCharSelector = isCharacterCategory(expr->GetType()->category());
bool isLogicalSelector = isLogicalCategory(expr->GetType()->category());
auto charValue = [&](const Fortran::lower::SomeExpr *expr) {
fir::ExtendedValue exv = genExprAddr(*expr, stmtCtx, &loc);
return exv.match(
[&](const fir::CharBoxValue &cbv) {
return fir::factory::CharacterExprHelper{*builder, loc}
.createEmboxChar(cbv.getAddr(), cbv.getLen());
},
[&](auto) {
fir::emitFatalError(loc, "not a character");
return mlir::Value{};
});
};
mlir::Value selector;
if (isCharSelector) {
selector = charValue(expr);
} else {
selector = createFIRExpr(loc, expr, stmtCtx);
if (isLogicalSelector)
selector = builder->createConvert(loc, builder->getI1Type(), selector);
}
mlir::Type selectType = selector.getType();
llvm::SmallVector<mlir::Attribute> attrList;
llvm::SmallVector<mlir::Value> valueList;
llvm::SmallVector<mlir::Block *> blockList;
mlir::Block *defaultBlock = eval.parentConstruct->constructExit->block;
using CaseValue = Fortran::parser::Scalar<Fortran::parser::ConstantExpr>;
auto addValue = [&](const CaseValue &caseValue) {
const Fortran::lower::SomeExpr *expr =
Fortran::semantics::GetExpr(caseValue.thing);
if (isCharSelector)
valueList.push_back(charValue(expr));
else if (isLogicalSelector)
valueList.push_back(builder->createConvert(
loc, selectType, createFIRExpr(toLocation(), expr, stmtCtx)));
else
valueList.push_back(builder->createIntegerConstant(
loc, selectType, *Fortran::evaluate::ToInt64(*expr)));
};
for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e;
e = e->controlSuccessor) {
const auto &caseStmt = e->getIf<Fortran::parser::CaseStmt>();
assert(e->block && "missing CaseStmt block");
const auto &caseSelector =
std::get<Fortran::parser::CaseSelector>(caseStmt->t);
const auto *caseValueRangeList =
std::get_if<std::list<Fortran::parser::CaseValueRange>>(
&caseSelector.u);
if (!caseValueRangeList) {
defaultBlock = e->block;
continue;
}
for (const Fortran::parser::CaseValueRange &caseValueRange :
*caseValueRangeList) {
blockList.push_back(e->block);
if (const auto *caseValue = std::get_if<CaseValue>(&caseValueRange.u)) {
attrList.push_back(fir::PointIntervalAttr::get(context));
addValue(*caseValue);
continue;
}
const auto &caseRange =
std::get<Fortran::parser::CaseValueRange::Range>(caseValueRange.u);
if (caseRange.lower && caseRange.upper) {
attrList.push_back(fir::ClosedIntervalAttr::get(context));
addValue(*caseRange.lower);
addValue(*caseRange.upper);
} else if (caseRange.lower) {
attrList.push_back(fir::LowerBoundAttr::get(context));
addValue(*caseRange.lower);
} else {
attrList.push_back(fir::UpperBoundAttr::get(context));
addValue(*caseRange.upper);
}
}
}
// Skip a logical default block that can never be referenced.
if (isLogicalSelector && attrList.size() == 2)
defaultBlock = eval.parentConstruct->constructExit->block;
attrList.push_back(mlir::UnitAttr::get(context));
blockList.push_back(defaultBlock);

// Generate a fir::SelectCaseOp.
// Explicit branch code is better for the LOGICAL type. The CHARACTER type
// does not yet have downstream support, and also uses explicit branch code.
// The -no-structured-fir option can be used to force generation of INTEGER
// type branch code.
if (!isLogicalSelector && !isCharSelector && eval.lowerAsStructured()) {
// Numeric selector is a ssa register, all temps that may have
// been generated while evaluating it can be cleaned-up before the
// fir.select_case.
stmtCtx.finalize();
builder->create<fir::SelectCaseOp>(loc, selector, attrList, valueList,
blockList);
return;
}

// Generate a sequence of case value comparisons and branches.
auto caseValue = valueList.begin();
auto caseBlock = blockList.begin();
for (mlir::Attribute attr : attrList) {
if (attr.isa<mlir::UnitAttr>()) {
genFIRBranch(*caseBlock++);
break;
}
auto genCond = [&](mlir::Value rhs,
mlir::arith::CmpIPredicate pred) -> mlir::Value {
if (!isCharSelector)
return builder->create<mlir::arith::CmpIOp>(loc, pred, selector, rhs);
fir::factory::CharacterExprHelper charHelper{*builder, loc};
std::pair<mlir::Value, mlir::Value> lhsVal =
charHelper.createUnboxChar(selector);
mlir::Value &lhsAddr = lhsVal.first;
mlir::Value &lhsLen = lhsVal.second;
std::pair<mlir::Value, mlir::Value> rhsVal =
charHelper.createUnboxChar(rhs);
mlir::Value &rhsAddr = rhsVal.first;
mlir::Value &rhsLen = rhsVal.second;
return fir::runtime::genCharCompare(*builder, loc, pred, lhsAddr,
lhsLen, rhsAddr, rhsLen);
};
mlir::Block *newBlock = insertBlock(*caseBlock);
if (attr.isa<fir::ClosedIntervalAttr>()) {
mlir::Block *newBlock2 = insertBlock(*caseBlock);
mlir::Value cond =
genCond(*caseValue++, mlir::arith::CmpIPredicate::sge);
genFIRConditionalBranch(cond, newBlock, newBlock2);
builder->setInsertionPointToEnd(newBlock);
mlir::Value cond2 =
genCond(*caseValue++, mlir::arith::CmpIPredicate::sle);
genFIRConditionalBranch(cond2, *caseBlock++, newBlock2);
builder->setInsertionPointToEnd(newBlock2);
continue;
}
mlir::arith::CmpIPredicate pred;
if (attr.isa<fir::PointIntervalAttr>()) {
pred = mlir::arith::CmpIPredicate::eq;
} else if (attr.isa<fir::LowerBoundAttr>()) {
pred = mlir::arith::CmpIPredicate::sge;
} else {
assert(attr.isa<fir::UpperBoundAttr>() && "unexpected predicate");
pred = mlir::arith::CmpIPredicate::sle;
}
mlir::Value cond = genCond(*caseValue++, pred);
genFIRConditionalBranch(cond, *caseBlock++, newBlock);
builder->setInsertionPointToEnd(newBlock);
}
assert(caseValue == valueList.end() && caseBlock == blockList.end() &&
"select case list mismatch");
// Clean-up the selector at the end of the construct if it is a temporary
// (which is possible with characters).
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
builder->setInsertionPointToEnd(eval.parentConstruct->constructExit->block);
stmtCtx.finalize();
builder->restoreInsertionPoint(insertPt);
}

fir::ExtendedValue
Expand Down Expand Up @@ -2115,10 +2290,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genFIRBranch(getEval().controlSuccessor->block);
}

void genFIR(const Fortran::parser::CaseStmt &) {
TODO(toLocation(), "CaseStmt lowering");
}

void genFIR(const Fortran::parser::ElseIfStmt &) {
TODO(toLocation(), "ElseIfStmt lowering");
}
Expand All @@ -2135,16 +2306,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
TODO(toLocation(), "EndMpSubprogramStmt lowering");
}

void genFIR(const Fortran::parser::EndSelectStmt &) {
TODO(toLocation(), "EndSelectStmt lowering");
}

// Nop statements - No code, or code is generated at the construct level.
void genFIR(const Fortran::parser::AssociateStmt &) {} // nop
void genFIR(const Fortran::parser::CaseStmt &) {} // nop
void genFIR(const Fortran::parser::ContinueStmt &) {} // nop
void genFIR(const Fortran::parser::EndAssociateStmt &) {} // nop
void genFIR(const Fortran::parser::EndFunctionStmt &) {} // nop
void genFIR(const Fortran::parser::EndIfStmt &) {} // nop
void genFIR(const Fortran::parser::EndSelectStmt &) {} // nop
void genFIR(const Fortran::parser::EndSubroutineStmt &) {} // nop
void genFIR(const Fortran::parser::EntryStmt &) {} // nop

Expand All @@ -2168,6 +2337,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
TODO(toLocation(), "NamelistStmt lowering");
}

/// Generate FIR for the Evaluation `eval`.
void genFIR(Fortran::lower::pft::Evaluation &eval,
bool unstructuredContext = true) {
if (unstructuredContext) {
Expand All @@ -2181,6 +2351,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
setCurrentEval(eval);
setCurrentPosition(eval.position);
eval.visit([&](const auto &stmt) { genFIR(stmt); });

if (unstructuredContext && blockIsUnterminated()) {
// Exit from an unstructured IF or SELECT construct block.
Fortran::lower::pft::Evaluation *successor{};
if (eval.isActionStmt())
successor = eval.controlSuccessor;
else if (eval.isConstruct() &&
eval.getLastNestedEvaluation()
.lexicalSuccessor->isIntermediateConstructStmt())
successor = eval.constructExit;
if (successor && successor->block)
genFIRBranch(successor->block);
}
}

//===--------------------------------------------------------------------===//
Expand Down

0 comments on commit 308fc3f

Please sign in to comment.