Skip to content

Commit

Permalink
[flang][openacc] Add support for complex add reduction
Browse files Browse the repository at this point in the history
Add lowering support for reduction with the add operator
on complex type.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D155007
  • Loading branch information
clementval committed Jul 12, 2023
1 parent 787a5ef commit 119c512
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
22 changes: 16 additions & 6 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
#include "flang/Lower/OpenACC.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/Support/Utils.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/Todo.h"
Expand Down Expand Up @@ -712,11 +714,17 @@ static mlir::Value genReductionInitValue(fir::FirOpBuilder &builder,
loc, ty,
builder.getFloatAttr(ty,
getReductionInitValue<llvm::APFloat>(op, ty)));
} else {
if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty))
return builder.create<mlir::arith::ConstantOp>(
loc, ty,
builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
} else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) {
return builder.create<mlir::arith::ConstantOp>(
loc, ty,
builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
} else if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty)) {
mlir::Type floatTy =
Fortran::lower::convertReal(builder.getContext(), cmplxTy.getFKind());
mlir::Value init = builder.createRealConstant(
loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy));
return fir::factory::Complex{builder, loc}.createComplex(cmplxTy.getFKind(),
init, init);
}
if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(ty)) {
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(refTy.getEleTy())) {
Expand All @@ -738,7 +746,7 @@ static mlir::Value genReductionInitValue(fir::FirOpBuilder &builder,
}
}

TODO(loc, "reduction type");
llvm::report_fatal_error("Unsupported OpenACC reduction type");
}

template <typename Op>
Expand Down Expand Up @@ -808,6 +816,8 @@ static mlir::Value genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
if (mlir::isa<mlir::FloatType>(ty))
return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty))
return builder.create<fir::AddcOp>(loc, value1, value2);
TODO(loc, "reduction add type");
}

Expand Down
23 changes: 23 additions & 0 deletions flang/test/Lower/OpenACC/acc-reduction.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@

! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s

! CHECK-LABEL: acc.reduction.recipe @reduction_add_z32 : !fir.complex<4> reduction_operator <add> init {
! CHECK: ^bb0(%{{.*}}: !fir.complex<4>):
! CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
! CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<4>
! CHECK: %[[UNDEF1:.*]] = fir.insert_value %[[UNDEF]], %[[CST]], [0 : index] : (!fir.complex<4>, f32) -> !fir.complex<4>
! CHECK: %[[UNDEF2:.*]] = fir.insert_value %[[UNDEF1]], %[[CST]], [1 : index] : (!fir.complex<4>, f32) -> !fir.complex<4>
! CHECK: acc.yield %[[UNDEF2]] : !fir.complex<4>
! CHECK: } combiner {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<4>, %[[ARG1:.*]]: !fir.complex<4>):
! CHECK: %[[COMBINED:.*]] = fir.addc %[[ARG0]], %[[ARG1]] : !fir.complex<4>
! CHECK: acc.yield %[[COMBINED]] : !fir.complex<4>
! CHECK: }

! CHECK-LABEL: acc.reduction.recipe @reduction_neqv_l32 : !fir.logical<4> reduction_operator <neqv> init {
! CHECK: ^bb0(%{{.*}}: !fir.logical<4>):
! CHECK: %[[CST:.*]] = arith.constant false
Expand Down Expand Up @@ -729,3 +742,13 @@ subroutine acc_reduction_neqv()
! CHECK-LABEL: func.func @_QPacc_reduction_neqv()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
! CHECK: acc.parallel reduction(@reduction_neqv_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)

subroutine acc_reduction_add_cmplx()
complex :: c
!$acc parallel reduction(+:c)
!$acc end parallel
end subroutine

! CHECK-LABEL: func.func @_QPacc_reduction_add_cmplx()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.complex<4>>) -> !fir.ref<!fir.complex<4>> {name = "c"}
! CHECK: acc.parallel reduction(@reduction_add_z32 -> %[[RED]] : !fir.ref<!fir.complex<4>>)

0 comments on commit 119c512

Please sign in to comment.