-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][flang][openacc] Add device_type support for update op #78764
[mlir][flang][openacc] Add device_type support for update op #78764
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-openacc Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesAdd support for device_type information on the acc.update operation and update lowering from Flang. Patch is 29.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78764.diff 6 Files Affected:
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 682ca06cabd6f6..541ea2e114324f 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2840,27 +2840,42 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
}
}
+static inline mlir::ArrayAttr
+getArrayAttr(fir::FirOpBuilder &b,
+ llvm::SmallVector<mlir::Attribute> &attributes) {
+ return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
+}
+
+static inline mlir::DenseI32ArrayAttr
+getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
+ llvm::SmallVector<int32_t> &values) {
+ return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values);
+}
+
static void
genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
const Fortran::parser::AccClauseList &accClauseList) {
- mlir::Value ifCond, async, waitDevnum;
+ mlir::Value ifCond, waitDevnum;
llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
- waitOperands, deviceTypeOperands;
- llvm::SmallVector<mlir::Attribute> deviceTypes;
-
- // Async and wait clause have optional values but can be present with
- // no value as well. When there is no value, the op has an attribute to
- // represent the clause.
- bool addAsyncAttr = false;
- bool addWaitAttr = false;
- bool addIfPresentAttr = false;
+ waitOperands, deviceTypeOperands, asyncOperands;
+ llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
+ asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
+ llvm::SmallVector<int32_t> waitOperandsSegments;
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
- // Lower clauses values mapped to operands.
+ // device_type attribute is set to `none` until a device_type clause is
+ // encountered.
+ llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+ crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None));
+
+ bool ifPresent = false;
+
+ // Lower clauses values mapped to operands and array attributes.
// Keep track of each group of operands separately as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2870,15 +2885,19 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
- genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
+ genAsyncClause(converter, asyncClause, asyncOperands,
+ asyncOperandsDeviceTypes, asyncOnlyDeviceTypes,
+ crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
- genWaitClause(converter, waitClause, waitOperands, waitDevnum,
- addWaitAttr, stmtCtx);
+ genWaitClause(converter, waitClause, waitOperands,
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
+ waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
- gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
+ crtDeviceTypes.clear();
+ gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
} else if (const auto *hostClause =
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2892,7 +2911,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
dataClauseOperands, mlir::acc::DataClause::acc_update_device, false,
/*implicit=*/false);
} else if (std::get_if<Fortran::parser::AccClause::IfPresent>(&clause.u)) {
- addIfPresentAttr = true;
+ ifPresent = true;
} else if (const auto *selfClause =
std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
const std::optional<Fortran::parser::AccSelfClause> &accSelfClause =
@@ -2909,30 +2928,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
dataClauseOperands.append(updateHostOperands);
- // Prepare the operand segment size attribute and the operands value range.
- llvm::SmallVector<mlir::Value> operands;
- llvm::SmallVector<int32_t> operandSegments;
- addOperand(operands, operandSegments, ifCond);
- addOperand(operands, operandSegments, async);
- addOperand(operands, operandSegments, waitDevnum);
- addOperands(operands, operandSegments, waitOperands);
- addOperands(operands, operandSegments, dataClauseOperands);
-
- mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
- builder, currentLocation, operands, operandSegments);
- if (!deviceTypes.empty())
- updateOp.setDeviceTypesAttr(
- mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
+ builder.create<mlir::acc::UpdateOp>(
+ currentLocation, ifCond, asyncOperands,
+ getArrayAttr(builder, asyncOperandsDeviceTypes),
+ getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
+ getDenseI32ArrayAttr(builder, waitOperandsSegments),
+ getArrayAttr(builder, waitOperandsDeviceTypes),
+ getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
+ ifPresent);
genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
builder, updateHostOperands, /*structured=*/false);
-
- if (addAsyncAttr)
- updateOp.setAsyncAttr(builder.getUnitAttr());
- if (addWaitAttr)
- updateOp.setWaitAttr(builder.getUnitAttr());
- if (addIfPresentAttr)
- updateOp.setIfPresentAttr(builder.getUnitAttr());
}
static void
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index d2b15f8bd258e7..ac7a56c56b1f20 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -61,17 +61,17 @@ subroutine acc_update
!$acc update host(a) async
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
+! CHECK: acc.update async dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
+! CHECK: acc.update wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) async wait
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
+! CHECK: acc.update async wait dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) async(1)
@@ -89,14 +89,14 @@ subroutine acc_update
!$acc update host(a) wait(1)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait([[WAIT1]] : i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT1]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait(queues: 1, 2)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
!$acc update host(a) wait(devnum: 1: queues: 1, 2)
@@ -104,17 +104,12 @@ subroutine acc_update
! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
+! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
- !$acc update host(a) device_type(default, host)
+ !$acc update host(a) device_type(host, nvidia) async
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]}
-! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
-
- !$acc update host(a) device_type(*)
-! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<star>]}
+! CHECK: acc.update async([#acc.device_type<host>, #acc.device_type<nvidia>]) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
end subroutine acc_update
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 7344ab2852b9ce..5b678e84b93ee4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2187,14 +2187,16 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
}];
let arguments = (ins Optional<I1>:$ifCond,
- Optional<IntOrIndex>:$asyncOperand,
- Optional<IntOrIndex>:$waitDevnum,
- Variadic<IntOrIndex>:$waitOperands,
- UnitAttr:$async,
- UnitAttr:$wait,
- OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
- Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- UnitAttr:$ifPresent);
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$async,
+ Optional<IntOrIndex>:$waitDevnum,
+ Variadic<IntOrIndex>:$waitOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$wait,
+ Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
+ UnitAttr:$ifPresent);
let extraClassDeclaration = [{
/// The number of data operands.
@@ -2202,14 +2204,41 @@ def OpenACC_UpdateOp : OpenACC_Op<"update",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);
+
+ /// Return true if the op has the async attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasAsync();
+ /// Return true if the op has the async attribute for the given device_type.
+ bool hasAsync(mlir::acc::DeviceType deviceType);
+ /// Return the value of the async clause if present.
+ mlir::Value getAsyncValue();
+ /// Return the value of the async clause for the given device_type if
+ /// present.
+ mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the wait attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasWait();
+ /// Return true if the op has the wait attribute for the given device_type.
+ bool hasWait(mlir::acc::DeviceType deviceType);
+ /// Return the values of the wait clause if present.
+ mlir::Operation::operand_range getWaitValues();
+ /// Return the values of the wait clause for the given device_type if
+ /// present.
+ mlir::Operation::operand_range
+ getWaitValues(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
oilist(
`if` `(` $ifCond `)`
- | `async` `(` $asyncOperand `:` type($asyncOperand) `)`
+ | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+ $asyncOperands, type($asyncOperands),
+ $asyncOperandsDeviceType, $async)
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
- | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+ | `wait` `` custom<WaitClause>($waitOperands,
+ type($waitOperands), $waitOperandsDeviceType,
+ $waitOperandsSegments, $wait)
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
)
attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bc03adbcae64df..4e31f7b163b9dc 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -936,6 +936,138 @@ static void printDeviceTypeOperandsWithSegment(
});
}
+static ParseResult parseWaitClause(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+ llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+ mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) {
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs;
+ llvm::SmallVector<int32_t> seg;
+
+ bool needCommaBeforeOperands = false;
+
+ // Keyword only
+ if (failed(parser.parseOptionalLParen())) {
+ keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+ return success();
+ }
+
+ // Parse keyword only attributes
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(keywordAttrs.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+ if (parser.parseRSquare())
+ return failure();
+ needCommaBeforeOperands = true;
+ }
+
+ if (needCommaBeforeOperands && failed(parser.parseComma()))
+ return failure();
+
+ do {
+ if (failed(parser.parseLBrace()))
+ return failure();
+
+ int32_t crtOperandsSize = operands.size();
+
+ if (failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::None, [&]() {
+ if (parser.parseOperand(operands.emplace_back()) ||
+ parser.parseColonType(types.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ seg.push_back(operands.size() - crtOperandsSize);
+
+ if (failed(parser.parseRBrace()))
+ return failure();
+
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else {
+ deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (failed(parser.parseRParen()))
+ return failure();
+
+ deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+ keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
+ segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
+
+ return success();
+}
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+ if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+ return true;
+ return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+ std::optional<mlir::ArrayAttr> deviceTypes) {
+ if (!hasDeviceTypeValues(deviceTypes))
+ return;
+
+ p << "[";
+ llvm::interleaveComma(*deviceTypes, p,
+ [&](mlir::Attribute attr) { p << attr; });
+ p << "]";
+}
+
+static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
+ if (!hasDeviceTypeValues(attrs))
+ return false;
+ if (attrs->size() != 1)
+ return false;
+ if (auto deviceTypeAttr =
+ mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
+ return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
+ return false;
+}
+
+static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::OperandRange operands, mlir::TypeRange types,
+ std::optional<mlir::ArrayAttr> deviceTypes,
+ std::optional<mlir::DenseI32ArrayAttr> segments,
+ std::optional<mlir::ArrayAttr> keywordOnly) {
+
+ if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
+ return;
+
+ p << "(";
+
+ printDeviceTypes(p, keywordOnly);
+ if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
+ p << ", ";
+
+ unsigned opIdx = 0;
+ llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
+ p << "{";
+ llvm::interleaveComma(
+ llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
+ p << operands[opIdx] << " : " << operands[opIdx].getType();
+ ++opIdx;
+ });
+ p << "}";
+ printSingleDeviceType(p, it.value());
+ });
+
+ p << ")";
+}
+
static ParseResult parseDeviceTypeOperands(
mlir::OpAsmP...
[truncated]
|
11f94e3
to
96b042b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job being so thorough with device_type support!
Add support for device_type information on the acc.update operation and update lowering from Flang.