diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h index e4c954159f71b..0d650f830b64e 100644 --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -708,6 +708,13 @@ mlir::Value createNullBoxProc(fir::FirOpBuilder &builder, mlir::Location loc, /// Set internal linkage attribute on a function. void setInternalLinkage(mlir::func::FuncOp); + +llvm::SmallVector +elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape); + +llvm::SmallVector +elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams); + } // namespace fir::factory #endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h index da10969ebc702..c9eb5bc857ac0 100644 --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -222,6 +222,23 @@ inline bool HasCUDAAttr(const Symbol &sym) { return false; } +inline bool NeedCUDAAlloc(const Symbol &sym) { + bool inDeviceSubprogram{IsCUDADeviceContext(&sym.owner())}; + if (const auto *details{ + sym.GetUltimate().detailsIf()}) { + if (details->cudaDataAttr() && + (*details->cudaDataAttr() == common::CUDADataAttr::Device || + *details->cudaDataAttr() == common::CUDADataAttr::Managed || + *details->cudaDataAttr() == common::CUDADataAttr::Unified)) { + // Descriptor is allocated on host when in host context. + if (Fortran::semantics::IsAllocatable(sym)) + return inDeviceSubprogram; + return true; + } + } + return false; +} + const Scope *FindCUDADeviceContext(const Scope *); std::optional GetCUDADataAttr(const Symbol *); diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index 413563fe95ca3..f31fbab41028c 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -693,6 +693,22 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter, if (ultimateSymbol.test(Fortran::semantics::Symbol::Flag::CrayPointee)) return builder.create(loc, fir::ReferenceType::get(ty)); + if (Fortran::semantics::NeedCUDAAlloc(ultimateSymbol)) { + fir::CUDADataAttributeAttr cudaAttr = + Fortran::lower::translateSymbolCUDADataAttribute(builder.getContext(), + ultimateSymbol); + llvm::SmallVector indices; + llvm::SmallVector elidedShape = + fir::factory::elideExtentsAlreadyInType(ty, shape); + llvm::SmallVector elidedLenParams = + fir::factory::elideLengthsAlreadyInType(ty, lenParams); + auto idxTy = builder.getIndexType(); + for (mlir::Value sh : elidedShape) + indices.push_back(builder.createConvert(loc, idxTy, sh)); + return builder.create(loc, ty, nm, symNm, cudaAttr, + lenParams, indices); + } + // Let the builder do all the heavy lifting. if (!Fortran::semantics::IsProcedurePointer(ultimateSymbol)) return builder.allocateLocal(loc, ty, nm, symNm, shape, lenParams, isTarg); @@ -927,6 +943,19 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter, }); } } + if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol())) { + auto *builder = &converter.getFirOpBuilder(); + mlir::Location loc = converter.getCurrentLocation(); + fir::ExtendedValue exv = + converter.getSymbolExtendedValue(var.getSymbol(), &symMap); + auto *sym = &var.getSymbol(); + converter.getFctCtx().attachCleanup([builder, loc, exv, sym]() { + fir::CUDADataAttributeAttr cudaAttr = + Fortran::lower::translateSymbolCUDADataAttribute( + builder->getContext(), *sym); + builder->create(loc, fir::getBase(exv), cudaAttr); + }); + } } //===----------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index a6da387637264..bd018d7f015b8 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -176,8 +176,9 @@ mlir::Value fir::FirOpBuilder::createRealConstant(mlir::Location loc, llvm_unreachable("should use builtin floating-point type"); } -static llvm::SmallVector -elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape) { +llvm::SmallVector +fir::factory::elideExtentsAlreadyInType(mlir::Type type, + mlir::ValueRange shape) { auto arrTy = mlir::dyn_cast(type); if (shape.empty() || !arrTy) return {}; @@ -191,8 +192,9 @@ elideExtentsAlreadyInType(mlir::Type type, mlir::ValueRange shape) { return dynamicShape; } -static llvm::SmallVector -elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams) { +llvm::SmallVector +fir::factory::elideLengthsAlreadyInType(mlir::Type type, + mlir::ValueRange lenParams) { if (lenParams.empty()) return {}; if (auto arrTy = mlir::dyn_cast(type)) @@ -211,9 +213,9 @@ mlir::Value fir::FirOpBuilder::allocateLocal( // Convert the shape extents to `index`, as needed. llvm::SmallVector indices; llvm::SmallVector elidedShape = - elideExtentsAlreadyInType(ty, shape); + fir::factory::elideExtentsAlreadyInType(ty, shape); llvm::SmallVector elidedLenParams = - elideLengthsAlreadyInType(ty, lenParams); + fir::factory::elideLengthsAlreadyInType(ty, lenParams); auto idxTy = getIndexType(); for (mlir::Value sh : elidedShape) indices.push_back(createConvert(loc, idxTy, sh)); @@ -283,9 +285,9 @@ fir::FirOpBuilder::createTemporary(mlir::Location loc, mlir::Type type, mlir::ValueRange lenParams, llvm::ArrayRef attrs) { llvm::SmallVector dynamicShape = - elideExtentsAlreadyInType(type, shape); + fir::factory::elideExtentsAlreadyInType(type, shape); llvm::SmallVector dynamicLength = - elideLengthsAlreadyInType(type, lenParams); + fir::factory::elideLengthsAlreadyInType(type, lenParams); InsertPoint insPt; const bool hoistAlloc = dynamicShape.empty() && dynamicLength.empty(); if (hoistAlloc) { @@ -306,9 +308,9 @@ mlir::Value fir::FirOpBuilder::createHeapTemporary( mlir::ValueRange shape, mlir::ValueRange lenParams, llvm::ArrayRef attrs) { llvm::SmallVector dynamicShape = - elideExtentsAlreadyInType(type, shape); + fir::factory::elideExtentsAlreadyInType(type, shape); llvm::SmallVector dynamicLength = - elideLengthsAlreadyInType(type, lenParams); + fir::factory::elideLengthsAlreadyInType(type, lenParams); assert(!mlir::isa(type) && "cannot be a reference"); return create(loc, type, /*unique_name=*/llvm::StringRef{}, @@ -660,7 +662,8 @@ mlir::Value fir::FirOpBuilder::createBox(mlir::Location loc, mlir::Type boxType, mlir::Type valueOrSequenceType = fir::unwrapPassByRefType(boxType); return create( loc, boxType, addr, shape, slice, - elideLengthsAlreadyInType(valueOrSequenceType, lengths), tdesc); + fir::factory::elideLengthsAlreadyInType(valueOrSequenceType, lengths), + tdesc); } void fir::FirOpBuilder::dumpFunc() { getFunction().dump(); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 6773d0adced0c..5e6c18af2dd0f 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -4033,6 +4033,21 @@ mlir::LogicalResult fir::CUDADeallocateOp::verify() { return mlir::success(); } +void fir::CUDAAllocOp::build( + mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, + llvm::StringRef uniqName, llvm::StringRef bindcName, + fir::CUDADataAttributeAttr cudaAttr, mlir::ValueRange typeparams, + mlir::ValueRange shape, llvm::ArrayRef attributes) { + mlir::StringAttr nameAttr = + uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); + mlir::StringAttr bindcAttr = + bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); + build(builder, result, wrapAllocaResultType(inType), + mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape, + cudaAttr); + result.addAttributes(attributes); +} + //===----------------------------------------------------------------------===// // FIROpsDialect //===----------------------------------------------------------------------===// diff --git a/flang/test/Lower/CUDA/cuda-data-attribute.cuf b/flang/test/Lower/CUDA/cuda-data-attribute.cuf index 937c981bddd36..083a3cacc0206 100644 --- a/flang/test/Lower/CUDA/cuda-data-attribute.cuf +++ b/flang/test/Lower/CUDA/cuda-data-attribute.cuf @@ -62,4 +62,29 @@ end subroutine ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref {fir.bindc_name = "du", fir.cuda_attr = #fir.cuda}) ! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda, uniq_name = "_QMcuda_varFdummy_arg_unifiedEdu"} : (!fir.ref) -> (!fir.ref, !fir.ref) +subroutine cuda_alloc_free(n) + integer :: n + real, device :: a(10) + integer, unified :: u + real, managed :: b(n) +end + +! CHECK-LABEL: func.func @_QMcuda_varPcuda_alloc_free +! CHECK: %[[ALLOC_A:.*]] = fir.cuda_alloc !fir.array<10xf32> {bindc_name = "a", cuda_attr = #fir.cuda, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref> +! CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1> +! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[ALLOC_A]](%[[SHAPE]]) {cuda_attr = #fir.cuda, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + +! CHECK: %[[ALLOC_U:.*]] = fir.cuda_alloc i32 {bindc_name = "u", cuda_attr = #fir.cuda, uniq_name = "_QMcuda_varFcuda_alloc_freeEu"} -> !fir.ref +! CHECK: %[[DECL_U:.*]]:2 = hlfir.declare %[[ALLOC_U]] {cuda_attr = #fir.cuda, uniq_name = "_QMcuda_varFcuda_alloc_freeEu"} : (!fir.ref) -> (!fir.ref, !fir.ref) + +! CHECK: %[[ALLOC_B:.*]] = fir.cuda_alloc !fir.array, %{{.*}} : index {bindc_name = "b", cuda_attr = #fir.cuda, uniq_name = "_QMcuda_varFcuda_alloc_freeEb"} -> !fir.ref> +! CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1> +! CHECK: %[[DECL_B:.*]]:2 = hlfir.declare %[[ALLOC_B]](%[[SHAPE]]) {cuda_attr = #fir.cuda, uniq_name = "_QMcuda_varFcuda_alloc_freeEb"} : (!fir.ref>, !fir.shape<1>) -> (!fir.box>, !fir.ref>) + +! CHECK: fir.cuda_free %[[DECL_B]]#1 : !fir.ref> {cuda_attr = #fir.cuda} +! CHECK: fir.cuda_free %[[DECL_U]]#1 : !fir.ref {cuda_attr = #fir.cuda} +! CHECK: fir.cuda_free %[[DECL_A]]#1 : !fir.ref> {cuda_attr = #fir.cuda} + end module + +