diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h index 2adfd6f2510d4..c3cd119b96174 100644 --- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h @@ -459,6 +459,7 @@ struct IntrinsicLibrary { mlir::Value genTime(mlir::Type, llvm::ArrayRef); void genTMABulkCommitGroup(llvm::ArrayRef); void genTMABulkG2S(llvm::ArrayRef); + void genTMABulkS2G(llvm::ArrayRef); void genTMABulkWaitGroup(llvm::ArrayRef); mlir::Value genTrailz(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genTransfer(mlir::Type, diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 5fe2a76128e0d..e07baafcef0d7 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1027,6 +1027,10 @@ static constexpr IntrinsicHandler handlers[]{ {"dst", asAddr}, {"nbytes", asValue}}}, /*isElemental=*/false}, + {"tma_bulk_s2g", + &I::genTMABulkS2G, + {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}}, + /*isElemental=*/false}, {"tma_bulk_wait_group", &I::genTMABulkWaitGroup, {{}}, @@ -9227,6 +9231,17 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef args) { builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {}); } +// TMA_BULK_S2G (CUDA) +void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef args) { + assert(args.size() == 3); + mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[0]), + mlir::NVVM::NVVMMemorySpace::Shared); + mlir::Value dst = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]), + mlir::NVVM::NVVMMemorySpace::Global); + mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create( + builder, loc, dst, src, fir::getBase(args[2]), {}, {}); +} + // TMA_BULK_WAIT_GROUP (CUDA) void IntrinsicLibrary::genTMABulkWaitGroup( llvm::ArrayRef args) { diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90 index a8b9aa8b57ef9..22df9cdf410d5 100644 --- a/flang/module/cudadevice.f90 +++ b/flang/module/cudadevice.f90 @@ -2034,6 +2034,15 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes) end subroutine end interface + interface + attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes) + !dir$ ignore_tkr src, dst + integer(4), shared :: src(*) + integer(4), device :: dst(*) + integer(4), value :: nbytes + end subroutine + end interface + contains attributes(device) subroutine syncthreads() diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index 83ee0118638b2..29c348c5260a5 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -438,7 +438,7 @@ end subroutine ! CHECK: nvvm.cp.async.bulk.commit.group ! CHECK: nvvm.cp.async.bulk.wait_group 0 -attributes(global) subroutine test_bulk_g2s(c, a, b, n) +attributes(global) subroutine test_bulk_g2s(a) real(8), device :: a(*) real(8), shared :: tmpa(1024) integer(8), shared :: barrier1 @@ -448,3 +448,13 @@ end subroutine ! CHECK-LABEL: func.func @_QPtest_bulk_g2s ! CHECK: nvvm.cp.async.bulk.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : <7>, <1> + +attributes(global) subroutine test_bulk_s2g(a) + real(8), device :: a(*) + real(8), shared :: tmpa(1024) + integer(4) :: tx_count + call tma_bulk_s2g(tmpa, a(j), tx_count) +end subroutine + +! CHECK-LABEL: func.func @_QPtest_bulk_s2g +! CHECL: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>