Skip to content

Commit

Permalink
[flang][openacc] Lower reduction for compute constructs
Browse files Browse the repository at this point in the history
Parallel and serial constructs support reduction clause. Extend
recent D151564 loop reduction clause support to also include these
compute constructs.

Reviewed By: clementval, vzakhari

Differential Revision: https://reviews.llvm.org/D151955
  • Loading branch information
Razvan Lupusoru committed Jun 7, 2023
1 parent ea63b39 commit 5e3faa0
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 5 deletions.
12 changes: 9 additions & 3 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,

llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations;
llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;

// Async, wait and self clause have optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
Expand Down Expand Up @@ -1151,8 +1151,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
&clause.u)) {
genObjectList(firstprivateClause->v, converter, semanticsContext, stmtCtx,
firstprivateOperands);
} else if (std::get_if<Fortran::parser::AccClause::Reduction>(&clause.u)) {
TODO(clauseLocation, "compute construct reduction clause lowering");
} else if (const auto *reductionClause =
std::get_if<Fortran::parser::AccClause::Reduction>(
&clause.u)) {
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
reductionOperands, reductionRecipes);
}
}

Expand Down Expand Up @@ -1194,6 +1197,9 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (!privatizations.empty())
computeOp.setPrivatizationsAttr(
mlir::ArrayAttr::get(builder.getContext(), privatizations));
if (!reductionRecipes.empty())
computeOp.setReductionRecipesAttr(
mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
}

auto insPt = builder.saveInsertionPoint();
Expand Down
16 changes: 16 additions & 0 deletions flang/test/Lower/OpenACC/acc-kernels-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ subroutine acc_kernels_loop
real, dimension(n) :: a, b, c
real, dimension(n, n) :: d, e
real, pointer :: f, g
integer :: reduction_i
real :: reduction_r

integer :: gangNum = 8
integer :: gangStatic = 8
Expand Down Expand Up @@ -709,6 +711,20 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

!$acc kernels loop reduction(+:reduction_r) reduction(*:reduction_i)
do i = 1, n
reduction_r = reduction_r + a(i)
reduction_i = 1
end do

! CHECK: acc.kernels {
! CHECK: acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
! CHECK: fir.do_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

end subroutine
13 changes: 13 additions & 0 deletions flang/test/Lower/OpenACC/acc-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ program acc_loop
integer :: gangStatic = 8
integer :: vectorLength = 128
integer, parameter :: tileSize = 2
integer :: reduction_i
real :: reduction_r


!$acc loop
Expand Down Expand Up @@ -270,4 +272,15 @@ program acc_loop
!CHECK: acc.yield
!CHECK-NEXT: }{{$}}

!$acc loop reduction(+:reduction_r) reduction(*:reduction_i)
do i = 1, n
reduction_r = reduction_r + a(i)
reduction_i = 1
end do

! CHECK: acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
! CHECK: fir.do_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

end program
16 changes: 16 additions & 0 deletions flang/test/Lower/OpenACC/acc-parallel-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ subroutine acc_parallel_loop
real, dimension(n) :: a, b, c
real, dimension(n, n) :: d, e
real, pointer :: f, g
integer :: reduction_i
real :: reduction_r

integer :: gangNum = 8
integer :: gangStatic = 8
Expand Down Expand Up @@ -729,6 +731,20 @@ subroutine acc_parallel_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

!$acc parallel loop reduction(+:reduction_r) reduction(*:reduction_i)
do i = 1, n
reduction_r = reduction_r + a(i)
reduction_i = 1
end do

! CHECK: acc.parallel reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
! CHECK: acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
! CHECK: fir.do_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

end subroutine acc_parallel_loop
9 changes: 9 additions & 0 deletions flang/test/Lower/OpenACC/acc-parallel.f90
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ subroutine acc_parallel
logical :: ifCondition = .TRUE.
real, dimension(10, 10) :: a, b, c
real, pointer :: d, e
integer :: reduction_i
real :: reduction_r

!CHECK: %[[A:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
!CHECK: %[[B:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
Expand Down Expand Up @@ -302,4 +304,11 @@ subroutine acc_parallel
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

!$acc parallel reduction(+:reduction_r) reduction(*:reduction_i)
!$acc end parallel

! CHECK: acc.parallel reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

end subroutine acc_parallel
16 changes: 16 additions & 0 deletions flang/test/Lower/OpenACC/acc-serial-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ subroutine acc_serial_loop
real, dimension(n) :: a, b, c
real, dimension(n, n) :: d, e
real, pointer :: f, g
integer :: reduction_i
real :: reduction_r

integer :: gangNum = 8
integer :: gangStatic = 8
Expand Down Expand Up @@ -645,6 +647,20 @@ subroutine acc_serial_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

!$acc serial loop reduction(+:reduction_r) reduction(*:reduction_i)
do i = 1, n
reduction_r = reduction_r + a(i)
reduction_i = 1
end do

! CHECK: acc.serial reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
! CHECK: acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
! CHECK: fir.do_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

end subroutine acc_serial_loop
9 changes: 9 additions & 0 deletions flang/test/Lower/OpenACC/acc-serial.f90
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ subroutine acc_serial
logical :: ifCondition = .TRUE.
real, dimension(10, 10) :: a, b, c
real, pointer :: d, e
integer :: reduction_i
real :: reduction_r

! CHECK: %[[A:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
! CHECK: %[[B:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
Expand Down Expand Up @@ -245,4 +247,11 @@ subroutine acc_serial
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

!$acc serial reduction(+:reduction_r) reduction(*:reduction_i)
!$acc end serial

! CHECK: acc.serial reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

end subroutine
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ LogicalResult acc::ParallelOp::verify() {
return failure();
if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions")))
"reductions", false)))
return failure();
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
Expand Down Expand Up @@ -586,7 +586,7 @@ LogicalResult acc::SerialOp::verify() {
return failure();
if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions")))
"reductions", false)))
return failure();
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
}
Expand Down

0 comments on commit 5e3faa0

Please sign in to comment.