Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Builder/IntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ struct IntrinsicLibrary {
mlir::Value genTanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genTransfer(mlir::Type,
Expand Down
43 changes: 33 additions & 10 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,13 @@ static constexpr IntrinsicHandler handlers[]{
&I::genTMABulkCommitGroup,
{{}},
/*isElemental=*/false},
{"tma_bulk_g2s",
&I::genTMABulkG2S,
{{{"barrier", asAddr},
{"src", asAddr},
{"dst", asAddr},
{"nbytes", asValue}}},
/*isElemental=*/false},
{"tma_bulk_wait_group",
&I::genTMABulkWaitGroup,
{{}},
Expand Down Expand Up @@ -3200,17 +3207,17 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType,
return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox);
}

static mlir::Value convertBarrierToLLVM(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value barrier) {
static mlir::Value convertPtrToNVVMSpace(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value barrier,
mlir::NVVM::NVVMMemorySpace space) {
mlir::Value llvmPtr = fir::ConvertOp::create(
builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
barrier);
mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create(
builder, loc,
mlir::LLVM::LLVMPointerType::get(
builder.getContext(),
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)),
mlir::LLVM::LLVMPointerType::get(builder.getContext(),
static_cast<unsigned>(space)),
llvmPtr);
return addrCast;
}
Expand All @@ -3220,7 +3227,8 @@ mlir::Value
IntrinsicLibrary::genBarrierArrive(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 1);
mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
mlir::Value barrier = convertPtrToNVVMSpace(
builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType,
barrier)
.getResult();
Expand All @@ -3231,7 +3239,8 @@ mlir::Value
IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2);
mlir::Value barrier = convertBarrierToLLVM(builder, loc, args[0]);
mlir::Value barrier = convertPtrToNVVMSpace(
builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
mlir::Value token = fir::AllocaOp::create(builder, loc, resultType);
// TODO: the MBarrierArriveExpectTxOp is not taking the state argument and
// currently just the sink symbol `_`.
Expand All @@ -3244,8 +3253,8 @@ IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
// BARRIER_INIT (CUDA)
void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 2);
mlir::Value barrier =
convertBarrierToLLVM(builder, loc, fir::getBase(args[0]));
mlir::Value barrier = convertPtrToNVVMSpace(
builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier,
fir::getBase(args[1]), {});
auto kind = mlir::NVVM::ProxyKindAttr::get(
Expand Down Expand Up @@ -9204,6 +9213,20 @@ void IntrinsicLibrary::genTMABulkCommitGroup(
mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc);
}

// TMA_BULK_G2S (CUDA)
void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 4);
mlir::Value barrier = convertPtrToNVVMSpace(
builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
mlir::Value dst =
convertPtrToNVVMSpace(builder, loc, fir::getBase(args[2]),
mlir::NVVM::NVVMMemorySpace::SharedCluster);
mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
mlir::NVVM::NVVMMemorySpace::Global);
mlir::NVVM::CpAsyncBulkGlobalToSharedClusterOp::create(
builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
}

// TMA_BULK_WAIT_GROUP (CUDA)
void IntrinsicLibrary::genTMABulkWaitGroup(
llvm::ArrayRef<fir::ExtendedValue> args) {
Expand Down
11 changes: 11 additions & 0 deletions flang/module/cudadevice.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,17 @@ attributes(device) subroutine tma_bulk_wait_group()
end subroutine
end interface

! Generic load, count is in bytes
interface
attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
!dir$ ignore_tkr src, dst
integer(8), shared :: barrier
integer(4), device :: src(*)
integer(4), shared :: dst(*)
integer(4), value :: nbytes
end subroutine
end interface

contains

attributes(device) subroutine syncthreads()
Expand Down
11 changes: 11 additions & 0 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,14 @@ end subroutine
! CHECK-LABEL: func.func @_QPtest_tma()
! CHECK: nvvm.cp.async.bulk.commit.group
! CHECK: nvvm.cp.async.bulk.wait_group 0

attributes(global) subroutine test_bulk_g2s(c, a, b, n)
real(8), device :: a(*)
real(8), shared :: tmpa(1024)
integer(8), shared :: barrier1
integer(4) :: tx_count
call tma_bulk_g2s(barrier1, a(j), tmpa, tx_count)
end subroutine

! CHECK-LABEL: func.func @_QPtest_bulk_g2s
! CHECK: nvvm.cp.async.bulk.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : <7>, <1>