Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenMP][MLIR] Add private clause to omp.target #91202

Merged
merged 7 commits into from
May 10, 2024

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented May 6, 2024

Starts the effort to support delayed privatization for omp.target. This PR extends the omp.target MLIR op with a private clause similar to what we currently have for omp.parallel in order to model privatized variables.

Starts the effort to support delayed privatization for `omp.target`.
This PR extends the `omp.target` MLIR op with a `private` clause similar
to what we currently have for `omp.parallel` in order to model
privatized variables.
@llvmbot
Copy link
Collaborator

llvmbot commented May 6, 2024

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Kareem Ergawy (ergawy)

Changes

Starts the effort to support delayed privatization for omp.target. This PR extends the omp.target MLIR op with a private clause similar to what we currently have for omp.parallel in order to model privatized variables.


Full diff: https://github.com/llvm/llvm-project/pull/91202.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+5-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+51-4)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+39-1)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a40676d071e620..a641588eaa8d42 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1787,7 +1787,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
                        UnitAttr:$nowait,
                        Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
                        Variadic<OpenMP_PointerLikeType>:$has_device_addr,
-                       Variadic<AnyType>:$map_operands);
+                       Variadic<AnyType>:$map_operands,
+                       Variadic<AnyType>:$private_vars,
+                       OptionalAttr<SymbolRefArrayAttr>:$privatizers);
+
   let regions = (region AnyRegion:$region);
 
   let builders = [
@@ -1802,6 +1805,7 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
     | `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
     | `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
     | `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+    | `private` `(` custom<PrivateList>($private_vars, type($private_vars), $privatizers) `)`
     | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) $region attr-dict
   }];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0799090cdea981..cedcc40864d663 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -469,14 +469,18 @@ ParseResult parseClauseWithRegionArgs(
 static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
                                       ValueRange argsSubrange,
                                       StringRef clauseName, ValueRange operands,
-                                      TypeRange types, ArrayAttr symbols) {
-  p << clauseName << "(";
+                                      TypeRange types, ArrayAttr symbols,
+                                      bool printPrefixSuffix = true) {
+  if (printPrefixSuffix)
+    p << clauseName << "(";
+
   llvm::interleaveComma(
       llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
         auto [sym, op, arg, type] = t;
         p << sym << " " << op << " -> " << arg << " : " << type;
       });
-  p << ") ";
+  if (printPrefixSuffix)
+    p << ") ";
 }
 
 static ParseResult parseParallelRegion(
@@ -1048,6 +1052,48 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
   }
 }
 
+static ParseResult parsePrivateList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+    SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
+  SmallVector<SymbolRefAttr> privateSymRefs;
+  SmallVector<OpAsmParser::Argument> regionPrivateArgs;
+
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
+            parser.parseOperand(privateOperands.emplace_back()) ||
+            parser.parseArrow() ||
+            parser.parseArgument(regionPrivateArgs.emplace_back()) ||
+            parser.parseColonType(privateOperandTypes.emplace_back()))
+          return failure();
+        return success();
+      })))
+    return failure();
+
+  SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
+                                         privateSymRefs.end());
+  privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
+
+  return success();
+}
+
+static void printPrivateList(OpAsmPrinter &p, Operation *op,
+                             ValueRange privateVarOperands,
+                             TypeRange privateVarTypes,
+                             ArrayAttr privatizerSymbols) {
+  auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
+  assert(targetOp);
+
+  auto &region = op->getRegion(0);
+  auto *argsBegin = region.front().getArguments().begin();
+  MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
+                               argsBegin + targetOp.getMapOperands().size() +
+                                   privateVarTypes.size());
+  printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVarOperands,
+                            privateVarTypes, privatizerSymbols,
+                            /*printPrefixSuffix=*/false);
+}
+
 static void printCaptureType(OpAsmPrinter &p, Operation *op,
                              VariableCaptureKindAttr mapCaptureType) {
   std::string typeCapStr;
@@ -1262,7 +1308,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
       builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
       makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
       clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
-      clauses.mapVars);
+      clauses.mapVars, clauses.privateVars,
+      ArrayAttr::get(builder.getContext(), clauses.privatizers));
 }
 
 LogicalResult TargetOp::verify() {
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c6875..138c2c9d418dc3 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2087,7 +2087,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b73..f0b76c117a4568 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0, 0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2550,3 +2550,41 @@ func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !
   }
   return
 }
+
+// CHECK-LABEL: omp_target_private
+func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+  %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+  %mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // CHECK: omp.target
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+  // CHECK: ^bb0(%[[PRIV_ARG]]: !llvm.ptr):
+  ^bb0(%priv_arg: !llvm.ptr):
+    omp.terminator
+  }
+
+  // CHECK: omp.target
+
+  // CHECK-SAME: map_entries(
+  // CHECK-SAME:   %[[MAP1_VAR:[^[:space:]]+]] -> %[[MAP1_ARG:[^[:space:]]+]],
+  // CHECK-SAME:   %[[MAP2_VAR:[^[:space:]]+]] -> %[[MAP2_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : memref<?xi32>, memref<?xi32>
+  // CHECK-SAME: )
+
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+  // CHECK: ^bb0(%[[MAP1_ARG]]: memref<?xi32>, %[[MAP2_ARG]]: memref<?xi32>
+  // CHECK-SAME: , %[[PRIV_ARG]]: !llvm.ptr):
+  ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %priv_arg: !llvm.ptr):
+    omp.terminator
+  }
+
+  return
+}

Copy link
Contributor

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Kareem! I think this generally looks good, I just have a couple of small comments.

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp Show resolved Hide resolved

auto &region = op->getRegion(0);
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we should have some common method to query OpenMP operations to return the index where certain groups of block arguments start. Because having operation-specific logic here doesn't look very scalable. I'm thinking something like this perhaps:

auto iface = cast<OpenMPOpWithEntryBlockArgs>(op);
int startIdx = iface.getPrivateArgsIndex();
if (startIdx < 0) { /* This instance doesn't have any privatization-related block arguments */ }
MutableArrayRef<BlockArgument> privateArgs(argsBegin + startIdx, argsBegin + startIdx + privateVarTypes.size());
...

That interface should be extensible to also allow queries for getReductionArgsIndex() and possibly others in the future.

Thinking about it a bit more, maybe the get<Clause>BlockArgsIndex() and get<Clause>BlockArgsSize() functions would be part of the ops themselves and the interface could define get<Clause>Args() functions to return references to the entry block arguments, based on calls to these two. Any operation implementing the interface must also implement these functions.

auto iface = cast<OpenMPOpWithEntryBlockArgs>(op);
ArrayRef<BlockArgument> privateArgs = iface.getPrivateArgs();
...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All that is best suited for a follow-up patch, so not a blocking comment but just a suggestion to maybe discuss. We should be able to eventually use {parse, print}{Private,Reduction}List in place of the {parse, print}{ParallelRegion,Wsloop} functions as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That definitely needs clean-up and restructuring since omp.target entered the picture. I can try out your suggestion and discuss on a follow up PR indeed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, maybe it'd be nice to add a "TODO: Remove target-specific logic from this function" or similar comment there to avoid forgetting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the order of the block arguments guaranteed to be the same as the order of the operands of the op? This line of code expects the private arguments to appear after the map arguments which is certainly true for the op. but, is that guaranteed as the order of the block arguments also?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that guaranteed as the order of the block arguments also?

The arguments to the op's region are the same as the arguments to the entry block of the region, see this. So I think if you guarantee the op args, the block args are also guaranteed to be correct consequently.

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@bhandarkar-pranav bhandarkar-pranav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, @ergawy. Some minor comments.

mlir/test/Dialect/OpenMP/ops.mlir Outdated Show resolved Hide resolved
mlir/test/Dialect/OpenMP/ops.mlir Outdated Show resolved Hide resolved

// CHECK: omp.target

// CHECK-SAME: map_entries(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to check the map clause? Your check could be
CHECK: omp.target {{.*}} private(@x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]] : !llvm.ptr)
The subsequent block arguments check could be
CHECK: ^bb0({{.*}, %[[PRIV_ARG}}: !llvm.ptr):

I dont think you need a string substitution block for PRIV_VAR though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reduced the noise a little bit. But kept the map clause check to make sure the order of the block arguments is correct: map info args first and then the private variables.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah makes sense.

Copy link
Contributor

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Kareem for addressing my comments, this LGTM.


auto &region = op->getRegion(0);
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, maybe it'd be nice to add a "TODO: Remove target-specific logic from this function" or similar comment there to avoid forgetting.

Copy link
Contributor

@bhandarkar-pranav bhandarkar-pranav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I just left a comment/question to aid my own understanding.


auto &region = op->getRegion(0);
auto *argsBegin = region.front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the order of the block arguments guaranteed to be the same as the order of the operands of the op? This line of code expects the private arguments to appear after the map arguments which is certainly true for the op. but, is that guaranteed as the order of the block arguments also?


// CHECK: omp.target

// CHECK-SAME: map_entries(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah makes sense.

@ergawy ergawy merged commit 427beff into llvm:main May 10, 2024
4 checks passed
hawkinsw added a commit to hawkinsw/llvm-project that referenced this pull request May 10, 2024
Correct mistake with libcxx/include/version from rebase.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants