diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index f66607dfa22f1..92a701a7b98c7 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3806,16 +3806,34 @@ class FirConverter : public Fortran::lower::AbstractConverter { return temps; } + // Check if the insertion point is currently in a device context. HostDevice + // subprogram are not considered fully device context so it will return false + // for it. + static bool isDeviceContext(fir::FirOpBuilder &builder) { + if (builder.getRegion().getParentOfType()) + return true; + if (auto funcOp = + builder.getRegion().getParentOfType()) { + if (auto cudaProcAttr = + funcOp.getOperation()->getAttrOfType( + fir::getCUDAAttrName())) { + return cudaProcAttr.getValue() != fir::CUDAProcAttribute::Host && + cudaProcAttr.getValue() != fir::CUDAProcAttribute::HostDevice; + } + } + return false; + } + void genDataAssignment( const Fortran::evaluate::Assignment &assign, const Fortran::evaluate::ProcedureRef *userDefinedAssignment) { mlir::Location loc = getCurrentLocation(); fir::FirOpBuilder &builder = getFirOpBuilder(); - bool isInDeviceContext = - builder.getRegion().getParentOfType(); - bool isCUDATransfer = Fortran::evaluate::HasCUDAAttrs(assign.lhs) || - Fortran::evaluate::HasCUDAAttrs(assign.rhs); + bool isInDeviceContext = isDeviceContext(builder); + bool isCUDATransfer = (Fortran::evaluate::HasCUDAAttrs(assign.lhs) || + Fortran::evaluate::HasCUDAAttrs(assign.rhs)) && + !isInDeviceContext; bool hasCUDAImplicitTransfer = Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs); llvm::SmallVector implicitTemps; @@ -3878,7 +3896,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { Fortran::lower::StatementContext localStmtCtx; hlfir::Entity rhs = evaluateRhs(localStmtCtx); hlfir::Entity lhs = evaluateLhs(localStmtCtx); - if (isCUDATransfer && !hasCUDAImplicitTransfer && !isInDeviceContext) + if (isCUDATransfer && !hasCUDAImplicitTransfer) genCUDADataTransfer(builder, loc, assign, lhs, rhs); else builder.create(loc, rhs, lhs, diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf index 70483685d2001..0a2608639bce7 100644 --- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf +++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf @@ -141,3 +141,21 @@ end subroutine ! CHECK: fir.cuda_kernel<<<*, *>>> ! CHECK-NOT: fir.cuda_data_transfer ! CHECK: hlfir.assign + +attributes(global) subroutine sub5(a) + integer, device :: a + integer :: i + a = i +end subroutine + +! CHECK-LABEL: func.func @_QPsub5 +! CHECK-NOT: fir.cuda_data_transfer + +attributes(host,device) subroutine sub6(a) + integer, device :: a + integer :: i + a = i +end subroutine + +! CHECK-LABEL: func.func @_QPsub6 +! CHECK: fir.cuda_data_transfer