-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OpenMP][MLIR] Add private
clause to omp.target
#91202
Conversation
Starts the effort to support delayed privatization for `omp.target`. This PR extends the `omp.target` MLIR op with a `private` clause similar to what we currently have for `omp.parallel` in order to model privatized variables.
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-mlir-openmp Author: Kareem Ergawy (ergawy) ChangesStarts the effort to support delayed privatization for Full diff: https://github.com/llvm/llvm-project/pull/91202.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a40676d071e620..a641588eaa8d42 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1787,7 +1787,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
Variadic<OpenMP_PointerLikeType>:$has_device_addr,
- Variadic<AnyType>:$map_operands);
+ Variadic<AnyType>:$map_operands,
+ Variadic<AnyType>:$private_vars,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizers);
+
let regions = (region AnyRegion:$region);
let builders = [
@@ -1802,6 +1805,7 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
| `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
| `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+ | `private` `(` custom<PrivateList>($private_vars, type($private_vars), $privatizers) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
}];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0799090cdea981..cedcc40864d663 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -469,14 +469,18 @@ ParseResult parseClauseWithRegionArgs(
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
ValueRange argsSubrange,
StringRef clauseName, ValueRange operands,
- TypeRange types, ArrayAttr symbols) {
- p << clauseName << "(";
+ TypeRange types, ArrayAttr symbols,
+ bool printPrefixSuffix = true) {
+ if (printPrefixSuffix)
+ p << clauseName << "(";
+
llvm::interleaveComma(
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : " << type;
});
- p << ") ";
+ if (printPrefixSuffix)
+ p << ") ";
}
static ParseResult parseParallelRegion(
@@ -1048,6 +1052,48 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
}
}
+static ParseResult parsePrivateList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+ SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
+ SmallVector<SymbolRefAttr> privateSymRefs;
+ SmallVector<OpAsmParser::Argument> regionPrivateArgs;
+
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
+ parser.parseOperand(privateOperands.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseArgument(regionPrivateArgs.emplace_back()) ||
+ parser.parseColonType(privateOperandTypes.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
+ privateSymRefs.end());
+ privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
+
+ return success();
+}
+
+static void printPrivateList(OpAsmPrinter &p, Operation *op,
+ ValueRange privateVarOperands,
+ TypeRange privateVarTypes,
+ ArrayAttr privatizerSymbols) {
+ auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
+ assert(targetOp);
+
+ auto ®ion = op->getRegion(0);
+ auto *argsBegin = region.front().getArguments().begin();
+ MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
+ argsBegin + targetOp.getMapOperands().size() +
+ privateVarTypes.size());
+ printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVarOperands,
+ privateVarTypes, privatizerSymbols,
+ /*printPrefixSuffix=*/false);
+}
+
static void printCaptureType(OpAsmPrinter &p, Operation *op,
VariableCaptureKindAttr mapCaptureType) {
std::string typeCapStr;
@@ -1262,7 +1308,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
- clauses.mapVars);
+ clauses.mapVars, clauses.privateVars,
+ ArrayAttr::get(builder.getContext(), clauses.privatizers));
}
LogicalResult TargetOp::verify() {
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c6875..138c2c9d418dc3 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2087,7 +2087,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
- }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
+ }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b73..f0b76c117a4568 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0, 0>} : ( i1, si32, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2550,3 +2550,41 @@ func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !
}
return
}
+
+// CHECK-LABEL: omp_target_private
+func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+ %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+ %mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // CHECK: omp.target
+ // CHECK-SAME: private(
+ // CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : !llvm.ptr
+ // CHECK-SAME: )
+ omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+ // CHECK: ^bb0(%[[PRIV_ARG]]: !llvm.ptr):
+ ^bb0(%priv_arg: !llvm.ptr):
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+
+ // CHECK-SAME: map_entries(
+ // CHECK-SAME: %[[MAP1_VAR:[^[:space:]]+]] -> %[[MAP1_ARG:[^[:space:]]+]],
+ // CHECK-SAME: %[[MAP2_VAR:[^[:space:]]+]] -> %[[MAP2_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : memref<?xi32>, memref<?xi32>
+ // CHECK-SAME: )
+
+ // CHECK-SAME: private(
+ // CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : !llvm.ptr
+ // CHECK-SAME: )
+ omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+ // CHECK: ^bb0(%[[MAP1_ARG]]: memref<?xi32>, %[[MAP2_ARG]]: memref<?xi32>
+ // CHECK-SAME: , %[[PRIV_ARG]]: !llvm.ptr):
+ ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %priv_arg: !llvm.ptr):
+ omp.terminator
+ }
+
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Kareem! I think this generally looks good, I just have a couple of small comments.
|
||
auto ®ion = op->getRegion(0); | ||
auto *argsBegin = region.front().getArguments().begin(); | ||
MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we should have some common method to query OpenMP operations to return the index where certain groups of block arguments start. Because having operation-specific logic here doesn't look very scalable. I'm thinking something like this perhaps:
auto iface = cast<OpenMPOpWithEntryBlockArgs>(op);
int startIdx = iface.getPrivateArgsIndex();
if (startIdx < 0) { /* This instance doesn't have any privatization-related block arguments */ }
MutableArrayRef<BlockArgument> privateArgs(argsBegin + startIdx, argsBegin + startIdx + privateVarTypes.size());
...
That interface should be extensible to also allow queries for getReductionArgsIndex()
and possibly others in the future.
Thinking about it a bit more, maybe the get<Clause>BlockArgsIndex()
and get<Clause>BlockArgsSize()
functions would be part of the ops themselves and the interface could define get<Clause>Args()
functions to return references to the entry block arguments, based on calls to these two. Any operation implementing the interface must also implement these functions.
auto iface = cast<OpenMPOpWithEntryBlockArgs>(op);
ArrayRef<BlockArgument> privateArgs = iface.getPrivateArgs();
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All that is best suited for a follow-up patch, so not a blocking comment but just a suggestion to maybe discuss. We should be able to eventually use {parse, print}{Private,Reduction}List
in place of the {parse, print}{ParallelRegion,Wsloop}
functions as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That definitely needs clean-up and restructuring since omp.target
entered the picture. I can try out your suggestion and discuss on a follow up PR indeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, maybe it'd be nice to add a "TODO: Remove target-specific logic from this function" or similar comment there to avoid forgetting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the order of the block arguments guaranteed to be the same as the order of the operands of the op? This line of code expects the private arguments to appear after the map arguments which is certainly true for the op. but, is that guaranteed as the order of the block arguments also?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is that guaranteed as the order of the block arguments also?
The arguments to the op's region are the same as the arguments to the entry block of the region, see this. So I think if you guarantee the op args, the block args are also guaranteed to be correct consequently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, @ergawy. Some minor comments.
|
||
// CHECK: omp.target | ||
|
||
// CHECK-SAME: map_entries( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to check the map
clause? Your check could be
CHECK: omp.target {{.*}} private(@x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]] : !llvm.ptr)
The subsequent block arguments check could be
CHECK: ^bb0({{.*}, %[[PRIV_ARG}}: !llvm.ptr):
I dont think you need a string substitution block for PRIV_VAR
though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reduced the noise a little bit. But kept the map
clause check to make sure the order of the block arguments is correct: map
info args first and then the private variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aah makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Kareem for addressing my comments, this LGTM.
|
||
auto ®ion = op->getRegion(0); | ||
auto *argsBegin = region.front().getArguments().begin(); | ||
MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, maybe it'd be nice to add a "TODO: Remove target-specific logic from this function" or similar comment there to avoid forgetting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I just left a comment/question to aid my own understanding.
|
||
auto ®ion = op->getRegion(0); | ||
auto *argsBegin = region.front().getArguments().begin(); | ||
MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the order of the block arguments guaranteed to be the same as the order of the operands of the op? This line of code expects the private arguments to appear after the map arguments which is certainly true for the op. but, is that guaranteed as the order of the block arguments also?
|
||
// CHECK: omp.target | ||
|
||
// CHECK-SAME: map_entries( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aah makes sense.
Correct mistake with libcxx/include/version from rebase.
Starts the effort to support delayed privatization for
omp.target
. This PR extends theomp.target
MLIR op with aprivate
clause similar to what we currently have foromp.parallel
in order to model privatized variables.