Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Builder/IntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ struct IntrinsicLibrary {
mlir::Value genBarrierArrive(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genBarrierArriveCnt(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genBesselJn(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genBesselYn(mlir::Type,
Expand Down
60 changes: 60 additions & 0 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -358,6 +359,14 @@ static constexpr IntrinsicHandler handlers[]{
&I::genBarrierInit,
{{{"barrier", asAddr}, {"count", asValue}}},
/*isElemental=*/false},
{"barrier_try_wait",
&I::genBarrierTryWait,
{{{"barrier", asAddr}, {"token", asValue}}},
/*isElemental=*/false},
{"barrier_try_wait_sleep",
&I::genBarrierTryWaitSleep,
{{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}},
/*isElemental=*/false},
{"bessel_jn",
&I::genBesselJn,
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
Expand Down Expand Up @@ -3282,6 +3291,57 @@ void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
}

// BARRIER_TRY_WAIT (CUDA)
mlir::Value
IntrinsicLibrary::genBarrierTryWait(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 2);
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
fir::StoreOp::create(builder, loc, zero, res);
mlir::Value ns =
builder.createIntegerConstant(loc, builder.getI32Type(), 1000000);
mlir::Value load = fir::LoadOp::create(builder, loc, res);
auto whileOp = mlir::scf::WhileOp::create(
builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load});
mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore());
mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc);
builder.setInsertionPointToStart(beforeBlock);
mlir::Value condition = mlir::arith::CmpIOp::create(
builder, loc, mlir::arith::CmpIPredicate::ne, beforeArg, zero);
mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg);
mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter());
afterBlock->addArgument(resultType, loc);
builder.setInsertionPointToStart(afterBlock);
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
mlir::Value ret =
mlir::NVVM::InlinePtxOp::create(
builder, loc, {resultType}, {barrier, args[1], ns}, {},
".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
"selp.b32 %0, 1, 0, p;",
{})
.getResult(0);
mlir::scf::YieldOp::create(builder, loc, ret);
builder.setInsertionPointAfter(whileOp);
return whileOp.getResult(0);
}

// BARRIER_TRY_WAIT_SLEEP (CUDA)
mlir::Value
IntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 3);
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
return mlir::NVVM::InlinePtxOp::create(
builder, loc, {resultType}, {barrier, args[1], args[2]}, {},
".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
"selp.b32 %0, 1, 0, p;",
{})
.getResult(0);
}

// BESSEL_JN
fir::ExtendedValue
IntrinsicLibrary::genBesselJn(mlir::Type resultType,
Expand Down
27 changes: 21 additions & 6 deletions flang/module/cudadevice.f90
Original file line number Diff line number Diff line change
Expand Up @@ -1998,22 +1998,37 @@ attributes(device,host) logical function on_device() bind(c)

! TMA Operations

interface barrier_arrive
attributes(device) function barrier_arrive(barrier) result(token)
integer(8), shared :: barrier
integer(8) :: token
end function
attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
integer(8), shared :: barrier
integer(4), value :: count
integer(8) :: token
end function
end interface

interface
attributes(device) subroutine barrier_init(barrier, count)
integer(8), shared :: barrier
integer(4), value :: count
end subroutine
end interface

interface barrier_arrive
attributes(device) function barrier_arrive(barrier) result(token)
interface
attributes(device) integer function barrier_try_wait(barrier, token)
integer(8), shared :: barrier
integer(8) :: token
integer(8), value :: token
end function
attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
end interface

interface
attributes(device) integer function barrier_try_wait_sleep(barrier, token, ns)
integer(8), shared :: barrier
integer(4), value :: count
integer(8) :: token
integer(8), value :: token
integer(4), value :: ns
end function
end interface

Expand Down
22 changes: 22 additions & 0 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,25 @@ end subroutine
! CHECK: %[[CASTED_CMP_XCHG_EV:.*]] = fir.convert %[[CMP_XCHG_EV]] : (i1) -> i32
! CHECK: %{{.*}} = arith.constant 1 : i32
! CHECK: %19 = arith.cmpi eq, %[[CASTED_CMP_XCHG_EV]], %{{.*}} : i32

attributes(global) subroutine test_barrier_try_wait()
integer :: istat
integer(8), shared :: barrier1
integer(8) :: token
istat = barrier_try_wait(barrier1, token)
end subroutine

! CHECK-LABEL: func.func @_QPtest_barrier_try_wait()
! CHECK: scf.while
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %{{.*}}, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %c1000000{{.*}} : !llvm.ptr, i64, i32) -> i32

attributes(global) subroutine test_barrier_try_wait_sleep()
integer :: istat
integer(8), shared :: barrier1
integer(8) :: token
integer(4) :: sleep_time
istat = barrier_try_wait_sleep(barrier1, token, sleep_time)
end subroutine

! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32