Skip to content

Commit

Permalink
Fix align and pass package to thread comp
Browse files Browse the repository at this point in the history
  • Loading branch information
luraess committed Jul 30, 2023
1 parent a98bb96 commit abc049c
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ end

function parallel_call_gpu(nblocks::Union{Symbol,Expr}, nthreads::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol; stream::Union{Symbol,Expr}=default_stream(package), shmem::Union{Symbol,Expr,Nothing}=nothing, launch::Bool=true, configcall::Expr=kernelcall)
maxsize = :( $nblocks .* $nthreads )
ranges = :(ParallelStencil.ParallelKernel.compute_ranges($maxsize))
ranges = :( ParallelStencil.ParallelKernel.compute_ranges($maxsize) )
parallel_call_gpu(ranges, nblocks, nthreads, kernelcall, backend_kwargs_expr, async, package; stream=stream, shmem=shmem, launch=launch)
end

function parallel_call_gpu(ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol; stream::Union{Symbol,Expr}=default_stream(package), shmem::Union{Symbol,Expr,Nothing}=nothing, launch::Bool=true, configcall::Expr=kernelcall)
maxsize = :(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)))
nthreads = :( ParallelStencil.ParallelKernel.compute_nthreads($maxsize, package) )
maxsize = :( length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)) )
nthreads = :( ParallelStencil.ParallelKernel.compute_nthreads($maxsize, $package) )
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
parallel_call_gpu(ranges, nblocks, nthreads, kernelcall, backend_kwargs_expr, async, package; stream=stream, shmem=shmem, launch=launch)
end
Expand Down Expand Up @@ -284,7 +284,7 @@ end

function parallel_call_threads(nblocks::Union{Symbol,Expr}, nthreads::Union{Symbol,Expr}, kernelcall::Expr, async::Bool; launch::Bool=true, configcall::Expr=kernelcall)
maxsize = :( $nblocks .* $nthreads )
ranges = :(ParallelStencil.ParallelKernel.compute_ranges($maxsize))
ranges = :( ParallelStencil.ParallelKernel.compute_ranges($maxsize) )
parallel_call_threads(ranges, kernelcall, async; launch=launch)
end

Expand Down Expand Up @@ -497,8 +497,9 @@ end

function compute_nthreads(maxsize, package; nthreads_max=NTHREADS_MAX, flatdim=0) # This is a heuristic, which results in (32,8,1) threads, except if maxsize[1] < 32 or maxsize[2] < 8.
maxsize = promote_maxsize(maxsize)
if (package == PKG_CUDA) nthreads_x = min(32, (flatdim==1) ? 1 : maxsize[1])
elseif (package == PKG_AMDGPU) nthreads_x = min(128, (flatdim==1) ? 1 : maxsize[1])
if (Symbol(package) == PKG_CUDA) nthreads_x = min(32, (flatdim==1) ? 1 : maxsize[1])
elseif (Symbol(package) == PKG_AMDGPU) nthreads_x = min(128, (flatdim==1) ? 1 : maxsize[1])
else nthreads_x = min(32, (flatdim==1) ? 1 : maxsize[1])
end
nthreads_y = min(ceil(Int,nthreads_max/nthreads_x), (flatdim==2) ? 1 : maxsize[2])
nthreads_z = min(ceil(Int,nthreads_max/(nthreads_x*nthreads_y)), (flatdim==3) ? 1 : maxsize[3])
Expand Down

0 comments on commit abc049c

Please sign in to comment.