Skip to content

Commit

Permalink
[mlir][openacc] Use new private representation in acc.parallel
Browse files Browse the repository at this point in the history
Update acc.parallel private operands list to use the new design
introduced in D150622.

Test in flang/test/Lower/OpenACC/acc-parallel.f90 and
flang/test/Lower/OpenACC/acc-parallel-loop.f90 are temporarly
disabled and will be enabled with updated lowering in the follow-up
patch.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D150971
  • Loading branch information
clementval committed May 22, 2023
1 parent c606fef commit c067c6e
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 23 deletions.
25 changes: 13 additions & 12 deletions flang/test/Lower/OpenACC/acc-parallel-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -442,18 +442,19 @@ subroutine acc_parallel_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}

!$acc parallel loop private(a) firstprivate(b)
DO i = 1, n
a(i) = b(i)
END DO

! CHECK: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10xf32>>) private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
! CHECK: acc.loop private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
! CHECK: fir.do_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! TODO: will be updated after lowering change in privatization to MLIR
! !$acc parallel loop private(a) firstprivate(b)
! DO i = 1, n
! a(i) = b(i)
! END DO

! TODO: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10xf32>>) private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
! TODO: acc.loop private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
! TODO: fir.do_loop
! TODO: acc.yield
! TODO-NEXT: }{{$}}
! TODO: acc.yield
! TODO-NEXT: }{{$}}

!$acc parallel loop seq
DO i = 1, n
Expand Down
11 changes: 6 additions & 5 deletions flang/test/Lower/OpenACC/acc-parallel.f90
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,12 @@ subroutine acc_parallel
!CHECK: acc.detach accPtr(%[[ATTACH_D]] : !fir.ptr<f32>) {dataClause = 10 : i64, name = "d"}
!CHECK: acc.detach accPtr(%[[ATTACH_E]] : !fir.ptr<f32>) {dataClause = 10 : i64, name = "e"}

!$acc parallel private(a) firstprivate(b) private(c)
!$acc end parallel
! TODO: will be updated after lowering change in privatization to MLIR
! !$acc parallel private(a) firstprivate(b) private(c)
! !$acc end parallel

!CHECK: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10x10xf32>>) private(%[[A]], %[[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) {
!CHECK: acc.yield
!CHECK-NEXT: }{{$}}
!TODO: acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10x10xf32>>) private(%[[A]], %[[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) {
!TODO: acc.yield
!TODO-NEXT: }{{$}}

end subroutine acc_parallel
7 changes: 5 additions & 2 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
UnitAttr:$selfAttr,
OptionalAttr<OpenACC_ReductionOperatorAttr>:$reductionOp,
Variadic<AnyType>:$reductionOperands,
Variadic<AnyType>:$gangPrivateOperands,
Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<AnyType>:$gangFirstPrivateOperands,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr);
Expand All @@ -659,7 +660,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
type($gangFirstPrivateOperands) `)`
| `num_gangs` `(` $numGangs `:` type($numGangs) `)`
| `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
| `private` `(` $gangPrivateOperands `:` type($gangPrivateOperands) `)`
| `private` `(` custom<PrivatizationList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
`)`
| `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `self` `(` $selfCond `)`
Expand Down
79 changes: 79 additions & 0 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,43 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
return success();
}

//===----------------------------------------------------------------------===//
// Custom parser and printer verifier for private clause
//===----------------------------------------------------------------------===//

static ParseResult parsePrivatizationList(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &privatizationSymbols) {
llvm::SmallVector<SymbolRefAttr> privatizationVec;
if (failed(parser.parseCommaSeparatedList([&]() {
if (parser.parseAttribute(privatizationVec.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
return success();
})))
return failure();
llvm::SmallVector<mlir::Attribute> privatizations(privatizationVec.begin(),
privatizationVec.end());
privatizationSymbols = ArrayAttr::get(parser.getContext(), privatizations);
return success();
}

static void
printPrivatizationList(mlir::OpAsmPrinter &p, mlir::Operation *op,
mlir::OperandRange privateOperands,
mlir::TypeRange privateTypes,
std::optional<mlir::ArrayAttr> privatizations) {
for (unsigned i = 0, e = privatizations->size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << (*privatizations)[i] << " -> " << privateOperands[i] << " : "
<< privateOperands[i].getType();
}
}

//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
Expand All @@ -455,6 +492,45 @@ static LogicalResult checkDataOperands(Op op,
return success();
}

static LogicalResult
checkPrivatizationList(Operation *op,
std::optional<mlir::ArrayAttr> privatizations,
mlir::OperandRange privateOperands) {
if (!privateOperands.empty()) {
if (!privatizations || privatizations->size() != privateOperands.size())
return op->emitOpError() << "expected as many privatizations symbol "
"reference as private operands";
} else {
if (privatizations)
return op->emitOpError() << "unexpected privatizations symbol reference";
return success();
}

llvm::DenseSet<Value> privates;
for (auto args : llvm::zip(privateOperands, *privatizations)) {
mlir::Value privateOperand = std::get<0>(args);

if (!privates.insert(privateOperand).second)
return op->emitOpError() << "private operand appears more than once";

mlir::Type varType = privateOperand.getType();
auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
auto decl =
SymbolTable::lookupNearestSymbolFrom<PrivateRecipeOp>(op, symbolRef);
if (!decl)
return op->emitOpError() << "expected symbol reference " << symbolRef
<< " to point to a private declaration";

if (decl.getType() && decl.getType() != varType)
return op->emitOpError()
<< "expected private (" << varType
<< ") to be the same type as private declaration ("
<< decl.getType() << ")";
}

return success();
}

unsigned ParallelOp::getNumDataOperands() {
return getReductionOperands().size() + getGangPrivateOperands().size() +
getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
Expand All @@ -471,6 +547,9 @@ Value ParallelOp::getDataOperand(unsigned i) {
}

LogicalResult acc::ParallelOp::verify() {
if (failed(checkPrivatizationList(*this, getPrivatizations(),
getGangPrivateOperands())))
return failure();
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}

Expand Down
38 changes: 34 additions & 4 deletions mlir/test/Dialect/OpenACC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ func.func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x

// -----

acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
^bb0(%arg0: memref<10xf32>):
%0 = memref.alloc() : memref<10xf32>
acc.yield %0 : memref<10xf32>
} destroy {
^bb0(%arg0: memref<10xf32>):
memref.dealloc %arg0 : memref<10xf32>
acc.terminator
}

func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
Expand All @@ -126,7 +136,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
%pc = acc.present varPtr(%c : memref<10xf32>) -> memref<10xf32>
%pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32>
acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) {
acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %c : memref<10xf32>) {
acc.loop gang {
scf.for %x = %lb to %c10 step %st {
acc.loop worker {
Expand Down Expand Up @@ -168,7 +178,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
// CHECK-NEXT: [[NUMGANG:%.*]] = arith.constant 10 : i64
// CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64
// CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private([[ARG2]] : memref<10xf32>) {
// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> [[ARG2]] : memref<10xf32>) {
// CHECK-NEXT: acc.loop gang {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: acc.loop worker {
Expand Down Expand Up @@ -358,6 +368,26 @@ func.func @acc_loop_multiple_block() {

// -----

acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
^bb0(%arg0: memref<10xf32>):
%0 = memref.alloc() : memref<10xf32>
acc.yield %0 : memref<10xf32>
} destroy {
^bb0(%arg0: memref<10xf32>):
memref.dealloc %arg0 : memref<10xf32>
acc.terminator
}

acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
^bb0(%arg0: memref<10x10xf32>):
%0 = memref.alloc() : memref<10x10xf32>
acc.yield %0 : memref<10x10xf32>
} destroy {
^bb0(%arg0: memref<10x10xf32>):
memref.dealloc %arg0 : memref<10x10xf32>
acc.terminator
}

func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
%i64value = arith.constant 1 : i64
%i32value = arith.constant 1 : i32
Expand Down Expand Up @@ -394,7 +424,7 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
}
acc.parallel vector_length(%idxValue: index) {
}
acc.parallel private(%a, %c : memref<10xf32>, memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
acc.parallel private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
}
acc.parallel {
} attributes {defaultAttr = #acc<defaultvalue none>}
Expand Down Expand Up @@ -445,7 +475,7 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
// CHECK-NEXT: }
// CHECK: acc.parallel vector_length([[IDXVALUE]] : index) {
// CHECK-NEXT: }
// CHECK: acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private([[ARGA]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
// CHECK: acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
// CHECK-NEXT: }
// CHECK: acc.parallel {
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
Expand Down

0 comments on commit c067c6e

Please sign in to comment.