diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index db9ed72bc8725..fd89d27db74dc 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -3469,6 +3469,72 @@ static void genACC(Fortran::lower::AbstractConverter &converter, llvm_unreachable("unsupported declarative directive"); } +static bool hasDeviceType(llvm::SmallVector &arrayAttr, + mlir::acc::DeviceType deviceType) { + for (auto attr : arrayAttr) { + auto deviceTypeAttr = mlir::dyn_cast(attr); + if (deviceTypeAttr.getValue() == deviceType) + return true; + } + return false; +} + +template +static std::optional +getAttributeValueByDeviceType(llvm::SmallVector &attributes, + llvm::SmallVector &deviceTypes, + mlir::acc::DeviceType deviceType) { + assert(attributes.size() == deviceTypes.size() && + "expect same number of attributes"); + for (auto it : llvm::enumerate(deviceTypes)) { + auto deviceTypeAttr = mlir::dyn_cast(it.value()); + if (deviceTypeAttr.getValue() == deviceType) { + if constexpr (std::is_same_v) { + auto strAttr = mlir::dyn_cast(attributes[it.index()]); + return strAttr.getValue(); + } else if constexpr (std::is_same_v) { + auto intAttr = + mlir::dyn_cast(attributes[it.index()]); + return intAttr.getInt(); + } + } + } + return std::nullopt; +} + +static bool compareDeviceTypeInfo( + mlir::acc::RoutineOp op, + llvm::SmallVector &bindNameArrayAttr, + llvm::SmallVector &bindNameDeviceTypeArrayAttr, + llvm::SmallVector &gangArrayAttr, + llvm::SmallVector &gangDimArrayAttr, + llvm::SmallVector &gangDimDeviceTypeArrayAttr, + llvm::SmallVector &seqArrayAttr, + llvm::SmallVector &workerArrayAttr, + llvm::SmallVector &vectorArrayAttr) { + for (uint32_t dtypeInt = 0; + dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) { + auto dtype = static_cast(dtypeInt); + if (op.getBindNameValue(dtype) != + getAttributeValueByDeviceType( + bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype)) + return false; + if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype)) + return false; + if (op.getGangDimValue(dtype) != + getAttributeValueByDeviceType( + gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype)) + return false; + if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype)) + return false; + if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype)) + return false; + if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype)) + return false; + } + return true; +} + static void attachRoutineInfo(mlir::func::FuncOp func, mlir::SymbolRefAttr routineAttr) { llvm::SmallVector routines; @@ -3518,17 +3584,23 @@ void Fortran::lower::genOpenACCRoutineConstruct( funcName = funcOp.getName(); } } - bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false, - hasNohost = false; - std::optional bindName = std::nullopt; - std::optional gangDim = std::nullopt; + bool hasNohost = false; + + llvm::SmallVector seqDeviceTypes, vectorDeviceTypes, + workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes, + gangDimDeviceTypes, gangDimValues; + + // device_type attribute is set to `none` until a device_type clause is + // encountered. + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + builder.getContext(), mlir::acc::DeviceType::None); for (const Fortran::parser::AccClause &clause : clauses.v) { if (std::get_if(&clause.u)) { - hasSeq = true; + seqDeviceTypes.push_back(crtDeviceTypeAttr); } else if (const auto *gangClause = std::get_if(&clause.u)) { - hasGang = true; + if (gangClause->v) { const Fortran::parser::AccGangArgList &x = *gangClause->v; for (const Fortran::parser::AccGangArg &gangArg : x.v) { @@ -3539,21 +3611,27 @@ void Fortran::lower::genOpenACCRoutineConstruct( if (!dimValue) mlir::emitError(loc, "dim value must be a constant positive integer"); - gangDim = *dimValue; + gangDimValues.push_back( + builder.getIntegerAttr(builder.getI64Type(), *dimValue)); + gangDimDeviceTypes.push_back(crtDeviceTypeAttr); } } + } else { + gangDeviceTypes.push_back(crtDeviceTypeAttr); } } else if (std::get_if(&clause.u)) { - hasVector = true; + vectorDeviceTypes.push_back(crtDeviceTypeAttr); } else if (std::get_if(&clause.u)) { - hasWorker = true; + workerDeviceTypes.push_back(crtDeviceTypeAttr); } else if (std::get_if(&clause.u)) { hasNohost = true; } else if (const auto *bindClause = std::get_if(&clause.u)) { if (const auto *name = std::get_if(&bindClause->v.u)) { - bindName = converter.mangleName(*name->symbol); + bindNames.push_back( + builder.getStringAttr(converter.mangleName(*name->symbol))); + bindNameDeviceTypes.push_back(crtDeviceTypeAttr); } else if (const auto charExpr = std::get_if( &bindClause->v.u)) { @@ -3562,8 +3640,18 @@ void Fortran::lower::genOpenACCRoutineConstruct( *charExpr); if (!name) mlir::emitError(loc, "Could not retrieve the bind name"); - bindName = *name; + bindNames.push_back(builder.getStringAttr(*name)); + bindNameDeviceTypes.push_back(crtDeviceTypeAttr); } + } else if (const auto *deviceTypeClause = + std::get_if( + &clause.u)) { + const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList = + deviceTypeClause->v; + assert(deviceTypeExprList.v.size() == 1 && + "expect only one device_type expr"); + crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v)); } } @@ -3575,12 +3663,11 @@ void Fortran::lower::genOpenACCRoutineConstruct( if (routineOp.getFuncName().str().compare(funcName) == 0) { // If the routine is already specified with the same clauses, just skip // the operation creation. - if (routineOp.getBindName() == bindName && - routineOp.getGang() == hasGang && - routineOp.getWorker() == hasWorker && - routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq && - routineOp.getNohost() == hasNohost && - routineOp.getGangDim() == gangDim) + if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes, + gangDeviceTypes, gangDimValues, + gangDimDeviceTypes, seqDeviceTypes, + workerDeviceTypes, vectorDeviceTypes) && + routineOp.getNohost() == hasNohost) return; mlir::emitError(loc, "Routine already specified with different clauses"); } @@ -3588,10 +3675,19 @@ void Fortran::lower::genOpenACCRoutineConstruct( modBuilder.create( loc, routineOpName.str(), funcName, - bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang, - hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false, - gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim) - : mlir::IntegerAttr{}); + bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames), + bindNameDeviceTypes.empty() ? nullptr + : builder.getArrayAttr(bindNameDeviceTypes), + workerDeviceTypes.empty() ? nullptr + : builder.getArrayAttr(workerDeviceTypes), + vectorDeviceTypes.empty() ? nullptr + : builder.getArrayAttr(vectorDeviceTypes), + seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes), + hasNohost, /*implicit=*/false, + gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes), + gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues), + gangDimDeviceTypes.empty() ? nullptr + : builder.getArrayAttr(gangDimDeviceTypes)); if (funcOp) attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str())); diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90 index 8b94279503334..2fe150e70b0cf 100644 --- a/flang/test/Lower/OpenACC/acc-routine.f90 +++ b/flang/test/Lower/OpenACC/acc-routine.f90 @@ -2,12 +2,14 @@ ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s - +! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type], "_QPacc_routine16" [#acc.device_type]) +! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type]) vector ([#acc.device_type]) +! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type]) seq ([#acc.device_type]) ! CHECK: acc.routine @acc_routine_10 func(@_QPacc_routine11) seq ! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq ! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a") ! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_") -! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim = 1 : i32) +! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim: 1 : i64) ! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost ! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker ! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector @@ -106,3 +108,15 @@ subroutine acc_routine14() subroutine acc_routine15() !$acc routine bind(acc_routine16) end subroutine + +subroutine acc_routine16() + !$acc routine device_type(host) seq dtype(nvidia) gang +end subroutine + +subroutine acc_routine17() + !$acc routine device_type(host) worker dtype(multicore) vector +end subroutine + +subroutine acc_routine18() + !$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16) +end subroutine diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 24f129d92805c..7344ab2852b9c 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -1994,27 +1994,63 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> { let arguments = (ins SymbolNameAttr:$sym_name, SymbolNameAttr:$func_name, - OptionalAttr:$bind_name, - UnitAttr:$gang, - UnitAttr:$worker, - UnitAttr:$vector, - UnitAttr:$seq, + OptionalAttr:$bindName, + OptionalAttr:$bindNameDeviceType, + OptionalAttr:$worker, + OptionalAttr:$vector, + OptionalAttr:$seq, UnitAttr:$nohost, UnitAttr:$implicit, - OptionalAttr:$gangDim); + OptionalAttr:$gang, + OptionalAttr:$gangDim, + OptionalAttr:$gangDimDeviceType); let extraClassDeclaration = [{ static StringRef getGangDimKeyword() { return "dim"; } + + /// Return true if the op has the worker attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasWorker(); + /// Return true if the op has the worker attribute for the given + /// device_type. + bool hasWorker(mlir::acc::DeviceType deviceType); + + /// Return true if the op has the vector attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasVector(); + /// Return true if the op has the vector attribute for the given + /// device_type. + bool hasVector(mlir::acc::DeviceType deviceType); + + /// Return true if the op has the seq attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasSeq(); + /// Return true if the op has the seq attribute for the given + /// device_type. + bool hasSeq(mlir::acc::DeviceType deviceType); + + /// Return true if the op has the gang attribute for the + /// mlir::acc::DeviceType::None device_type. + bool hasGang(); + /// Return true if the op has the gang attribute for the given + /// device_type. + bool hasGang(mlir::acc::DeviceType deviceType); + + std::optional getGangDimValue(); + std::optional getGangDimValue(mlir::acc::DeviceType deviceType); + + std::optional getBindNameValue(); + std::optional getBindNameValue(mlir::acc::DeviceType deviceType); }]; let assemblyFormat = [{ $sym_name `func` `(` $func_name `)` oilist ( - `bind` `(` $bind_name `)` - | `gang` `` custom($gang, $gangDim) - | `worker` $worker - | `vector` $vector - | `seq` $seq + `bind` `(` custom($bindName, $bindNameDeviceType) `)` + | `gang` `` custom($gang, $gangDim, $gangDimDeviceType) + | `worker` custom($worker) + | `vector` custom($vector) + | `seq` custom($seq) | `nohost` $nohost | `implicit` $implicit ) attr-dict-with-keyword diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 664a0161b79c1..20465f6bb86ed 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1033,7 +1033,7 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly( return success(); } -bool hasDeviceTypeValues(std::optional arrayAttr) { +static bool hasDeviceTypeValues(std::optional arrayAttr) { if (arrayAttr && *arrayAttr && arrayAttr->size() > 0) return true; return false; @@ -2090,55 +2090,281 @@ LogicalResult acc::DeclareOp::verify() { // RoutineOp //===----------------------------------------------------------------------===// +static bool hasDeviceType(std::optional arrayAttr, + mlir::acc::DeviceType deviceType) { + if (!hasDeviceTypeValues(arrayAttr)) + return false; + + for (auto attr : *arrayAttr) { + auto deviceTypeAttr = mlir::dyn_cast(attr); + if (deviceTypeAttr.getValue() == deviceType) + return true; + } + + return false; +} + +static unsigned getParallelismForDeviceType(acc::RoutineOp op, + acc::DeviceType dtype) { + unsigned parallelism = 0; + parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0; + parallelism += op.hasWorker(dtype) ? 1 : 0; + parallelism += op.hasVector(dtype) ? 1 : 0; + parallelism += op.hasSeq(dtype) ? 1 : 0; + return parallelism; +} + LogicalResult acc::RoutineOp::verify() { - int parallelism = 0; - parallelism += getGang() ? 1 : 0; - parallelism += getWorker() ? 1 : 0; - parallelism += getVector() ? 1 : 0; - parallelism += getSeq() ? 1 : 0; + unsigned baseParallelism = + getParallelismForDeviceType(*this, acc::DeviceType::None); - if (parallelism > 1) + if (baseParallelism > 1) return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " "be present at the same time"; + for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); + ++dtypeInt) { + auto dtype = static_cast(dtypeInt); + if (dtype == acc::DeviceType::None) + continue; + unsigned parallelism = getParallelismForDeviceType(*this, dtype); + + if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) + return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " + "be present at the same time"; + } + return success(); } -static ParseResult parseRoutineGangClause(OpAsmParser &parser, UnitAttr &gang, - IntegerAttr &gangDim) { - // Since gang clause exists, ensure that unit attribute is set. - gang = UnitAttr::get(parser.getBuilder().getContext()); +static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, + mlir::ArrayAttr &deviceTypes) { + llvm::SmallVector bindNameAttrs; + llvm::SmallVector deviceTypeAttrs; - // Next, look for dim on gang. Don't initialize `gangDim` yet since - // we leave it without attribute if there is no `dim` specifier. - if (succeeded(parser.parseOptionalLParen())) { - // Look for syntax that looks like `dim = 1 : i32`. - // Thus first look for `dim =` - if (failed(parser.parseKeyword(RoutineOp::getGangDimKeyword())) || - failed(parser.parseEqual())) - return failure(); + if (failed(parser.parseCommaSeparatedList([&]() { + if (parser.parseAttribute(bindNameAttrs.emplace_back())) + return failure(); + if (failed(parser.parseOptionalLSquare())) { + deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + } else { + if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) || + parser.parseRSquare()) + return failure(); + } + return success(); + }))) + return failure(); - int64_t dimValue; - Type valueType; - // Now look for `1 : i32` - if (failed(parser.parseInteger(dimValue)) || - failed(parser.parseColonType(valueType))) - return failure(); + bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs); + deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); + + return success(); +} - gangDim = IntegerAttr::get(valueType, dimValue); +static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, + std::optional bindName, + std::optional deviceTypes) { + llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p, + [&](const auto &pair) { + p << std::get<0>(pair); + printSingleDeviceType(p, std::get<1>(pair)); + }); +} + +static ParseResult parseRoutineGangClause(OpAsmParser &parser, + mlir::ArrayAttr &gang, + mlir::ArrayAttr &gangDim, + mlir::ArrayAttr &gangDimDeviceTypes) { + + llvm::SmallVector gangAttrs, gangDimAttrs, + gangDimDeviceTypeAttrs; + bool needCommaBeforeOperands = false; - if (failed(parser.parseRParen())) + // Gang keyword only + if (failed(parser.parseOptionalLParen())) { + gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + gang = ArrayAttr::get(parser.getContext(), gangAttrs); + return success(); + } + + // Parse keyword only attributes + if (succeeded(parser.parseOptionalLSquare())) { + if (failed(parser.parseCommaSeparatedList([&]() { + if (parser.parseAttribute(gangAttrs.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + if (parser.parseRSquare()) return failure(); + needCommaBeforeOperands = true; } + if (needCommaBeforeOperands && failed(parser.parseComma())) + return failure(); + + if (failed(parser.parseCommaSeparatedList([&]() { + if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) || + parser.parseColon() || + parser.parseAttribute(gangDimAttrs.emplace_back())) + return failure(); + if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) || + parser.parseRSquare()) + return failure(); + } else { + gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + } + return success(); + }))) + return failure(); + + if (failed(parser.parseRParen())) + return failure(); + + gang = ArrayAttr::get(parser.getContext(), gangAttrs); + gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs); + gangDimDeviceTypes = + ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs); + return success(); } -void printRoutineGangClause(OpAsmPrinter &p, Operation *op, UnitAttr gang, - IntegerAttr gangDim) { - if (gangDim) - p << "(" << RoutineOp::getGangDimKeyword() << " = " << gangDim.getValue() - << " : " << gangDim.getType() << ")"; +void printRoutineGangClause(OpAsmPrinter &p, Operation *op, + std::optional gang, + std::optional gangDim, + std::optional gangDimDeviceTypes) { + + if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) && + gang->size() == 1) { + auto deviceTypeAttr = mlir::dyn_cast((*gang)[0]); + if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) + return; + } + + p << "("; + + printDeviceTypes(p, gang); + + if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes)) + p << ", "; + + if (hasDeviceTypeValues(gangDimDeviceTypes)) + llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p, + [&](const auto &pair) { + p << acc::RoutineOp::getGangDimKeyword() << ": "; + p << std::get<0>(pair); + printSingleDeviceType(p, std::get<1>(pair)); + }); + + p << ")"; +} + +static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, + mlir::ArrayAttr &deviceTypes) { + llvm::SmallVector attributes; + // Keyword only + if (failed(parser.parseOptionalLParen())) { + attributes.push_back(mlir::acc::DeviceTypeAttr::get( + parser.getContext(), mlir::acc::DeviceType::None)); + deviceTypes = ArrayAttr::get(parser.getContext(), attributes); + return success(); + } + + // Parse device type attributes + if (succeeded(parser.parseOptionalLSquare())) { + if (failed(parser.parseCommaSeparatedList([&]() { + if (parser.parseAttribute(attributes.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + if (parser.parseRSquare() || parser.parseRParen()) + return failure(); + } + deviceTypes = ArrayAttr::get(parser.getContext(), attributes); + return success(); +} + +static void +printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, + std::optional deviceTypes) { + + if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) { + auto deviceTypeAttr = + mlir::dyn_cast((*deviceTypes)[0]); + if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) + return; + } + + if (!hasDeviceTypeValues(deviceTypes)) + return; + + p << "(["; + llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) { + auto dTypeAttr = mlir::dyn_cast(attr); + p << dTypeAttr; + }); + p << "])"; +} + +bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } + +bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) { + return hasDeviceType(getWorker(), deviceType); +} + +bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } + +bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) { + return hasDeviceType(getVector(), deviceType); +} + +bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } + +bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) { + return hasDeviceType(getSeq(), deviceType); +} + +std::optional RoutineOp::getBindNameValue() { + return getBindNameValue(mlir::acc::DeviceType::None); +} + +std::optional +RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) { + if (!hasDeviceTypeValues(getBindNameDeviceType())) + return std::nullopt; + if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) { + auto attr = (*getBindName())[*pos]; + auto stringAttr = dyn_cast(attr); + return stringAttr.getValue(); + } + return std::nullopt; +} + +bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } + +bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) { + return hasDeviceType(getGang(), deviceType); +} + +std::optional RoutineOp::getGangDimValue() { + return getGangDimValue(mlir::acc::DeviceType::None); +} + +std::optional +RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { + if (!hasDeviceTypeValues(getGangDimDeviceType())) + return std::nullopt; + if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) { + auto intAttr = mlir::dyn_cast((*getGangDim())[*pos]); + return intAttr.getInt(); + } + return std::nullopt; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 8fa37bc98294c..99b44183758d9 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1656,7 +1656,7 @@ acc.routine @acc_func_rout5 func(@acc_func) bind("acc_func_gpu_worker") worker acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") implicit gang acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost -acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim = 1 : i32) +acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim: 1 : i64) // CHECK-LABEL: func.func @acc_func( // CHECK: attributes {acc.routine_info = #acc.routine_info<[@acc_func_rout1, @acc_func_rout2, @acc_func_rout3, @@ -1669,7 +1669,7 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang( // CHECK: acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq // CHECK: acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") gang implicit // CHECK: acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost -// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim = 1 : i32) +// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim: 1 : i64) // ----- diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp index d78d7b0fdf676..474f887928992 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp @@ -347,3 +347,61 @@ TEST_F(OpenACCOpsTest, loopOpGangVectorWorkerTest) { } op->removeVectorAttr(); } + +TEST_F(OpenACCOpsTest, routineOpTest) { + OwningOpRef op = + b.create(loc, TypeRange{}, ValueRange{}); + + EXPECT_FALSE(op->hasSeq()); + EXPECT_FALSE(op->hasVector()); + EXPECT_FALSE(op->hasWorker()); + + for (auto d : dtypes) { + EXPECT_FALSE(op->hasSeq(d)); + EXPECT_FALSE(op->hasVector(d)); + EXPECT_FALSE(op->hasWorker(d)); + } + + auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None); + op->setSeqAttr(b.getArrayAttr({dtypeNone})); + EXPECT_TRUE(op->hasSeq()); + for (auto d : dtypesWithoutNone) + EXPECT_FALSE(op->hasSeq(d)); + op->removeSeqAttr(); + + op->setVectorAttr(b.getArrayAttr({dtypeNone})); + EXPECT_TRUE(op->hasVector()); + for (auto d : dtypesWithoutNone) + EXPECT_FALSE(op->hasVector(d)); + op->removeVectorAttr(); + + op->setWorkerAttr(b.getArrayAttr({dtypeNone})); + EXPECT_TRUE(op->hasWorker()); + for (auto d : dtypesWithoutNone) + EXPECT_FALSE(op->hasWorker(d)); + op->removeWorkerAttr(); + + op->setGangAttr(b.getArrayAttr({dtypeNone})); + EXPECT_TRUE(op->hasGang()); + for (auto d : dtypesWithoutNone) + EXPECT_FALSE(op->hasGang(d)); + op->removeGangAttr(); + + op->setGangDimDeviceTypeAttr(b.getArrayAttr({dtypeNone})); + op->setGangDimAttr(b.getArrayAttr({b.getIntegerAttr(b.getI64Type(), 8)})); + EXPECT_TRUE(op->getGangDimValue().has_value()); + EXPECT_EQ(op->getGangDimValue().value(), 8); + for (auto d : dtypesWithoutNone) + EXPECT_FALSE(op->getGangDimValue(d).has_value()); + op->removeGangDimDeviceTypeAttr(); + op->removeGangDimAttr(); + + op->setBindNameDeviceTypeAttr(b.getArrayAttr({dtypeNone})); + op->setBindNameAttr(b.getArrayAttr({b.getStringAttr("fname")})); + EXPECT_TRUE(op->getBindNameValue().has_value()); + EXPECT_EQ(op->getBindNameValue().value(), "fname"); + for (auto d : dtypesWithoutNone) + EXPECT_FALSE(op->getBindNameValue(d).has_value()); + op->removeBindNameDeviceTypeAttr(); + op->removeBindNameAttr(); +}