Skip to content

Commit

Permalink
Modify AMDGPU threads heurisitics
Browse files Browse the repository at this point in the history
  • Loading branch information
luraess committed Jul 29, 2023
1 parent 4a23081 commit 13f95ea
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ end

function parallel_call_gpu(ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol; stream::Union{Symbol,Expr}=default_stream(package), shmem::Union{Symbol,Expr,Nothing}=nothing, launch::Bool=true, configcall::Expr=kernelcall)
maxsize = :(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)))
nthreads = :( ParallelStencil.ParallelKernel.compute_nthreads($maxsize) )
nthreads = :( ParallelStencil.ParallelKernel.compute_nthreads($maxsize, $package) )
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
parallel_call_gpu(ranges, nblocks, nthreads, kernelcall, backend_kwargs_expr, async, package; stream=stream, shmem=shmem, launch=launch)
end
Expand Down Expand Up @@ -495,9 +495,11 @@ function compute_ranges(maxsize)
return (1:maxsize[1], 1:maxsize[2], 1:maxsize[3])
end

function compute_nthreads(maxsize; nthreads_max=NTHREADS_MAX, flatdim=0) # This is a heuristic, which results in (32,8,1) threads, except if maxsize[1] < 32 or maxsize[2] < 8.
function compute_nthreads(maxsize, package; nthreads_max=NTHREADS_MAX, flatdim=0) # This is a heuristic, which results in (32,8,1) threads, except if maxsize[1] < 32 or maxsize[2] < 8.
maxsize = promote_maxsize(maxsize)
nthreads_x = min(32, (flatdim==1) ? 1 : maxsize[1])
if (package == PKG_CUDA) nthreads_x = min(32, (flatdim==1) ? 1 : maxsize[1])
elseif (package == PKG_AMDGPU) nthreads_x = min(128, (flatdim==1) ? 1 : maxsize[1])
end
nthreads_y = min(ceil(Int,nthreads_max/nthreads_x), (flatdim==2) ? 1 : maxsize[2])
nthreads_z = min(ceil(Int,nthreads_max/(nthreads_x*nthreads_y)), (flatdim==3) ? 1 : maxsize[3])
return (nthreads_x, nthreads_y , nthreads_z)
Expand Down

0 comments on commit 13f95ea

Please sign in to comment.