Skip to content

Conversation

@clementval
Copy link
Contributor

@clementval clementval requested a review from wangzpgi October 12, 2025 03:46
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 12, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 12, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/#load-and-store-functions-using-bulk-tma-operations


Full diff: https://github.com/llvm/llvm-project/pull/163034.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+1)
  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+33-10)
  • (modified) flang/module/cudadevice.f90 (+11)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+11)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 0e3c9aa22f994..2adfd6f2510d4 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -458,6 +458,7 @@ struct IntrinsicLibrary {
   mlir::Value genTanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
   void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
   void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genTransfer(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 7c5c5fb053ef2..5fe2a76128e0d 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1020,6 +1020,13 @@ static constexpr IntrinsicHandler handlers[]{
      &I::genTMABulkCommitGroup,
      {{}},
      /*isElemental=*/false},
+    {"tma_bulk_g2s",
+     &I::genTMABulkG2S,
+     {{{"barrier", asAddr},
+       {"src", asAddr},
+       {"dst", asAddr},
+       {"nbytes", asValue}}},
+     /*isElemental=*/false},
     {"tma_bulk_wait_group",
      &I::genTMABulkWaitGroup,
      {{}},
@@ -3200,17 +3207,17 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType,
   return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox);
 }
 
-static mlir::Value convertBarrierToLLVM(fir::FirOpBuilder &builder,
-                                        mlir::Location loc,
-                                        mlir::Value barrier) {
+static mlir::Value convertPtrToNVVMSpace(fir::FirOpBuilder &builder,
+                                         mlir::Location loc,
+                                         mlir::Value barrier,
+                                         mlir::NVVM::NVVMMemorySpace space) {
   mlir::Value llvmPtr = fir::ConvertOp::create(
       builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
       barrier);
   mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create(
       builder, loc,
-      mlir::LLVM::LLVMPointerType::get(
-          builder.getContext(),
-          static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)),
+      mlir::LLVM::LLVMPointerType::get(builder.getContext(),
+                                       static_cast<unsigned>(space)),
       llvmPtr);
   return addrCast;
 }
@@ -3220,7 +3227,8 @@ mlir::Value
 IntrinsicLibrary::genBarrierArrive(mlir::Type resultType,
                                    llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 1);
-  mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
+  mlir::Value barrier = convertPtrToNVVMSpace(
+      builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
   return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType,
                                                     barrier)
       .getResult();
@@ -3231,7 +3239,8 @@ mlir::Value
 IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
                                       llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 2);
-  mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
+  mlir::Value barrier = convertPtrToNVVMSpace(
+      builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
   mlir::Value token = fir::AllocaOp::create(builder, loc, resultType);
   // TODO: the MBarrierArriveExpectTxOp is not taking the state argument and
   // currently just the sink symbol `_`.
@@ -3244,8 +3253,8 @@ IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
 // BARRIER_INIT (CUDA)
 void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
   assert(args.size() == 2);
-  mlir::Value barrier =
-      convertBarrierToLLVM(builder, loc, fir::getBase(args[0]));
+  mlir::Value barrier = convertPtrToNVVMSpace(
+      builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
   mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier,
                                            fir::getBase(args[1]), {});
   auto kind = mlir::NVVM::ProxyKindAttr::get(
@@ -9204,6 +9213,20 @@ void IntrinsicLibrary::genTMABulkCommitGroup(
   mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc);
 }
 
+// TMA_BULK_G2S (CUDA)
+void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 4);
+  mlir::Value barrier = convertPtrToNVVMSpace(
+      builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
+  mlir::Value dst =
+      convertPtrToNVVMSpace(builder, loc, fir::getBase(args[2]),
+                            mlir::NVVM::NVVMMemorySpace::SharedCluster);
+  mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
+                                          mlir::NVVM::NVVMMemorySpace::Global);
+  mlir::NVVM::CpAsyncBulkGlobalToSharedClusterOp::create(
+      builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
+}
+
 // TMA_BULK_WAIT_GROUP (CUDA)
 void IntrinsicLibrary::genTMABulkWaitGroup(
     llvm::ArrayRef<fir::ExtendedValue> args) {
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 106f3e20aeaee..b9e4c2848cd00 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -2023,6 +2023,17 @@ attributes(device) subroutine tma_bulk_wait_group()
     end subroutine
   end interface
 
+  ! Generic load, count is in bytes
+  interface
+    attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
+      !dir$ ignore_tkr src, dst
+      integer(8), shared :: barrier
+      integer(4), device :: src(*)
+      integer(4), shared :: dst(*)
+      integer(4), value  :: nbytes
+    end subroutine
+  end interface
+
 contains
 
   attributes(device) subroutine syncthreads()
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 697b17b2cf2b1..83ee0118638b2 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -437,3 +437,14 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtest_tma()
 ! CHECK: nvvm.cp.async.bulk.commit.group
 ! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_bulk_g2s(c, a, b, n)
+  real(8), device :: a(*)
+  real(8), shared :: tmpa(1024)
+  integer(8), shared :: barrier1
+  integer(4) :: tx_count
+  call tma_bulk_g2s(barrier1, a(j), tmpa, tx_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_bulk_g2s
+! CHECK: nvvm.cp.async.bulk.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : <7>, <1>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants