diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index d619d47fc2359..dac14242bd800 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2941,27 +2941,42 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter, } } +static inline mlir::ArrayAttr +getArrayAttr(fir::FirOpBuilder &b, + llvm::SmallVector &attributes) { + return attributes.empty() ? nullptr : b.getArrayAttr(attributes); +} + +static inline mlir::DenseI32ArrayAttr +getDenseI32ArrayAttr(fir::FirOpBuilder &builder, + llvm::SmallVector &values) { + return values.empty() ? nullptr : builder.getDenseI32ArrayAttr(values); +} + static void genACCUpdateOp(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, const Fortran::parser::AccClauseList &accClauseList) { - mlir::Value ifCond, async, waitDevnum; + mlir::Value ifCond, waitDevnum; llvm::SmallVector dataClauseOperands, updateHostOperands, - waitOperands, deviceTypeOperands; - llvm::SmallVector deviceTypes; - - // Async and wait clause have optional values but can be present with - // no value as well. When there is no value, the op has an attribute to - // represent the clause. - bool addAsyncAttr = false; - bool addWaitAttr = false; - bool addIfPresentAttr = false; + waitOperands, deviceTypeOperands, asyncOperands; + llvm::SmallVector asyncOperandsDeviceTypes, + asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes; + llvm::SmallVector waitOperandsSegments; fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - // Lower clauses values mapped to operands. + // device_type attribute is set to `none` until a device_type clause is + // encountered. + llvm::SmallVector crtDeviceTypes; + crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( + builder.getContext(), mlir::acc::DeviceType::None)); + + bool ifPresent = false; + + // Lower clauses values mapped to operands and array attributes. // Keep track of each group of operands separately as clauses can appear // more than once. for (const Fortran::parser::AccClause &clause : accClauseList.v) { @@ -2971,15 +2986,19 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter, genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx); } else if (const auto *asyncClause = std::get_if(&clause.u)) { - genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); + genAsyncClause(converter, asyncClause, asyncOperands, + asyncOperandsDeviceTypes, asyncOnlyDeviceTypes, + crtDeviceTypes, stmtCtx); } else if (const auto *waitClause = std::get_if(&clause.u)) { - genWaitClause(converter, waitClause, waitOperands, waitDevnum, - addWaitAttr, stmtCtx); + genWaitClause(converter, waitClause, waitOperands, + waitOperandsDeviceTypes, waitOnlyDeviceTypes, + waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx); } else if (const auto *deviceTypeClause = std::get_if( &clause.u)) { - gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes); + crtDeviceTypes.clear(); + gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes); } else if (const auto *hostClause = std::get_if(&clause.u)) { genDataOperandOperations( @@ -2993,7 +3012,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter, dataClauseOperands, mlir::acc::DataClause::acc_update_device, false, /*implicit=*/false); } else if (std::get_if(&clause.u)) { - addIfPresentAttr = true; + ifPresent = true; } else if (const auto *selfClause = std::get_if(&clause.u)) { const std::optional &accSelfClause = @@ -3010,30 +3029,17 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter, dataClauseOperands.append(updateHostOperands); - // Prepare the operand segment size attribute and the operands value range. - llvm::SmallVector operands; - llvm::SmallVector operandSegments; - addOperand(operands, operandSegments, ifCond); - addOperand(operands, operandSegments, async); - addOperand(operands, operandSegments, waitDevnum); - addOperands(operands, operandSegments, waitOperands); - addOperands(operands, operandSegments, dataClauseOperands); - - mlir::acc::UpdateOp updateOp = createSimpleOp( - builder, currentLocation, operands, operandSegments); - if (!deviceTypes.empty()) - updateOp.setDeviceTypesAttr( - mlir::ArrayAttr::get(builder.getContext(), deviceTypes)); + builder.create( + currentLocation, ifCond, asyncOperands, + getArrayAttr(builder, asyncOperandsDeviceTypes), + getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands, + getDenseI32ArrayAttr(builder, waitOperandsSegments), + getArrayAttr(builder, waitOperandsDeviceTypes), + getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands, + ifPresent); genDataExitOperations( builder, updateHostOperands, /*structured=*/false); - - if (addAsyncAttr) - updateOp.setAsyncAttr(builder.getUnitAttr()); - if (addWaitAttr) - updateOp.setWaitAttr(builder.getUnitAttr()); - if (addIfPresentAttr) - updateOp.setIfPresentAttr(builder.getUnitAttr()); } static void diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90 index d2b15f8bd258e..ba036ac928118 100644 --- a/flang/test/Lower/OpenACC/acc-update.f90 +++ b/flang/test/Lower/OpenACC/acc-update.f90 @@ -61,17 +61,17 @@ subroutine acc_update !$acc update host(a) async ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a", structured = false} -! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref>) attributes {async} +! CHECK: acc.update async() dataOperands(%[[DEVPTR_A]] : !fir.ref>) ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} !$acc update host(a) wait ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a", structured = false} -! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref>) attributes {wait} +! CHECK: acc.update wait dataOperands(%[[DEVPTR_A]] : !fir.ref>) ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} !$acc update host(a) async wait ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a", structured = false} -! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref>) attributes {async, wait} +! CHECK: acc.update async() wait dataOperands(%[[DEVPTR_A]] : !fir.ref>) ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} !$acc update host(a) async(1) @@ -89,14 +89,14 @@ subroutine acc_update !$acc update host(a) wait(1) ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a", structured = false} ! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32 -! CHECK: acc.update wait([[WAIT1]] : i32) dataOperands(%[[DEVPTR_A]] : !fir.ref>) +! CHECK: acc.update wait({[[WAIT1]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref>) ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} !$acc update host(a) wait(queues: 1, 2) ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a", structured = false} ! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32 -! CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref>) +! CHECK: acc.update wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref>) ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} !$acc update host(a) wait(devnum: 1: queues: 1, 2) @@ -104,17 +104,12 @@ subroutine acc_update ! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32 ! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32 ! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32 -! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref>) +! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref>) ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} - !$acc update host(a) device_type(default, host) + !$acc update host(a) device_type(host, nvidia) async ! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a", structured = false} -! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref>) attributes {device_types = [#acc.device_type, #acc.device_type]} -! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} - - !$acc update host(a) device_type(*) -! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref> {dataClause = #acc, name = "a", structured = false} -! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref>) attributes {device_types = [#acc.device_type]} +! CHECK: acc.update async([#acc.device_type, #acc.device_type]) dataOperands(%[[DEVPTR_A]] : !fir.ref>) ! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref>) {name = "a", structured = false} end subroutine acc_update diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 992f2809644a6..87fd587782e7c 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -2196,14 +2196,16 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", }]; let arguments = (ins Optional:$ifCond, - Optional:$asyncOperand, - Optional:$waitDevnum, - Variadic:$waitOperands, - UnitAttr:$async, - UnitAttr:$wait, - OptionalAttr>:$device_types, - Variadic:$dataClauseOperands, - UnitAttr:$ifPresent); + Variadic:$asyncOperands, + OptionalAttr:$asyncOperandsDeviceType, + OptionalAttr:$async, + Optional:$waitDevnum, + Variadic:$waitOperands, + OptionalAttr:$waitOperandsSegments, + OptionalAttr:$waitOperandsDeviceType, + OptionalAttr:$wait, + Variadic:$dataClauseOperands, + UnitAttr:$ifPresent); let extraClassDeclaration = [{ /// The number of data operands. @@ -2211,14 +2213,41 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", /// The i-th data operand passed. Value getDataOperand(unsigned i); + + /// Return true if the op has the async attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasAsyncOnly(); + /// Return true if the op has the async attribute for the given device_type. + bool hasAsyncOnly(mlir::acc::DeviceType deviceType); + /// Return the value of the async clause if present. + mlir::Value getAsyncValue(); + /// Return the value of the async clause for the given device_type if + /// present. + mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType); + + /// Return true if the op has the wait attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasWaitOnly(); + /// Return true if the op has the wait attribute for the given device_type. + bool hasWaitOnly(mlir::acc::DeviceType deviceType); + /// Return the values of the wait clause if present. + mlir::Operation::operand_range getWaitValues(); + /// Return the values of the wait clause for the given device_type if + /// present. + mlir::Operation::operand_range + getWaitValues(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ oilist( `if` `(` $ifCond `)` - | `async` `(` $asyncOperand `:` type($asyncOperand) `)` + | `async` `` custom( + $asyncOperands, type($asyncOperands), + $asyncOperandsDeviceType, $async) | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)` - | `wait` `(` $waitOperands `:` type($waitOperands) `)` + | `wait` `` custom($waitOperands, + type($waitOperands), $waitOperandsDeviceType, + $waitOperandsSegments, $wait) | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` ) attr-dict-with-keyword diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index f6229e5192a0a..e1e69113bca16 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -963,6 +963,121 @@ static void printDeviceTypeOperandsWithSegment( }); } +static ParseResult parseWaitClause( + mlir::OpAsmParser &parser, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, + mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &keywordOnly) { + llvm::SmallVector deviceTypeAttrs, keywordAttrs; + llvm::SmallVector seg; + + bool needCommaBeforeOperands = false; + + // Keyword only + if (failed(parser.parseOptionalLParen())) { + keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs); + return success(); + } + + // Parse keyword only attributes + if (succeeded(parser.parseOptionalLSquare())) { + if (failed(parser.parseCommaSeparatedList([&]() { + if (parser.parseAttribute(keywordAttrs.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + if (parser.parseRSquare()) + return failure(); + needCommaBeforeOperands = true; + } + + if (needCommaBeforeOperands && failed(parser.parseComma())) + return failure(); + + do { + if (failed(parser.parseLBrace())) + return failure(); + + int32_t crtOperandsSize = operands.size(); + + if (failed(parser.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::None, [&]() { + if (parser.parseOperand(operands.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + + seg.push_back(operands.size() - crtOperandsSize); + + if (failed(parser.parseRBrace())) + return failure(); + + if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) || + parser.parseRSquare()) + return failure(); + } else { + deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + } + } while (succeeded(parser.parseOptionalComma())); + + if (failed(parser.parseRParen())) + return failure(); + + deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); + keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs); + segments = DenseI32ArrayAttr::get(parser.getContext(), seg); + + return success(); +} + +static bool hasOnlyDeviceTypeNone(std::optional attrs) { + if (!hasDeviceTypeValues(attrs)) + return false; + if (attrs->size() != 1) + return false; + if (auto deviceTypeAttr = + mlir::dyn_cast((*attrs)[0])) + return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None; + return false; +} + +static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::OperandRange operands, mlir::TypeRange types, + std::optional deviceTypes, + std::optional segments, + std::optional keywordOnly) { + + if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly)) + return; + + p << "("; + + printDeviceTypes(p, keywordOnly); + if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes)) + p << ", "; + + unsigned opIdx = 0; + llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { + p << "{"; + llvm::interleaveComma( + llvm::seq(0, (*segments)[it.index()]), p, [&](auto it) { + p << operands[opIdx] << " : " << operands[opIdx].getType(); + ++opIdx; + }); + p << "}"; + printSingleDeviceType(p, it.value()); + }); + + p << ")"; +} + static ParseResult parseDeviceTypeOperands( mlir::OpAsmParser &parser, llvm::SmallVectorImpl &operands, @@ -993,6 +1108,8 @@ static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional deviceTypes) { + if (!hasDeviceTypeValues(deviceTypes)) + return; llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) { p << std::get<1>(it) << " : " << std::get<1>(it).getType(); printSingleDeviceType(p, std::get<0>(it)); @@ -1068,15 +1185,10 @@ static void printDeviceTypeOperandsWithKeywordOnly( std::optional keywordOnlyDeviceTypes) { p << "("; - - if (operands.begin() == operands.end() && keywordOnlyDeviceTypes && - keywordOnlyDeviceTypes->size() == 1) { - auto deviceTypeAttr = - mlir::dyn_cast((*keywordOnlyDeviceTypes)[0]); - if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) { - p << ")"; - return; - } + if (operands.begin() == operands.end() && + hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) { + p << ")"; + return; } printDeviceTypes(p, keywordOnlyDeviceTypes); @@ -1452,14 +1564,9 @@ void printGangClause(OpAsmPrinter &p, Operation *op, p << "("; if (operands.begin() == operands.end() && - hasDeviceTypeValues(gangOnlyDeviceTypes) && - gangOnlyDeviceTypes->size() == 1) { - auto deviceTypeAttr = - mlir::dyn_cast((*gangOnlyDeviceTypes)[0]); - if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) { - p << ")"; - return; - } + hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) { + p << ")"; + return; } printDeviceTypes(p, gangOnlyDeviceTypes); @@ -2432,15 +2539,30 @@ LogicalResult acc::UpdateOp::verify() { if (getDataClauseOperands().empty()) return emitError("at least one value must be present in dataOperands"); - // The async attribute represent the async clause without value. Therefore the - // attribute and operand cannot appear at the same time. - if (getAsyncOperand() && getAsync()) - return emitError("async attribute cannot appear with asyncOperand"); + if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), + getAsyncOperandsDeviceTypeAttr(), + "async"))) + return failure(); - // The wait attribute represent the wait clause without values. Therefore the - // attribute and operands cannot appear at the same time. - if (!getWaitOperands().empty() && getWait()) - return emitError("wait attribute cannot appear with waitOperands"); + if (failed(verifyDeviceTypeAndSegmentCountMatch( + *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), + getWaitOperandsDeviceTypeAttr(), "wait"))) + return failure(); + + for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); + ++dtypeInt) { + auto dtype = static_cast(dtypeInt); + + // The async attribute represent the async clause without value. Therefore + // the attribute and operand cannot appear at the same time. + if (getAsyncValue(dtype) && hasAsyncOnly(dtype)) + return emitError("async attribute cannot appear with asyncOperand"); + + // The wait attribute represent the wait clause without values. Therefore + // the attribute and operands cannot appear at the same time. + if (!getWaitValues(dtype).empty() && hasWaitOnly(dtype)) + return emitError("wait attribute cannot appear with waitOperands"); + } if (getWaitDevnum() && getWaitOperands().empty()) return emitError("wait_devnum cannot appear without waitOperands"); @@ -2459,7 +2581,7 @@ unsigned UpdateOp::getNumDataOperands() { } Value UpdateOp::getDataOperand(unsigned i) { - unsigned numOptional = getAsyncOperand() ? 1 : 0; + unsigned numOptional = getAsyncOperands().size(); numOptional += getWaitDevnum() ? 1 : 0; numOptional += getIfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); @@ -2470,6 +2592,46 @@ void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add>(context); } +bool UpdateOp::hasAsyncOnly() { + return hasAsyncOnly(mlir::acc::DeviceType::None); +} + +bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { + return hasDeviceType(getAsync(), deviceType); +} + +mlir::Value UpdateOp::getAsyncValue() { + return getAsyncValue(mlir::acc::DeviceType::None); +} + +mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) { + if (!hasDeviceTypeValues(getAsyncOperandsDeviceType())) + return {}; + + if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType)) + return getAsyncOperands()[*pos]; + + return {}; +} + +bool UpdateOp::hasWaitOnly() { + return hasWaitOnly(mlir::acc::DeviceType::None); +} + +bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { + return hasDeviceType(getWait(), deviceType); +} + +mlir::Operation::operand_range UpdateOp::getWaitValues() { + return getWaitValues(mlir::acc::DeviceType::None); +} + +mlir::Operation::operand_range +UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) { + return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), + getWaitOperandsSegments(), deviceType); +} + //===----------------------------------------------------------------------===// // WaitOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index 57ae5856149d1..80d439f19d9f4 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -138,7 +138,7 @@ acc.update wait_devnum(%cst: index) dataOperands(%0: memref) %value = memref.alloc() : memref %0 = acc.update_device varPtr(%value : memref) -> memref // expected-error@+1 {{async attribute cannot appear with asyncOperand}} -acc.update async(%cst: index) dataOperands(%0 : memref) attributes {async} +acc.update async(%cst: index) dataOperands(%0 : memref) attributes {async = [#acc.device_type]} // ----- @@ -146,7 +146,7 @@ acc.update async(%cst: index) dataOperands(%0 : memref) attributes {async} %value = memref.alloc() : memref %0 = acc.update_device varPtr(%value : memref) -> memref // expected-error@+1 {{wait attribute cannot appear with waitOperands}} -acc.update wait(%cst: index) dataOperands(%0: memref) attributes {wait} +acc.update wait({%cst: index}) dataOperands(%0: memref) attributes {wait = [#acc.device_type]} // ----- diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index d4c884a837f87..45b41f1a77225 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -934,17 +934,17 @@ func.func @testupdateop(%a: memref, %b: memref, %c: memref) -> () acc.update async(%i32Value: i32) dataOperands(%0: memref) acc.update async(%i32Value: i32) dataOperands(%0: memref) acc.update async(%idxValue: index) dataOperands(%0: memref) - acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) dataOperands(%0: memref) + acc.update wait_devnum(%i64Value: i64) wait({%i32Value : i32, %idxValue : index}) dataOperands(%0: memref) acc.update if(%ifCond) dataOperands(%0: memref) - acc.update dataOperands(%0: memref) attributes {acc.device_types = [#acc.device_type]} + acc.update dataOperands(%0: memref) acc.update dataOperands(%0, %1, %2 : memref, memref, memref) - acc.update dataOperands(%0, %1, %2 : memref, memref, memref) attributes {async} - acc.update dataOperands(%0, %1, %2 : memref, memref, memref) attributes {wait} + acc.update async() dataOperands(%0, %1, %2 : memref, memref, memref) + acc.update wait dataOperands(%0, %1, %2 : memref, memref, memref) acc.update dataOperands(%0, %1, %2 : memref, memref, memref) attributes {ifPresent} return } -// CHECK: func @testupdateop([[ARGA:%.*]]: memref, [[ARGB:%.*]]: memref, [[ARGC:%.*]]: memref) { +// CHECK: func.func @testupdateop(%{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) // CHECK: [[I64VALUE:%.*]] = arith.constant 1 : i64 // CHECK: [[I32VALUE:%.*]] = arith.constant 1 : i32 // CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index @@ -953,12 +953,12 @@ func.func @testupdateop(%a: memref, %b: memref, %c: memref) -> () // CHECK: acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref) // CHECK: acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref) // CHECK: acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref) -// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) dataOperands(%{{.*}} : memref) +// CHECK: acc.update wait_devnum([[I64VALUE]] : i64) wait({[[I32VALUE]] : i32, [[IDXVALUE]] : index}) dataOperands(%{{.*}} : memref) // CHECK: acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref) -// CHECK: acc.update dataOperands(%{{.*}} : memref) attributes {acc.device_types = [#acc.device_type]} +// CHECK: acc.update dataOperands(%{{.*}} : memref) // CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref) -// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref) attributes {async} -// CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref) attributes {wait} +// CHECK: acc.update async() dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref) +// CHECK: acc.update wait dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref) // CHECK: acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref) attributes {ifPresent} // -----