Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][flang][openacc] Add device_type support for update op #78764

Merged
merged 2 commits into from
Jan 25, 2024

Conversation

clementval
Copy link
Contributor

Add support for device_type information on the acc.update operation and update lowering from Flang.

@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openacc flang:fir-hlfir openacc labels Jan 19, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 19, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Add support for device_type information on the acc.update operation and update lowering from Flang.


Patch is 29.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78764.diff

6 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+43-37)
  • (modified) flang/test/Lower/OpenACC/acc-update.f90 (+8-13)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+39-10)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+189-40)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+6-6)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 682ca06cabd6f6..541ea2e114324f 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2840,27 +2840,42 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static inline mlir::ArrayAttr
+getArrayAttr(fir::FirOpBuilder &b,
+             llvm::SmallVector<mlir::Attribute> &attributes) {
+  return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
+}
+
+static inline mlir::DenseI32ArrayAttr
+getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
+                     llvm::SmallVector<int32_t> &values) {
+  return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
+}
+
 static void
 genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
                mlir::Location currentLocation,
                Fortran::semantics::SemanticsContext &semanticsContext,
                Fortran::lower::StatementContext &stmtCtx,
                const Fortran::parser::AccClauseList &accClauseList) {
-  mlir::Value ifCond, async, waitDevnum;
+  mlir::Value ifCond, waitDevnum;
   llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
-      waitOperands, deviceTypeOperands;
-  llvm::SmallVector<mlir::Attribute> deviceTypes;
-
-  // Async and wait clause have optional values but can be present with
-  // no value as well. When there is no value, the op has an attribute to
-  // represent the clause.
-  bool addAsyncAttr = false;
-  bool addWaitAttr = false;
-  bool addIfPresentAttr = false;
+      waitOperands, deviceTypeOperands, asyncOperands;
+  llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
+      asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+  llvm::SmallVector<int32_t> waitOperandsSegments;
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
-  // Lower clauses values mapped to operands.
+  // device_type attribute is set to `none` until a device_type clause is
+  // encountered.
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+  crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None));
+
+  bool ifPresent = false;
+
+  // Lower clauses values mapped to operands and array attributes.
   // Keep track of each group of operands separately as clauses can appear
   // more than once.
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2870,15 +2885,19 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
     } else if (const auto *asyncClause =
                    std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
-      genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+      genAsyncClause(converter, asyncClause, asyncOperands,
+                     asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
+                     crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
-      genWaitClause(converter, waitClause, waitOperands, waitDevnum,
-                    addWaitAttr, stmtCtx);
+      genWaitClause(converter, waitClause, waitOperands,
+                    waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     } else if (const auto *hostClause =
                    std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
       genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2892,7 +2911,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
           dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
           /*implicit=*/false);
     } else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
-      addIfPresentAttr = true;
+      ifPresent = true;
     } else if (const auto *selfClause =
                    std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
       const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
@@ -2909,30 +2928,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
 
   dataClauseOperands.append(updateHostOperands);
 
-  // Prepare the operand segment size attribute and the operands value range.
-  llvm::SmallVector<mlir::Value> operands;
-  llvm::SmallVector<int32_t> operandSegments;
-  addOperand(operands, operandSegments, ifCond);
-  addOperand(operands, operandSegments, async);
-  addOperand(operands, operandSegments, waitDevnum);
-  addOperands(operands, operandSegments, waitOperands);
-  addOperands(operands, operandSegments, dataClauseOperands);
-
-  mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
-      builder, currentLocation, operands, operandSegments);
-  if (!deviceTypes.empty())
-    updateOp.setDeviceTypesAttr(
-        mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
+  builder.create<mlir::acc::UpdateOp>(
+      currentLocation, ifCond, asyncOperands,
+      getArrayAttr(builder, asyncOperandsDeviceTypes),
+      getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+      getDenseI32ArrayAttr(builder, waitOperandsSegments),
+      getArrayAttr(builder, waitOperandsDeviceTypes),
+      getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
+      ifPresent);
 
   genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
       builder, updateHostOperands, /*structured=*/false);
-
-  if (addAsyncAttr)
-    updateOp.setAsyncAttr(builder.getUnitAttr());
-  if (addWaitAttr)
-    updateOp.setWaitAttr(builder.getUnitAttr());
-  if (addIfPresentAttr)
-    updateOp.setIfPresentAttr(builder.getUnitAttr());
 }
 
 static void
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index d2b15f8bd258e7..ac7a56c56b1f20 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -61,17 +61,17 @@ subroutine acc_update
 
   !$acc update host(a) async
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
+! CHECK: acc.update async dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
+! CHECK: acc.update wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) async wait
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
+! CHECK: acc.update async wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) async(1)
@@ -89,14 +89,14 @@ subroutine acc_update
   !$acc update host(a) wait(1)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait([[WAIT1]] : i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT1]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait(queues: 1, 2)
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) wait(devnum: 1: queues: 1, 2)
@@ -104,17 +104,12 @@ subroutine acc_update
 ! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
 ! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
 ! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
-  !$acc update host(a) device_type(default, host)
+  !$acc update host(a) device_type(host, nvidia) async
 ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]} 
-! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
-
-  !$acc update host(a) device_type(*)
-! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<star>]} 
+! CHECK: acc.update async([#acc.device_type<host>, #acc.device_type<nvidia>]) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
 ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
 end subroutine acc_update
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 7344ab2852b9ce..5b678e84b93ee4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2187,14 +2187,16 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
   }];
 
   let arguments = (ins Optional<I1>:$ifCond,
-                       Optional<IntOrIndex>:$asyncOperand,
-                       Optional<IntOrIndex>:$waitDevnum,
-                       Variadic<IntOrIndex>:$waitOperands,
-                       UnitAttr:$async,
-                       UnitAttr:$wait,
-                       OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
-                       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-                       UnitAttr:$ifPresent);
+      Variadic<IntOrIndex>:$asyncOperands,
+      OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$async,
+      Optional<IntOrIndex>:$waitDevnum,
+      Variadic<IntOrIndex>:$waitOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+      OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$wait,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+      UnitAttr:$ifPresent);
 
   let extraClassDeclaration = [{
     /// The number of data operands.
@@ -2202,14 +2204,41 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
 
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
+
+    /// Return true if the op has the async attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasAsync();
+    /// Return true if the op has the async attribute for the given device_type.
+    bool hasAsync(mlir::acc::DeviceType deviceType);
+    /// Return the value of the async clause if present.
+    mlir::Value getAsyncValue();
+    /// Return the value of the async clause for the given device_type if
+    /// present.
+    mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+    /// Return true if the op has the wait attribute for the
+    /// mlir::acc::DeviceType::None device_type.
+    bool hasWait();
+    /// Return true if the op has the wait attribute for the given device_type.
+    bool hasWait(mlir::acc::DeviceType deviceType);
+    /// Return the values of the wait clause if present.
+    mlir::Operation::operand_range getWaitValues();
+    /// Return the values of the wait clause for the given device_type if
+    /// present.
+    mlir::Operation::operand_range
+    getWaitValues(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     oilist(
         `if` `(` $ifCond `)`
-      | `async` `(` $asyncOperand `:` type($asyncOperand) `)`
+      | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+            $asyncOperands, type($asyncOperands),
+            $asyncOperandsDeviceType, $async)
       | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
-      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+      | `wait` `` custom<WaitClause>($waitOperands,
+            type($waitOperands), $waitOperandsDeviceType, 
+            $waitOperandsSegments, $wait)
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
     attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bc03adbcae64df..4e31f7b163b9dc 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -936,6 +936,138 @@ static void printDeviceTypeOperandsWithSegment(
   });
 }
 
+static ParseResult parseWaitClause(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) {
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs;
+  llvm::SmallVector<int32_t> seg;
+
+  bool needCommaBeforeOperands = false;
+
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(keywordAttrs.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
+
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  do {
+    if (failed(parser.parseLBrace()))
+      return failure();
+
+    int32_t crtOperandsSize = operands.size();
+
+    if (failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::None, [&]() {
+              if (parser.parseOperand(operands.emplace_back()) ||
+                  parser.parseColonType(types.emplace_back()))
+                return failure();
+              return success();
+            })))
+      return failure();
+
+    seg.push_back(operands.size() - crtOperandsSize);
+
+    if (failed(parser.parseRBrace()))
+      return failure();
+
+    if (succeeded(parser.parseOptionalLSquare())) {
+      if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+          parser.parseRSquare())
+        return failure();
+    } else {
+      deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+          parser.getContext(), mlir::acc::DeviceType::None));
+    }
+  } while (succeeded(parser.parseOptionalComma()));
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+  keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+  segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+  return success();
+}
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+                             std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+
+  p << "[";
+  llvm::interleaveComma(*deviceTypes, p,
+                        [&](mlir::Attribute attr) { p << attr; });
+  p << "]";
+}
+
+static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
+  if (!hasDeviceTypeValues(attrs))
+    return false;
+  if (attrs->size() != 1)
+    return false;
+  if (auto deviceTypeAttr =
+          mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
+    return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
+  return false;
+}
+
+static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                            mlir::OperandRange operands, mlir::TypeRange types,
+                            std::optional<mlir::ArrayAttr> deviceTypes,
+                            std::optional<mlir::DenseI32ArrayAttr> segments,
+                            std::optional<mlir::ArrayAttr> keywordOnly) {
+
+  if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
+    return;
+
+  p << "(";
+
+  printDeviceTypes(p, keywordOnly);
+  if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  unsigned opIdx = 0;
+  llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
+    p << "{";
+    llvm::interleaveComma(
+        llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
+          p << operands[opIdx] << " : " << operands[opIdx].getType();
+          ++opIdx;
+        });
+    p << "}";
+    printSingleDeviceType(p, it.value());
+  });
+
+  p << ")";
+}
+
 static ParseResult parseDeviceTypeOperands(
     mlir::OpAsmP...
[truncated]

@clementval clementval force-pushed the acc_update_device_type_support branch from 11f94e3 to 96b042b Compare January 22, 2024 22:16
Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job being so thorough with device_type support!

@clementval clementval merged commit 78ef032 into llvm:main Jan 25, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants