Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Lower device to host and device to device transfer #87387

Merged
merged 6 commits into from
Apr 5, 2024

Conversation

clementval
Copy link
Contributor

Add more support for CUDA data transfer in assignment. This patch adds device to device and device to host support. If device symbols are present on the rhs, some implicit data transfer are initiated. A temporary is created and the data are transferred to the host. The expression is evaluated on the host and the assignment is done.

Copy link

github-actions bot commented Apr 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

flang/include/flang/Evaluate/tools.h Outdated Show resolved Hide resolved
unsigned deviceSymbols = 0;
for (const Symbol &sym : CollectSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to guard against derived type component symbol here, you may flag them as host while the base entity is on the device.

E.g:

  type t
    integer :: i(10)
  end type
  type(t), device :: tdev
  integer :: host(10)
  host = tdev%i + 10

tdev%i + 10 would likely be flagged as requiring data transfer because CollectSymbols will collect both tdev and i, that I think both have ObjectEntityDetails, but where tdev would have the device attribute and not i (it is a symbol belonging to the type definition scope).

You can detect component symbols using sym.owner().IsDerivedType().

I am assuming the following is not possible (even with pointer/allocatable attribute):

  type t
    integer, device :: i(10)
  end type

The current code is probably OK from a correctness point of view though (doing the transfer before the addition).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll update that.

I am assuming the following is not possible (even with pointer/allocatable attribute):

So it turns out this is possible with allocatable component:
From 3.2.1

Members of a derived type may not have the device attribute unless they are allocatable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it turns out this is possible with allocatable component

Then you should not apply my comment for allocatable components. BTW, what is the semantic of an allocatable component without the device attribute in a derived type instance that is device, is the allocatable component considered to be host or device?

 type t
  real, allocatable :: x(:)
 end type
 type(t), device :: a
 ! Is a%x considered to be on the host or on the device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semantic is not clearly stated. If members of the derived-type are not allocatable or pointer they are on the device. I would say that this should also apply to allocatable. Talking with Brent it seems that the correct use is to make the derived type managed or unified when it has allocatable components. I'll work with him so that we can have this written in the specs.
I'll add a TODO until we have a more precise semantic on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying this!

flang/lib/Lower/Bridge.cpp Outdated Show resolved Hide resolved
flang/lib/Lower/Bridge.cpp Show resolved Hide resolved
flang/lib/Lower/Bridge.cpp Show resolved Hide resolved
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:semantics labels Apr 3, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 3, 2024

@llvm/pr-subscribers-flang-semantics

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Add more support for CUDA data transfer in assignment. This patch adds device to device and device to host support. If device symbols are present on the rhs, some implicit data transfer are initiated. A temporary is created and the data are transferred to the host. The expression is evaluated on the host and the assignment is done.


Full diff: https://github.com/llvm/llvm-project/pull/87387.diff

3 Files Affected:

  • (modified) flang/include/flang/Evaluate/tools.h (+18)
  • (modified) flang/lib/Lower/Bridge.cpp (+86-11)
  • (modified) flang/test/Lower/CUDA/cuda-data-transfer.cuf (+42)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 8c872a0579c8ed..9a65d89a3333e0 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1240,6 +1240,24 @@ inline bool HasCUDAAttrs(const Expr<SomeType> &expr) {
   return false;
 }
 
+/// Check if the expression is a mix of host and device variables that require
+/// implicit data transfer.
+inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
+  unsigned hostSymbols{0};
+  unsigned deviceSymbols{0};
+  for (const Symbol &sym : CollectSymbols(expr)) {
+    if (const auto *details =
+            sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
+      if (details->cudaDataAttr()) {
+        ++deviceSymbols;
+      } else {
+        ++hostSymbols;
+      }
+    }
+  }
+  return hostSymbols > 0 && deviceSymbols > 0;
+}
+
 } // namespace Fortran::evaluate
 
 namespace Fortran::semantics {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 5bba0978617c79..771da038c66c7f 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3710,16 +3710,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return false;
   }
 
-  static void genCUDADataTransfer(fir::FirOpBuilder &builder,
-                                  mlir::Location loc, bool lhsIsDevice,
-                                  hlfir::Entity &lhs, bool rhsIsDevice,
-                                  hlfir::Entity &rhs) {
+  void genCUDADataTransfer(fir::FirOpBuilder &builder, mlir::Location loc,
+                           const Fortran::evaluate::Assignment &assign,
+                           hlfir::Entity &lhs, hlfir::Entity &rhs) {
+    bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
+    bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
     if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
       TODO(loc, "CUDA data transfler with descriptors");
+
+    // device = host
     if (lhsIsDevice && !rhsIsDevice) {
       auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
           builder.getContext(), fir::CUDADataTransferKind::HostDevice);
-      // device = host
       if (!rhs.isVariable()) {
         auto associate = hlfir::genAssociateExpr(
             loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
@@ -3732,7 +3734,71 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       }
       return;
     }
-    TODO(loc, "Assignement with CUDA Fortran variables");
+
+    // host = device
+    if (!lhsIsDevice && rhsIsDevice) {
+      auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+          builder.getContext(), fir::CUDADataTransferKind::DeviceHost);
+      if (!rhs.isVariable()) {
+        // evaluateRhs loads scalar. Look for the memory reference to be used in
+        // the transfer.
+        if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
+          auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
+          builder.create<fir::CUDADataTransferOp>(loc, loadOp.getMemref(), lhs,
+                                                  transferKindAttr);
+          return;
+        }
+      } else {
+        builder.create<fir::CUDADataTransferOp>(loc, rhs, lhs,
+                                                transferKindAttr);
+      }
+      return;
+    }
+
+    if (lhsIsDevice && rhsIsDevice) {
+      assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
+      auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+          builder.getContext(), fir::CUDADataTransferKind::DeviceDevice);
+      builder.create<fir::CUDADataTransferOp>(loc, rhs, lhs, transferKindAttr);
+      return;
+    }
+    llvm_unreachable("Unhandled CUDA data transfer");
+  }
+
+  llvm::SmallVector<mlir::Value>
+  genCUDAImplicitDataTransfer(fir::FirOpBuilder &builder, mlir::Location loc,
+                              const Fortran::evaluate::Assignment &assign) {
+    llvm::SmallVector<mlir::Value> temps;
+    localSymbols.pushScope();
+    auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+        builder.getContext(), fir::CUDADataTransferKind::DeviceHost);
+    unsigned nbDeviceResidentObject = 0;
+    for (const Fortran::semantics::Symbol &sym :
+         Fortran::evaluate::CollectSymbols(assign.rhs)) {
+      if (const auto *details =
+              sym.GetUltimate()
+                  .detailsIf<Fortran::semantics::ObjectEntityDetails>()) {
+        if (details->cudaDataAttr()) {
+          // TODO: This should probably being checked in semantic and give a
+          // proper error.
+          assert(
+              nbDeviceResidentObject <= 1 &&
+              "Only one reference to the device resident object is supported");
+          auto addr = getSymbolAddress(sym);
+          hlfir::Entity entity{addr};
+          auto [temp, cleanup] =
+              hlfir::createTempFromMold(loc, builder, entity);
+          auto needCleanup = fir::getIntIfConstant(cleanup);
+          if (needCleanup && *needCleanup)
+            temps.push_back(temp);
+          addSymbol(sym, temp, /*forced=*/true);
+          builder.create<fir::CUDADataTransferOp>(loc, addr, temp,
+                                                  transferKindAttr);
+          ++nbDeviceResidentObject;
+        }
+      }
+    }
+    return temps;
   }
 
   void genDataAssignment(
@@ -3741,8 +3807,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     mlir::Location loc = getCurrentLocation();
     fir::FirOpBuilder &builder = getFirOpBuilder();
 
-    bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
-    bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
+    bool isCUDATransfer = Fortran::evaluate::HasCUDAAttrs(assign.lhs) ||
+                          Fortran::evaluate::HasCUDAAttrs(assign.rhs);
+    bool hasCUDAImplicitTransfer =
+        Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
+    llvm::SmallVector<mlir::Value> implicitTemps;
+    if (hasCUDAImplicitTransfer)
+      implicitTemps = genCUDAImplicitDataTransfer(builder, loc, assign);
 
     // Gather some information about the assignment that will impact how it is
     // lowered.
@@ -3800,12 +3871,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       Fortran::lower::StatementContext localStmtCtx;
       hlfir::Entity rhs = evaluateRhs(localStmtCtx);
       hlfir::Entity lhs = evaluateLhs(localStmtCtx);
-      if (lhsIsDevice || rhsIsDevice) {
-        genCUDADataTransfer(builder, loc, lhsIsDevice, lhs, rhsIsDevice, rhs);
-      } else {
+      if (isCUDATransfer && !hasCUDAImplicitTransfer)
+        genCUDADataTransfer(builder, loc, assign, lhs, rhs);
+      else
         builder.create<hlfir::AssignOp>(loc, rhs, lhs,
                                         isWholeAllocatableAssignment,
                                         keepLhsLengthInAllocatableAssignment);
+      if (hasCUDAImplicitTransfer) {
+        localSymbols.popScope();
+        for (mlir::Value temp : implicitTemps)
+          builder.create<fir::FreeMemOp>(loc, temp);
       }
       return;
     }
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 54226b8623e6a9..1bca867e08905a 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -55,3 +55,45 @@ end
 ! CHECK: %[[ASSOC:.*]]:3 = hlfir.associate %[[ELEMENTAL]](%{{.*}}) {uniq_name = ".cuf_host_tmp"} : (!hlfir.expr<10xi32>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, i1)
 ! CHECK: fir.cuda_data_transfer %[[ASSOC]]#0 to %[[ADEV]]#0 {transfer_kind = #fir.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
 ! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<!fir.array<10xi32>>, i1
+
+subroutine sub2()
+  integer, device :: m
+  integer, device :: adev(10), bdev(10)
+  integer :: i, ahost(10), bhost(10)
+
+  ahost = adev
+
+  i = m
+
+  ahost(1:5) = adev(1:5)
+
+  bdev = adev
+
+  ! Implicit data transfer of adev before evaluation.
+  bhost = ahost + adev
+
+end
+
+! CHECK-LABEL: func.func @_QPsub2()
+! CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub2Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[BDEV:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub2Ebdev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[BHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub2Ebhost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[I:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[M:.*]]:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub2Em"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: fir.cuda_data_transfer %[[ADEV]]#0 to %[[AHOST]]#0 {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
+! CHECK: fir.cuda_data_transfer %[[M]]#0 to %[[I]]#0 {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
+
+! CHECK: %[[DES_ADEV:.*]] = hlfir.designate %[[ADEV]]#0 (%{{.*}}:%{{.*}}:%{{.*}})  shape %{{.*}} : (!fir.ref<!fir.array<10xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+! CHECK: %[[DES_AHOST:.*]] = hlfir.designate %[[AHOST]]#0 (%{{.*}}:%{{.*}}:%{{.*}})  shape %{{.*}} : (!fir.ref<!fir.array<10xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+! CHECK: fir.cuda_data_transfer %[[DES_ADEV]] to %[[DES_AHOST]] {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
+
+! CHECK: fir.cuda_data_transfer %[[ADEV]]#0 to %[[BDEV]]#0 {transfer_kind = #fir.cuda_transfer<device_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
+
+! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""}
+! CHECK: %[[DECL_TEMP:.*]]:2 = hlfir.declare %[[TEMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
+! CHECK: %[[ADEV_TEMP:.*]]:2 = hlfir.declare %21#0 {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub2Eadev"} : (!fir.heap<!fir.array<10xi32>>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
+! CHECK: fir.cuda_data_transfer %[[ADEV]]#1 to %[[DECL_TEMP]]#0 {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>
+! CHECK: %[[ELEMENTAL:.*]] = hlfir.elemental %4 unordered : (!fir.shape<1>) -> !hlfir.expr<10xi32> {
+! CHECK: hlfir.assign %[[ELEMENTAL]] to %[[BHOST]]#0 : !hlfir.expr<10xi32>, !fir.ref<!fir.array<10xi32>>
+! CHECK: fir.freemem %[[DECL_TEMP]]#0 : !fir.heap<!fir.array<10xi32>>

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 3, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Add more support for CUDA data transfer in assignment. This patch adds device to device and device to host support. If device symbols are present on the rhs, some implicit data transfer are initiated. A temporary is created and the data are transferred to the host. The expression is evaluated on the host and the assignment is done.


Full diff: https://github.com/llvm/llvm-project/pull/87387.diff

3 Files Affected:

  • (modified) flang/include/flang/Evaluate/tools.h (+18)
  • (modified) flang/lib/Lower/Bridge.cpp (+86-11)
  • (modified) flang/test/Lower/CUDA/cuda-data-transfer.cuf (+42)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 8c872a0579c8ed..9a65d89a3333e0 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1240,6 +1240,24 @@ inline bool HasCUDAAttrs(const Expr<SomeType> &expr) {
   return false;
 }
 
+/// Check if the expression is a mix of host and device variables that require
+/// implicit data transfer.
+inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
+  unsigned hostSymbols{0};
+  unsigned deviceSymbols{0};
+  for (const Symbol &sym : CollectSymbols(expr)) {
+    if (const auto *details =
+            sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
+      if (details->cudaDataAttr()) {
+        ++deviceSymbols;
+      } else {
+        ++hostSymbols;
+      }
+    }
+  }
+  return hostSymbols > 0 && deviceSymbols > 0;
+}
+
 } // namespace Fortran::evaluate
 
 namespace Fortran::semantics {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 5bba0978617c79..771da038c66c7f 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3710,16 +3710,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return false;
   }
 
-  static void genCUDADataTransfer(fir::FirOpBuilder &builder,
-                                  mlir::Location loc, bool lhsIsDevice,
-                                  hlfir::Entity &lhs, bool rhsIsDevice,
-                                  hlfir::Entity &rhs) {
+  void genCUDADataTransfer(fir::FirOpBuilder &builder, mlir::Location loc,
+                           const Fortran::evaluate::Assignment &assign,
+                           hlfir::Entity &lhs, hlfir::Entity &rhs) {
+    bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
+    bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
     if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
       TODO(loc, "CUDA data transfler with descriptors");
+
+    // device = host
     if (lhsIsDevice && !rhsIsDevice) {
       auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
           builder.getContext(), fir::CUDADataTransferKind::HostDevice);
-      // device = host
       if (!rhs.isVariable()) {
         auto associate = hlfir::genAssociateExpr(
             loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
@@ -3732,7 +3734,71 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       }
       return;
     }
-    TODO(loc, "Assignement with CUDA Fortran variables");
+
+    // host = device
+    if (!lhsIsDevice && rhsIsDevice) {
+      auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+          builder.getContext(), fir::CUDADataTransferKind::DeviceHost);
+      if (!rhs.isVariable()) {
+        // evaluateRhs loads scalar. Look for the memory reference to be used in
+        // the transfer.
+        if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
+          auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
+          builder.create<fir::CUDADataTransferOp>(loc, loadOp.getMemref(), lhs,
+                                                  transferKindAttr);
+          return;
+        }
+      } else {
+        builder.create<fir::CUDADataTransferOp>(loc, rhs, lhs,
+                                                transferKindAttr);
+      }
+      return;
+    }
+
+    if (lhsIsDevice && rhsIsDevice) {
+      assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
+      auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+          builder.getContext(), fir::CUDADataTransferKind::DeviceDevice);
+      builder.create<fir::CUDADataTransferOp>(loc, rhs, lhs, transferKindAttr);
+      return;
+    }
+    llvm_unreachable("Unhandled CUDA data transfer");
+  }
+
+  llvm::SmallVector<mlir::Value>
+  genCUDAImplicitDataTransfer(fir::FirOpBuilder &builder, mlir::Location loc,
+                              const Fortran::evaluate::Assignment &assign) {
+    llvm::SmallVector<mlir::Value> temps;
+    localSymbols.pushScope();
+    auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+        builder.getContext(), fir::CUDADataTransferKind::DeviceHost);
+    unsigned nbDeviceResidentObject = 0;
+    for (const Fortran::semantics::Symbol &sym :
+         Fortran::evaluate::CollectSymbols(assign.rhs)) {
+      if (const auto *details =
+              sym.GetUltimate()
+                  .detailsIf<Fortran::semantics::ObjectEntityDetails>()) {
+        if (details->cudaDataAttr()) {
+          // TODO: This should probably being checked in semantic and give a
+          // proper error.
+          assert(
+              nbDeviceResidentObject <= 1 &&
+              "Only one reference to the device resident object is supported");
+          auto addr = getSymbolAddress(sym);
+          hlfir::Entity entity{addr};
+          auto [temp, cleanup] =
+              hlfir::createTempFromMold(loc, builder, entity);
+          auto needCleanup = fir::getIntIfConstant(cleanup);
+          if (needCleanup && *needCleanup)
+            temps.push_back(temp);
+          addSymbol(sym, temp, /*forced=*/true);
+          builder.create<fir::CUDADataTransferOp>(loc, addr, temp,
+                                                  transferKindAttr);
+          ++nbDeviceResidentObject;
+        }
+      }
+    }
+    return temps;
   }
 
   void genDataAssignment(
@@ -3741,8 +3807,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     mlir::Location loc = getCurrentLocation();
     fir::FirOpBuilder &builder = getFirOpBuilder();
 
-    bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
-    bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
+    bool isCUDATransfer = Fortran::evaluate::HasCUDAAttrs(assign.lhs) ||
+                          Fortran::evaluate::HasCUDAAttrs(assign.rhs);
+    bool hasCUDAImplicitTransfer =
+        Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
+    llvm::SmallVector<mlir::Value> implicitTemps;
+    if (hasCUDAImplicitTransfer)
+      implicitTemps = genCUDAImplicitDataTransfer(builder, loc, assign);
 
     // Gather some information about the assignment that will impact how it is
     // lowered.
@@ -3800,12 +3871,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       Fortran::lower::StatementContext localStmtCtx;
       hlfir::Entity rhs = evaluateRhs(localStmtCtx);
       hlfir::Entity lhs = evaluateLhs(localStmtCtx);
-      if (lhsIsDevice || rhsIsDevice) {
-        genCUDADataTransfer(builder, loc, lhsIsDevice, lhs, rhsIsDevice, rhs);
-      } else {
+      if (isCUDATransfer && !hasCUDAImplicitTransfer)
+        genCUDADataTransfer(builder, loc, assign, lhs, rhs);
+      else
         builder.create<hlfir::AssignOp>(loc, rhs, lhs,
                                         isWholeAllocatableAssignment,
                                         keepLhsLengthInAllocatableAssignment);
+      if (hasCUDAImplicitTransfer) {
+        localSymbols.popScope();
+        for (mlir::Value temp : implicitTemps)
+          builder.create<fir::FreeMemOp>(loc, temp);
       }
       return;
     }
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 54226b8623e6a9..1bca867e08905a 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -55,3 +55,45 @@ end
 ! CHECK: %[[ASSOC:.*]]:3 = hlfir.associate %[[ELEMENTAL]](%{{.*}}) {uniq_name = ".cuf_host_tmp"} : (!hlfir.expr<10xi32>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, i1)
 ! CHECK: fir.cuda_data_transfer %[[ASSOC]]#0 to %[[ADEV]]#0 {transfer_kind = #fir.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
 ! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<!fir.array<10xi32>>, i1
+
+subroutine sub2()
+  integer, device :: m
+  integer, device :: adev(10), bdev(10)
+  integer :: i, ahost(10), bhost(10)
+
+  ahost = adev
+
+  i = m
+
+  ahost(1:5) = adev(1:5)
+
+  bdev = adev
+
+  ! Implicit data transfer of adev before evaluation.
+  bhost = ahost + adev
+
+end
+
+! CHECK-LABEL: func.func @_QPsub2()
+! CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub2Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[BDEV:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub2Ebdev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[BHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub2Ebhost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[I:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[M:.*]]:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub2Em"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: fir.cuda_data_transfer %[[ADEV]]#0 to %[[AHOST]]#0 {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
+! CHECK: fir.cuda_data_transfer %[[M]]#0 to %[[I]]#0 {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
+
+! CHECK: %[[DES_ADEV:.*]] = hlfir.designate %[[ADEV]]#0 (%{{.*}}:%{{.*}}:%{{.*}})  shape %{{.*}} : (!fir.ref<!fir.array<10xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+! CHECK: %[[DES_AHOST:.*]] = hlfir.designate %[[AHOST]]#0 (%{{.*}}:%{{.*}}:%{{.*}})  shape %{{.*}} : (!fir.ref<!fir.array<10xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
+! CHECK: fir.cuda_data_transfer %[[DES_ADEV]] to %[[DES_AHOST]] {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
+
+! CHECK: fir.cuda_data_transfer %[[ADEV]]#0 to %[[BDEV]]#0 {transfer_kind = #fir.cuda_transfer<device_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
+
+! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""}
+! CHECK: %[[DECL_TEMP:.*]]:2 = hlfir.declare %[[TEMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
+! CHECK: %[[ADEV_TEMP:.*]]:2 = hlfir.declare %21#0 {cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub2Eadev"} : (!fir.heap<!fir.array<10xi32>>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
+! CHECK: fir.cuda_data_transfer %[[ADEV]]#1 to %[[DECL_TEMP]]#0 {transfer_kind = #fir.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>
+! CHECK: %[[ELEMENTAL:.*]] = hlfir.elemental %4 unordered : (!fir.shape<1>) -> !hlfir.expr<10xi32> {
+! CHECK: hlfir.assign %[[ELEMENTAL]] to %[[BHOST]]#0 : !hlfir.expr<10xi32>, !fir.ref<!fir.array<10xi32>>
+! CHECK: fir.freemem %[[DECL_TEMP]]#0 : !fir.heap<!fir.array<10xi32>>

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments! LGTM

@clementval clementval merged commit 953aa10 into llvm:main Apr 5, 2024
4 checks passed
@clementval clementval deleted the cuf_device_to_host branch April 5, 2024 16:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants