Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,9 +989,18 @@ static constexpr IntrinsicHandler handlers[]{
{"mask", asBox, handleDynamicOptional}}},
/*isElemental=*/false},
{"syncthreads", &I::genSyncThreads, {}, /*isElemental=*/false},
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
{"syncthreads_and_i4", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
{"syncthreads_and_l4", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
{"syncthreads_count_i4",
&I::genSyncThreadsCount,
{},
/*isElemental=*/false},
{"syncthreads_count_l4",
&I::genSyncThreadsCount,
{},
/*isElemental=*/false},
{"syncthreads_or_i4", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
{"syncthreads_or_l4", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
{"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
{"system",
&I::genSystem,
Expand Down
33 changes: 21 additions & 12 deletions flang/module/cudadevice.f90
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,32 @@ module cudadevice
procedure :: syncthreads
end interface

interface
attributes(device) integer function syncthreads_and(value)
integer, value :: value
interface syncthreads_and
attributes(device) integer function syncthreads_and_i4(value)
integer(4), value :: value
end function
end interface
attributes(device) integer function syncthreads_and_l4(value)
logical(4), value :: value
end function
end interface syncthreads_and

interface
attributes(device) integer function syncthreads_count(value)
integer, value :: value
interface syncthreads_count
attributes(device) integer function syncthreads_count_i4(value)
integer(4), value :: value
end function
end interface
attributes(device) integer function syncthreads_count_l4(value)
logical(4), value :: value
end function
end interface syncthreads_count

interface
attributes(device) integer function syncthreads_or(value)
integer, value :: value
interface syncthreads_or
attributes(device) integer function syncthreads_or_i4(value)
integer(4), value :: value
end function
end interface
attributes(device) integer function syncthreads_or_l4(value)
logical(4), value :: value
end function
end interface syncthreads_or

interface
attributes(device) subroutine syncwarp(mask)
Expand Down
26 changes: 22 additions & 4 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@ attributes(global) subroutine devsub()
integer(8) :: al
integer(8) :: time
integer :: smalltime
integer(4) :: res
integer(4) :: res, offset
integer(8) :: resl

integer :: tid
tid = threadIdx%x

call syncthreads()
call syncwarp(1)
call threadfence()
call threadfence_block()
call threadfence_system()
ret = syncthreads_and(1)
res = syncthreads_and(tid > offset)
ret = syncthreads_count(1)
ret = syncthreads_count(tid > offset)
ret = syncthreads_or(1)
ret = syncthreads_or(tid > offset)

ai = atomicadd(ai, 1_4)
al = atomicadd(al, 1_8)
Expand Down Expand Up @@ -100,9 +106,21 @@ end
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1_i32_0) fastmath<contract> : (i32) -> i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1_i32_1) fastmath<contract> : (i32) -> i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1_i32_2) fastmath<contract> : (i32) -> i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%[[CMP]])
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%[[CMP]]) fastmath<contract> : (i1) -> i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%[[CMP]]) fastmath<contract> : (i1) -> i32
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32
Expand Down