Skip to content

Commit

Permalink
[flang][hlfir] Lower WHERE to HLFIR
Browse files Browse the repository at this point in the history
Lower WHERE to the newly added hlfir.where and hlfir.elsewhere
operations.

Differential Revision: https://reviews.llvm.org/D149950
  • Loading branch information
jeanPerier committed May 9, 2023
1 parent b87e655 commit 54c88fc
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 14 deletions.
108 changes: 94 additions & 14 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3154,7 +3154,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Gather some information about the assignment that will impact how it is
// lowered.
const bool isWholeAllocatableAssignment =
!userDefinedAssignment &&
!userDefinedAssignment && !isInsideHlfirWhere() &&
Fortran::lower::isWholeAllocatable(assign.lhs);
std::optional<Fortran::evaluate::DynamicType> lhsType =
assign.lhs.GetType();
Expand Down Expand Up @@ -3243,8 +3243,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void genAssignment(const Fortran::evaluate::Assignment &assign) {
mlir::Location loc = toLocation();
if (lowerToHighLevelFIR()) {
if (!implicitIterSpace.empty())
TODO(loc, "HLFIR assignment inside WHERE");
std::visit(
Fortran::common::visitors{
[&](const Fortran::evaluate::Assignment::Intrinsic &) {
Expand Down Expand Up @@ -3452,23 +3450,47 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
}

bool isInsideHlfirForallOrWhere() const {
// Is the insertion point of the builder directly or indirectly set
// inside any operation of type "Op"?
template <typename... Op>
bool isInsideOp() const {
mlir::Block *block = builder->getInsertionBlock();
mlir::Operation *op = block ? block->getParentOp() : nullptr;
while (op) {
if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp>(op))
if (mlir::isa<Op...>(op))
return true;
op = op->getParentOp();
}
return false;
}
bool isInsideHlfirForallOrWhere() const {
return isInsideOp<hlfir::ForallOp, hlfir::WhereOp>();
}
bool isInsideHlfirWhere() const { return isInsideOp<hlfir::WhereOp>(); }

void genFIR(const Fortran::parser::WhereConstruct &c) {
implicitIterSpace.growStack();
mlir::Location loc = getCurrentLocation();
hlfir::WhereOp whereOp;

if (!lowerToHighLevelFIR()) {
implicitIterSpace.growStack();
} else {
whereOp = builder->create<hlfir::WhereOp>(loc);
builder->createBlock(&whereOp.getMaskRegion());
}

// Lower the where mask. For HLFIR, this is done in the hlfir.where mask
// region.
genNestedStatement(
std::get<
Fortran::parser::Statement<Fortran::parser::WhereConstructStmt>>(
c.t));

// Lower WHERE body. For HLFIR, this is done in the hlfir.where body
// region.
if (whereOp)
builder->createBlock(&whereOp.getBody());

for (const auto &body :
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(c.t))
genFIR(body);
Expand All @@ -3484,6 +3506,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genNestedStatement(
std::get<Fortran::parser::Statement<Fortran::parser::EndWhereStmt>>(
c.t));

if (whereOp) {
// For HLFIR, create fir.end terminator in the last hlfir.elsewhere, or
// in the hlfir.where if it had no elsewhere.
builder->create<fir::FirEndOp>(loc);
builder->setInsertionPointAfter(whereOp);
}
}
void genFIR(const Fortran::parser::WhereBodyConstruct &body) {
std::visit(
Expand All @@ -3499,24 +3528,61 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
body.u);
}

/// Lower a Where or Elsewhere mask into an hlfir mask region.
void lowerWhereMaskToHlfir(mlir::Location loc,
const Fortran::semantics::SomeExpr *maskExpr) {
assert(maskExpr && "mask semantic analysis failed");
Fortran::lower::StatementContext maskContext;
hlfir::Entity mask = Fortran::lower::convertExprToHLFIR(
loc, *this, *maskExpr, localSymbols, maskContext);
mask = hlfir::loadTrivialScalar(loc, *builder, mask);
auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask);
genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext);
}
void genFIR(const Fortran::parser::WhereConstructStmt &stmt) {
implicitIterSpace.append(Fortran::semantics::GetExpr(
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr(
std::get<Fortran::parser::LogicalExpr>(stmt.t));
if (lowerToHighLevelFIR())
lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
else
implicitIterSpace.append(maskExpr);
}
void genFIR(const Fortran::parser::WhereConstruct::MaskedElsewhere &ew) {
mlir::Location loc = getCurrentLocation();
hlfir::ElseWhereOp elsewhereOp;
if (lowerToHighLevelFIR()) {
elsewhereOp = builder->create<hlfir::ElseWhereOp>(loc);
// Lower mask in the mask region.
builder->createBlock(&elsewhereOp.getMaskRegion());
}
genNestedStatement(
std::get<
Fortran::parser::Statement<Fortran::parser::MaskedElsewhereStmt>>(
ew.t));

// For HLFIR, lower the body in the hlfir.elsewhere body region.
if (elsewhereOp)
builder->createBlock(&elsewhereOp.getBody());

for (const auto &body :
std::get<std::list<Fortran::parser::WhereBodyConstruct>>(ew.t))
genFIR(body);
}
void genFIR(const Fortran::parser::MaskedElsewhereStmt &stmt) {
implicitIterSpace.append(Fortran::semantics::GetExpr(
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
const auto *maskExpr = Fortran::semantics::GetExpr(
std::get<Fortran::parser::LogicalExpr>(stmt.t));
if (lowerToHighLevelFIR())
lowerWhereMaskToHlfir(getCurrentLocation(), maskExpr);
else
implicitIterSpace.append(maskExpr);
}
void genFIR(const Fortran::parser::WhereConstruct::Elsewhere &ew) {
if (lowerToHighLevelFIR()) {
auto elsewhereOp =
builder->create<hlfir::ElseWhereOp>(getCurrentLocation());
builder->createBlock(&elsewhereOp.getBody());
}
genNestedStatement(
std::get<Fortran::parser::Statement<Fortran::parser::ElsewhereStmt>>(
ew.t));
Expand All @@ -3525,18 +3591,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genFIR(body);
}
void genFIR(const Fortran::parser::ElsewhereStmt &stmt) {
implicitIterSpace.append(nullptr);
if (!lowerToHighLevelFIR())
implicitIterSpace.append(nullptr);
}
void genFIR(const Fortran::parser::EndWhereStmt &) {
implicitIterSpace.shrinkStack();
if (!lowerToHighLevelFIR())
implicitIterSpace.shrinkStack();
}

void genFIR(const Fortran::parser::WhereStmt &stmt) {
Fortran::lower::StatementContext stmtCtx;
const auto &assign = std::get<Fortran::parser::AssignmentStmt>(stmt.t);
const auto *mask = Fortran::semantics::GetExpr(
std::get<Fortran::parser::LogicalExpr>(stmt.t));
if (lowerToHighLevelFIR()) {
mlir::Location loc = getCurrentLocation();
auto whereOp = builder->create<hlfir::WhereOp>(loc);
builder->createBlock(&whereOp.getMaskRegion());
lowerWhereMaskToHlfir(loc, mask);
builder->createBlock(&whereOp.getBody());
genAssignment(*assign.typedAssignment->v);
builder->create<fir::FirEndOp>(loc);
builder->setInsertionPointAfter(whereOp);
return;
}
implicitIterSpace.growStack();
implicitIterSpace.append(Fortran::semantics::GetExpr(
std::get<Fortran::parser::LogicalExpr>(stmt.t)));
implicitIterSpace.append(mask);
genAssignment(*assign.typedAssignment->v);
implicitIterSpace.shrinkStack();
}
Expand Down
170 changes: 170 additions & 0 deletions flang/test/Lower/HLFIR/where.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
! Test lowering of WHERE construct and statements to HLFIR.
! RUN: bbc --hlfir -emit-fir -o - %s | FileCheck %s

module where_defs
logical :: mask(10)
real :: x(10), y(10)
real, allocatable :: a(:), b(:)
interface
function return_temporary_mask()
logical, allocatable :: return_temporary_mask(:)
end function
function return_temporary_array()
real, allocatable :: return_temporary_array(:)
end function
end interface
end module

subroutine simple_where()
use where_defs, only: mask, x, y
where (mask) x = y
end subroutine
! CHECK-LABEL: func.func @_QPsimple_where() {
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Emask
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Ex
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ey
! CHECK: hlfir.where {
! CHECK: hlfir.yield %[[VAL_3]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
! CHECK: } do {
! CHECK: hlfir.region_assign {
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: } to {
! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: }
! CHECK: }
! CHECK: return
! CHECK:}

subroutine where_construct()
use where_defs
where (mask)
x = y
a = b
end where
end subroutine
! CHECK-LABEL: func.func @_QPwhere_construct() {
! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEa"}
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMwhere_defsEb"}
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
! CHECK: hlfir.where {
! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
! CHECK: } do {
! CHECK: hlfir.region_assign {
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: } to {
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: }
! CHECK: hlfir.region_assign {
! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
! CHECK: } to {
! CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
! CHECK: hlfir.yield %[[VAL_17]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
! CHECK: }
! CHECK: }
! CHECK: return
! CHECK:}

subroutine where_cleanup()
use where_defs, only: x, return_temporary_mask, return_temporary_array
where (return_temporary_mask()) x = return_temporary_array()
end subroutine
! CHECK-LABEL: func.func @_QPwhere_cleanup() {
! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = ".result"}
! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> {bindc_name = ".result"}
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex
! CHECK: hlfir.where {
! CHECK: %[[VAL_6:.*]] = fir.call @_QPreturn_temporary_mask() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
! CHECK: fir.save_result %[[VAL_6]] to %[[VAL_1]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>)
! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>>
! CHECK: hlfir.yield %[[VAL_8]] : !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>> cleanup {
! CHECK: fir.freemem
! CHECK: }
! CHECK: } do {
! CHECK: hlfir.region_assign {
! CHECK: %[[VAL_14:.*]] = fir.call @_QPreturn_temporary_array() fastmath<contract> : () -> !fir.box<!fir.heap<!fir.array<?xf32>>>
! CHECK: fir.save_result %[[VAL_14]] to %[[VAL_0]] : !fir.box<!fir.heap<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
! CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
! CHECK: hlfir.yield %[[VAL_16]] : !fir.box<!fir.heap<!fir.array<?xf32>>> cleanup {
! CHECK: fir.freemem
! CHECK: }
! CHECK: } to {
! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: }
! CHECK: }

subroutine simple_elsewhere()
use where_defs
where (mask)
x = y
elsewhere
y = x
end where
end subroutine
! CHECK-LABEL: func.func @_QPsimple_elsewhere() {
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Emask
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
! CHECK: hlfir.where {
! CHECK: hlfir.yield %[[VAL_7]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
! CHECK: } do {
! CHECK: hlfir.region_assign {
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: } to {
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: }
! CHECK: hlfir.elsewhere do {
! CHECK: hlfir.region_assign {
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: } to {
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: }
! CHECK: }
! CHECK: }

subroutine elsewhere_2(mask2)
use where_defs, only : mask, x, y
logical :: mask2(:)
where (mask)
x = y
elsewhere(mask2)
y = x
elsewhere
x = foo()
end where
end subroutine
! CHECK-LABEL: func.func @_QPelsewhere_2(
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Emask
! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare {{.*}}Emask2
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare {{.*}}Ex
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare {{.*}}Ey
! CHECK: hlfir.where {
! CHECK: hlfir.yield %[[VAL_5]]#0 : !fir.ref<!fir.array<10x!fir.logical<4>>>
! CHECK: } do {
! CHECK: hlfir.region_assign {
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: } to {
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: }
! CHECK: hlfir.elsewhere mask {
! CHECK: hlfir.yield %[[VAL_6]]#0 : !fir.box<!fir.array<?x!fir.logical<4>>>
! CHECK: } do {
! CHECK: hlfir.region_assign {
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: } to {
! CHECK: hlfir.yield %[[VAL_15]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: }
! CHECK: hlfir.elsewhere do {
! CHECK: hlfir.region_assign {
! CHECK: %[[VAL_16:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> f32
! CHECK: hlfir.yield %[[VAL_16]] : f32
! CHECK: } to {
! CHECK: hlfir.yield %[[VAL_11]]#0 : !fir.ref<!fir.array<10xf32>>
! CHECK: }
! CHECK: }
! CHECK: }
! CHECK: }

0 comments on commit 54c88fc

Please sign in to comment.