Skip to content

Commit

Permalink
Merge pull request #134 from omlins/ndims-gen
Browse files Browse the repository at this point in the history
Enable ndims expansion and interpolation into kernel signature
  • Loading branch information
omlins committed Jan 4, 2024
2 parents 65233fb + d38781f commit 63f0d8f
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 49 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Expand Up @@ -12,9 +12,10 @@ jobs:
fail-fast: false
matrix:
version:
# - '1.8' # Minimum required Julia version (due to CellArrays' AMDGPU dependency 1.7 and due to Enzyme 1.8).
#- '1.8' # Minimum required Julia version (due to CellArrays' AMDGPU dependency 1.7 and due to Enzyme 1.8).
- '1.9'
- '1' # Latest stable 1.x release of Julia
# - 'nightly'
#- 'nightly'
os:
- ubuntu-latest
- macOS-latest
Expand Down
4 changes: 3 additions & 1 deletion src/ParallelKernel/parallel.jl
Expand Up @@ -40,10 +40,12 @@ const PARALLEL_INDICES_DOC = """
@parallel_indices indices kernel
@parallel_indices indices inbounds=... kernel
Declare the `kernel` parallel and generate the given parallel `indices` inside the `kernel` using the package for parallelization selected with [`@init_parallel_kernel`](@ref).
# Keyword arguments
- `inbounds::Bool`: whether to apply `@inbounds` to the kernel. The default is `false` or as set with the `inbounds` keyword argument of [`@init_parallel_kernel`](@ref).
Declare the `kernel` parallel and generate the given parallel `indices` inside the `kernel` using the package for parallelization selected with [`@init_parallel_kernel`](@ref).
See also: [`@init_parallel_kernel`](@ref)
"""
@doc PARALLEL_INDICES_DOC
macro parallel_indices(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__module__, args...)); end
Expand Down
26 changes: 25 additions & 1 deletion src/ParallelKernel/shared.jl
Expand Up @@ -130,6 +130,7 @@ get_body(kernel::Expr) = return kernel.args[2]
set_body!(kernel::Expr, body::Expr) = ((kernel.args[2] = body); return kernel)
get_name(kernel::Expr) = return splitdef(kernel)[:name]


function set_name(kernel::Expr, name::Symbol)
kernel_elems = splitdef(kernel)
kernel_elems[:name] = name
Expand All @@ -144,6 +145,29 @@ function push_to_signature!(kernel::Expr, arg::Expr)
return kernel
end

function substitute_in_kernel(kernel::Expr, old, new; signature_only=false)
if signature_only
kernel_elems = splitdef(kernel)
body = kernel_elems[:body] # save to restore later
kernel_elems[:body] = :(return)
kernel = combinedef(kernel_elems)
end
kernel = substitute(kernel, old, new)
if signature_only
kernel_elems = splitdef(kernel)
kernel_elems[:body] = body
kernel = combinedef(kernel_elems)
end
return kernel
end

function in_signature(kernel::Expr, x)
kernel_elems = splitdef(kernel)
kernel_elems[:body] = :()
signature = combinedef(kernel_elems)
return inexpr_walk(signature, x)
end

function remove_return(body::Expr)
if !(body.args[end] in [:(return), :(return nothing), :(nothing)])
@ArgumentError("invalid kernel in @parallel kernel definition: the last statement must be a `return nothing` statement ('return' or 'return nothing' or 'nothing') as required for any GPU kernels.")
Expand Down Expand Up @@ -262,7 +286,7 @@ end

function split_kwargs(kwargs; separator=:(=), keyword_type=Symbol)
if !all(is_kwarg.(kwargs; separator=separator, keyword_type=keyword_type)) @ModuleInternalError("not all of kwargs are keyword arguments.") end
return Dict(x.args[1] => x.args[2] for x in kwargs)
return Dict{keyword_type,Any}(x.args[1] => x.args[2] for x in kwargs)
end

function validate_kwargkeys(kwargs::Dict, valid_kwargs::Tuple, macroname::String)
Expand Down
105 changes: 75 additions & 30 deletions src/parallel.jl
Expand Up @@ -9,7 +9,8 @@ Declare the `kernel` parallel and containing stencil computations be performed w
# Optional keyword arguments
- `inbounds::Bool`: whether to apply `@inbounds` to the kernel. The default is `false` or as set with the `inbounds` keyword argument of [`@init_parallel_stencil`](@ref).
- `memopt::Bool=false`: whether to perform advanced stencil-specific on-chip memory optimisations. If `memopt=true` is set, then it must also be set in the corresponding kernel call(s).
- `ndims::Integer`: the number of dimensions used for the stencil computations in the kernels: 1, 2 or 3. A default can be set with the `ndims` keyword argument of [`@init_parallel_stencil`](@ref).
!!! note "Advanced optional keyword arguments"
- `ndims::Integer|Tuple`: the number of dimensions used for the stencil computations in the kernels: 1, 2 or 3 (or a tuple containing any of the previous in order to generate a method for each of the given values - this can only work correctly if the macros used *and loaded* work for any of the chosen values of `ndims`!). A default can be set with the `ndims` keyword argument of [`@init_parallel_stencil`](@ref). The value of `ndims` can be interpolated into the kernel method signatures with `\$ndims` (e.g., `@parallel ndims=(1,3) function f(A::Data.Array{\$ndims}) ... end`). This enables dispatching on the number of dimensions in the kernel methods.
See also: [`@init_parallel_stencil`](@ref)
Expand Down Expand Up @@ -53,8 +54,20 @@ See also: [`@init_parallel_kernel`](@ref)
macro parallel(args...) check_initialized(__module__); checkargs_parallel(args...); esc(parallel(__source__, __module__, args...)); end


##
const PARALLEL_INDICES_DOC = """
$(replace(ParallelKernel.PARALLEL_INDICES_DOC, "@init_parallel_kernel" => "@init_parallel_stencil")) Using splat syntax for the `indices` (e.g., `@parallel_indices (I...)`) enables to generate a tuple of parallel indices (`I` in this example) of length `ndims` selected with [`@init_parallel_stencil`](@ref). This makes it possible to write kernels that are agnostic to the number of dimensions (writing, e.g., `A[I...]` to access elements of the array `A`).
@parallel_indices indices kernel
@parallel_indices indices inbounds=... memopt=... ndims=... kernel
Declare the `kernel` parallel and generate the given parallel `indices` inside the `kernel` using the package for parallelization selected with [`@init_parallel_stencil`](@ref).
# Optional keyword arguments
- `inbounds::Bool`: whether to apply `@inbounds` to the kernel. The default is `false` or as set with the `inbounds` keyword argument of [`@init_parallel_stencil`](@ref).
- `memopt::Bool=false`: whether to perform advanced stencil-specific on-chip memory optimisations. If `memopt=true` is set, then it must also be set in the corresponding kernel call(s).
!!! note "Advanced optional keyword arguments"
- `ndims::Integer|Tuple`: the number of indexing dimensions desired when using splat syntax for the `indices`: 1, 2 or 3 (a default `ndims` value can be set with the corresponding keyword argument of [`@init_parallel_stencil`](@ref)), or a tuple containing any of the previous in order to generate a method for each of the given `ndims` values (e.g., `@parallel_indices (I...) ndims=(2,3)`). Concretely, the splat syntax generates a tuple of parallel indices (`I` in this example) where the length is given by the `ndims` value (`2` for the first method and `3` for the second method in this example). This makes it possible to write kernels that are agnostic to the number of dimensions (writing, e.g., `A[I...]` to access elements of the array `A`). The value of `ndims` can be interpolated into the kernel method signatures with `\$ndims` (e.g., `@parallel ndims=(1,3) function f(A::Data.Array{\$ndims}) ... end`). This enables dispatching on the number of dimensions in the kernel methods.
See also: [`@init_parallel_stencil`](@ref)
"""
@doc PARALLEL_INDICES_DOC
macro parallel_indices(args...) check_initialized(__module__); checkargs_parallel_indices(args...); esc(parallel_indices(__source__, __module__, args...)); end
Expand Down Expand Up @@ -111,15 +124,25 @@ parallel_async(source::LineNumberNode, caller::Module, args::Union{Symbol,Expr}.
function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Expr}...; package::Symbol=get_package(caller), async::Bool=false)
if is_kernel(args[end])
posargs, kwargs_expr, kernelarg = split_parallel_args(args, is_call=false)
kwargs = extract_kwargs(caller, kwargs_expr, (:ndims, :inbounds, :memopt, :optvars, :loopdim, :loopsize, :optranges, :useshmemhalos, :optimize_halo_read, :metadata_module, :metadata_function), "@parallel <kernel>"; eval_args=(:ndims, :inbounds, :memopt, :loopdim, :optranges, :useshmemhalos, :optimize_halo_read, :metadata_module))
numbertype = get_numbertype(caller)
if !haskey(kwargs, :metadata_module)
get_name(kernelarg)
metadata_module, metadata_function = create_metadata_storage(source, caller, kernelarg)
kwargs = extract_kwargs(caller, kwargs_expr, (:ndims, :inbounds, :memopt, :optvars, :loopdim, :loopsize, :optranges, :useshmemhalos, :optimize_halo_read, :metadata_module, :metadata_function), "@parallel <kernel>"; eval_args=(:ndims, :inbounds, :memopt, :loopdim, :optranges, :useshmemhalos, :optimize_halo_read, :metadata_module))
ndims = haskey(kwargs, :ndims) ? kwargs.ndims : get_ndims(caller)
is_parallel_kernel = true
if typeof(ndims) <: Tuple
expand_ndims_tuple(ndims, is_parallel_kernel, kernelarg, kwargs, posargs...)
else
metadata_module, metadata_function = kwargs.metadata_module, kwargs.metadata_function
if in_signature(kernelarg, :($(Expr(:$, :ndims))))
interpolate_ndims(ndims, is_parallel_kernel, kernelarg, kwargs_expr, posargs...)
else
numbertype = get_numbertype(caller)
if !haskey(kwargs, :metadata_module)
get_name(kernelarg)
metadata_module, metadata_function = create_metadata_storage(source, caller, kernelarg)
else
metadata_module, metadata_function = kwargs.metadata_module, kwargs.metadata_function
end
parallel_kernel(metadata_module, metadata_function, caller, package, ndims, numbertype, kernelarg, posargs...; kwargs)
end
end
parallel_kernel(metadata_module, metadata_function, caller, package, numbertype, kernelarg, posargs...; kwargs)
elseif is_call(args[end])
posargs, kwargs_expr, kernelarg = split_parallel_args(args)
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:memopt, :configcall, :∇, :ad_mode, :ad_annotations), "@parallel <kernelcall>", true; eval_args=(:memopt,))
Expand All @@ -146,35 +169,58 @@ function parallel_indices(source::LineNumberNode, caller::Module, args::Union{Sy
posargs, kwargs_expr, kernelarg = split_parallel_args(args, is_call=false)
kwargs = extract_kwargs(caller, kwargs_expr, (:ndims, :inbounds, :memopt, :optvars, :loopdim, :loopsize, :optranges, :useshmemhalos, :optimize_halo_read, :metadata_module, :metadata_function), "@parallel_indices"; eval_args=(:ndims, :inbounds, :memopt, :loopdim, :optranges, :useshmemhalos, :optimize_halo_read, :metadata_module))
indices_expr = posargs[1]
if is_splatarg(indices_expr)
parallel_indices_splatarg(caller, package, kwargs_expr, posargs..., kernelarg; kwargs)
ndims = haskey(kwargs, :ndims) ? kwargs.ndims : get_ndims(caller)
if typeof(ndims) <: Tuple
expand_ndims_tuple(ndims, is_parallel_kernel, kernelarg, kwargs, posargs...)
else
if !haskey(kwargs, :metadata_module)
get_name(kernelarg)
metadata_module, metadata_function = create_metadata_storage(source, caller, kernelarg)
if in_signature(kernelarg, :($(Expr(:$, :ndims))))
interpolate_ndims(ndims, is_parallel_kernel, kernelarg, kwargs_expr, posargs...)
elseif is_splatarg(indices_expr)
parallel_indices_splatarg(caller, package, ndims, kwargs_expr, posargs..., kernelarg; kwargs)
else
metadata_module, metadata_function = kwargs.metadata_module, kwargs.metadata_function
end
inbounds = haskey(kwargs, :inbounds) ? kwargs.inbounds : get_inbounds(caller)
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt(caller)
if memopt
quote
$(parallel_indices_memopt(metadata_module, metadata_function, is_parallel_kernel, caller, package, posargs..., kernelarg; kwargs...)) #TODO: the package and numbertype will have to be passed here further once supported as kwargs (currently removed from call: package, numbertype, )
$metadata_function
if !haskey(kwargs, :metadata_module)
get_name(kernelarg)
metadata_module, metadata_function = create_metadata_storage(source, caller, kernelarg)
else
metadata_module, metadata_function = kwargs.metadata_module, kwargs.metadata_function
end
inbounds = haskey(kwargs, :inbounds) ? kwargs.inbounds : get_inbounds(caller)
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt(caller)
if memopt
quote
$(parallel_indices_memopt(metadata_module, metadata_function, is_parallel_kernel, caller, package, posargs..., kernelarg; kwargs...)) #TODO: the package and numbertype will have to be passed here further once supported as kwargs (currently removed from call: package, numbertype, )
$metadata_function
end
else
kwargs_expr = :(inbounds=$inbounds)
ParallelKernel.parallel_indices(caller, posargs..., kwargs_expr, kernelarg; package=package)
end
else
kwargs_expr = :(inbounds=$inbounds)
ParallelKernel.parallel_indices(caller, posargs..., kwargs_expr, kernelarg; package=package)
end
end
end


## @PARALLEL KERNEL FUNCTIONS

function parallel_indices_splatarg(caller::Module, package::Symbol, kwargs_expr, alias_indices::Expr, kernel::Expr; kwargs::NamedTuple)
function expand_ndims_tuple(ndims::Tuple, is_parallel_kernel::Bool, kernel::Expr, kwargs::NamedTuple, posargs...)
if !(typeof(ndims) <: NTuple{N,<:Integer} where N) @ArgumentError("$macroname: argument 'ndims' must be an integer or a tuple of integers (obtained: $ndims).") end
kwargs_expr = (:($key=$(getproperty(kwargs, key))) for key in keys(kwargs) if key != :ndims)
if (is_parallel_kernel) ndims_methods_expr = (:(@parallel $(posargs...) ndims=$i $(kwargs_expr...) $kernel) for i in ndims)
else ndims_methods_expr = (:(@parallel_indices $(posargs...) ndims=$i $(kwargs_expr...) $kernel) for i in ndims)
end
return quote $(ndims_methods_expr...) end
end

function interpolate_ndims(ndims::Integer, is_parallel_kernel::Bool, kernel::Expr, kwargs_expr, posargs...)
if (ndims < 1 || ndims > 3) @ArgumentError("$macroname: argument 'ndims' is invalid or missing (valid values are 1, 2 or 3; 'ndims' an be set globally in @init_parallel_stencil and overwritten per kernel if needed).") end
kernel = substitute_in_kernel(kernel, :($(Expr(:$, :ndims))), ndims; signature_only=true)
if (is_parallel_kernel) return :(@parallel $(posargs...) $(kwargs_expr...) $kernel)
else return :(@parallel_indices $(posargs...) $(kwargs_expr...) $kernel)
end
end

function parallel_indices_splatarg(caller::Module, package::Symbol, ndims::Integer, kwargs_expr, alias_indices::Expr, kernel::Expr; kwargs::NamedTuple)
if !@capture(alias_indices, (I_...)) @ArgumentError("@parallel_indices: argument 'indices' must be a tuple of indices, a single index or a variable followed by the splat operator representing a tuple of indices (e.g. (ix, iy, iz) or (ix, iy) or ix or I...).") end
ndims = haskey(kwargs, :ndims) ? kwargs.ndims : get_ndims(caller)
if (ndims < 1 || ndims > 3) @ArgumentError("@parallel_indices: argument 'ndims' is required for the syntax `@parallel_indices I...`` and is invalid or missing (valid values are 1, 2 or 3; 'ndims' an be set globally in @init_parallel_stencil and overwritten per kernel if needed).") end
indices = get_indices_expr(ndims).args
indices_expr = Expr(:tuple, indices...)
Expand All @@ -193,12 +239,11 @@ function parallel_indices_memopt(metadata_module::Module, metadata_function::Exp
body = add_return(body)
set_body!(kernel, body)
indices = extract_tuple(indices)
return :(@parallel_indices $(Expr(:tuple, indices[1:end-1]...)) inbounds=$inbounds memopt=false metadata_module=$metadata_module metadata_function=$metadata_function $kernel) #TODO: the package and numbertype will have to be passed here further once supported as kwargs (currently removed from signature: package::Symbol, numbertype::DataType, )
return :(@parallel_indices $(Expr(:tuple, indices[1:end-1]...)) ndims=$ndims inbounds=$inbounds memopt=false metadata_module=$metadata_module metadata_function=$metadata_function $kernel) #TODO: the package and numbertype will have to be passed here further once supported as kwargs (currently removed from signature: package::Symbol, numbertype::DataType, )
end

function parallel_kernel(metadata_module::Module, metadata_function::Expr, caller::Module, package::Symbol, numbertype::DataType, kernel::Expr; kwargs::NamedTuple)
function parallel_kernel(metadata_module::Module, metadata_function::Expr, caller::Module, package::Symbol, ndims::Integer, numbertype::DataType, kernel::Expr; kwargs::NamedTuple)
is_parallel_kernel = true
ndims = haskey(kwargs, :ndims) ? kwargs.ndims : get_ndims(caller)
if (ndims < 1 || ndims > 3) @ArgumentError("@parallel: argument 'ndims' is invalid or missing (valid values are 1, 2 or 3; 'ndims' an be set globally in @init_parallel_stencil and overwritten per kernel if needed).") end
inbounds = haskey(kwargs, :inbounds) ? kwargs.inbounds : get_inbounds(caller)
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt(caller)
Expand Down
2 changes: 1 addition & 1 deletion src/shared.jl
Expand Up @@ -9,7 +9,7 @@ elseif ENABLE_AMDGPU
using AMDGPU
end
import MacroTools: @capture, postwalk, splitarg # NOTE: inexpr_walk used instead of MacroTools.inexpr
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, @isgpu, substitute, inexpr_walk, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, insert_device_types, is_kernel, is_call, gensym_world, isgpu, @isgpu, substitute, substitute_in_kernel, in_signature, inexpr_walk, add_inbounds, cast, @ranges, @rangelengths, @return_value, @return_nothing
import .ParallelKernel: PKG_CUDA, PKG_AMDGPU, PKG_THREADS, PKG_NONE, NUMBERTYPE_NONE, SUPPORTED_NUMBERTYPES, SUPPORTED_PACKAGES, ERRMSG_UNSUPPORTED_PACKAGE, INT_CUDA, INT_AMDGPU, INT_THREADS, INDICES, PKNumber, RANGES_VARNAME, RANGES_TYPE, RANGELENGTH_XYZ_TYPE, RANGELENGTHS_VARNAMES, THREADIDS_VARNAMES, GENSYM_SEPARATOR, AD_SUPPORTED_ANNOTATIONS
import .ParallelKernel: @require, @symbols, symbols, longnameof, @prettyexpand, @prettystring, prettystring, @gorgeousexpand, @gorgeousstring, gorgeousstring

Expand Down

0 comments on commit 63f0d8f

Please sign in to comment.