Skip to content

Commit

Permalink
[flang][openacc] Convert rhs expr to the lhs type on atomic read/write (
Browse files Browse the repository at this point in the history
#70377)

In some cases the rhs expression scalar type is not the same as the lhs
type. A convert op is needed before the acc.atomic.read or
acc.atomic.write operation to fit with the requirements of the
operations.
  • Loading branch information
clementval committed Oct 27, 2023
1 parent 3333096 commit 8e463b3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 19 deletions.
34 changes: 15 additions & 19 deletions flang/lib/Lower/DirectivesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
mlir::Value toAddress = fir::getBase(converter.genExprAddr(
*Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
mlir::Location loc = converter.getCurrentLocation();
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
if (fromAddress.getType() != toAddress.getType())
fromAddress =
builder.create<fir::ConvertOp>(loc, toAddress.getType(), fromAddress);
genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress,
leftHandClauseList, rightHandClauseList,
elementType);
Expand Down Expand Up @@ -427,10 +432,12 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,

const Fortran::parser::AssignmentStmt &stmt1 =
std::get<typename AtomicT::Stmt1>(atomicCapture.t).v.statement;
const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v;
const auto &stmt1Var{std::get<Fortran::parser::Variable>(stmt1.t)};
const auto &stmt1Expr{std::get<Fortran::parser::Expr>(stmt1.t)};
const Fortran::parser::AssignmentStmt &stmt2 =
std::get<typename AtomicT::Stmt2>(atomicCapture.t).v.statement;
const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v;
const auto &stmt2Var{std::get<Fortran::parser::Variable>(stmt2.t)};
const auto &stmt2Expr{std::get<Fortran::parser::Expr>(stmt2.t)};

Expand All @@ -442,36 +449,25 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
mlir::Value stmt1LHSArg, stmt1RHSArg, stmt2LHSArg, stmt2RHSArg;
mlir::Type elementType;
// LHS evaluations are common to all combinations of `atomic.capture`
stmt1LHSArg = fir::getBase(
converter.genExprAddr(*Fortran::semantics::GetExpr(stmt1Var), stmtCtx));
stmt2LHSArg = fir::getBase(
converter.genExprAddr(*Fortran::semantics::GetExpr(stmt2Var), stmtCtx));
stmt1LHSArg = fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx));
stmt2LHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));

// Operation specific RHS evaluations
if (checkForSingleVariableOnRHS(stmt1)) {
// Atomic capture construct is of the form [capture-stmt, update-stmt] or
// of the form [capture-stmt, write-stmt]
stmt1RHSArg = fir::getBase(converter.genExprAddr(
*Fortran::semantics::GetExpr(stmt1Expr), stmtCtx));
stmt2RHSArg = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(stmt2Expr), stmtCtx));

stmt1RHSArg = fir::getBase(converter.genExprAddr(assign1.rhs, stmtCtx));
stmt2RHSArg = fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx));
} else {
// Atomic capture construct is of the form [update-stmt, capture-stmt]
stmt1RHSArg = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(stmt1Expr), stmtCtx));
stmt2RHSArg = fir::getBase(converter.genExprAddr(
*Fortran::semantics::GetExpr(stmt2Expr), stmtCtx));
stmt1RHSArg = fir::getBase(converter.genExprValue(assign1.rhs, stmtCtx));
stmt2RHSArg = fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx));
}
// Type information used in generation of `atomic.update` operation
mlir::Type stmt1VarType =
fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(stmt1Var), stmtCtx))
.getType();
fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType();
mlir::Type stmt2VarType =
fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(stmt2Var), stmtCtx))
.getType();
fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();

mlir::Operation *atomicCaptureOp = nullptr;
if constexpr (std::is_same<AtomicListT,
Expand Down
25 changes: 25 additions & 0 deletions flang/test/Lower/OpenACC/acc-atomic-capture.f90
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,28 @@ subroutine pointers_in_atomic_capture()
b = a
!$acc end atomic
end subroutine


subroutine capture_with_convert_f32_to_i32()
implicit none
integer :: k, v, i

k = 1
v = 0

!$acc atomic capture
v = k
k = (i + 1) * 3.14
!$acc end atomic
end subroutine

! CHECK-LABEL: func.func @_QPcapture_with_convert_f32_to_i32()
! CHECK: %[[K:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFcapture_with_convert_f32_to_i32Ek"}
! CHECK: %[[V:.*]] = fir.alloca i32 {bindc_name = "v", uniq_name = "_QFcapture_with_convert_f32_to_i32Ev"}
! CHECK: %[[CST:.*]] = arith.constant 3.140000e+00 : f32
! CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %[[CST]] fastmath<contract> : f32
! CHECK: %[[CONV:.*]] = fir.convert %[[MUL]] : (f32) -> i32
! CHECK: acc.atomic.capture {
! CHECK: acc.atomic.read %[[V]] = %[[K]] : !fir.ref<i32>, i32
! CHECK: acc.atomic.write %[[K]] = %[[CONV]] : !fir.ref<i32>, i32
! CHECK: }
14 changes: 14 additions & 0 deletions flang/test/Lower/OpenACC/acc-atomic-read.f90
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,17 @@ subroutine atomic_read_pointer()

x = y
end

subroutine atomic_read_with_convert()
integer(4) :: x
integer(8) :: y

!$acc atomic read
y = x
end

! CHECK-LABEL: func.func @_QPatomic_read_with_convert() {
! CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFatomic_read_with_convertEx"}
! CHECK: %[[Y:.*]] = fir.alloca i64 {bindc_name = "y", uniq_name = "_QFatomic_read_with_convertEy"}
! CHECK: %[[CONV:.*]] = fir.convert %[[X]] : (!fir.ref<i32>) -> !fir.ref<i64>
! CHECK: acc.atomic.read %[[Y]] = %[[CONV]] : !fir.ref<i64>, i32

0 comments on commit 8e463b3

Please sign in to comment.