Skip to content

Commit

Permalink
Merge pull request #102 from omlins/ad
Browse files Browse the repository at this point in the history
Add documentation, unit tests and small fixes for automatic differentiation
  • Loading branch information
omlins committed Jul 13, 2023
2 parents 19b8835 + f9ce077 commit bc1fada
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
20 changes: 14 additions & 6 deletions src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
const PARALLEL_DOC = """
@parallel kernelcall
@parallel ∇=... kernelcall
!!! note "Advanced"
@parallel ranges kernelcall
@parallel nblocks nthreads kernelcall
@parallel ranges nblocks nthreads kernelcall
@parallel (...) kwargs... kernelcall
@parallel (...) configcall=... backendkwargs... kernelcall
@parallel ∇=... ad_mode=... ad_annotations=... (...) backendkwargs... kernelcall
Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream).
Expand All @@ -15,7 +17,11 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
- `ranges::Tuple{UnitRange{},UnitRange{},UnitRange{}} | Tuple{UnitRange{},UnitRange{}} | Tuple{UnitRange{}} | UnitRange{}`: the ranges of indices in each dimension for which computations must be performed.
- `nblocks::Tuple{Integer,Integer,Integer}`: the number of blocks to be used if the package CUDA or AMDGPU was selected with [`@init_parallel_kernel`](@ref).
- `nthreads::Tuple{Integer,Integer,Integer}`: the number of threads to be used if the package CUDA or AMDGPU was selected with [`@init_parallel_kernel`](@ref).
- `kwargs...`: keyword arguments to be passed further to CUDA or AMDGPU (ignored for Threads).
# Keyword arguments
!!! note "Advanced"
- `configcall=kernelcall`: a call to a kernel that is declared parallel, which is used for determining the kernel launch parameters. This keyword is useful, e.g., for generic automatic differentiation using the low-level submodule [`AD`](@ref).
- `backendkwargs...`: keyword arguments to be passed further to CUDA or AMDGPU (ignored for Threads).
!!! note "Performance note"
Kernel launch parameters are automatically defined with heuristics, where not defined with optional kernel arguments. For CUDA and AMDGPU, `nthreads` is typically set to (32,8,1) and `nblocks` accordingly to ensure that enough threads are launched.
Expand All @@ -39,12 +45,14 @@ macro parallel_indices(args...) check_initialized(); checkargs_parallel_indices(
##
const PARALLEL_ASYNC_DOC = """
@parallel_async kernelcall
@parallel_async ∇=... kernelcall
!!! note "Advanced"
@parallel_async ranges kernelcall
@parallel_async nblocks nthreads kernelcall
@parallel_async ranges nblocks nthreads kernelcall
@parallel_async (...) kwargs... kernelcall
@parallel_async (...) configcall=... backendkwargs... kernelcall
@parallel_async ∇=... ad_mode=... ad_annotations=... (...) backendkwargs... kernelcall
Declare the `kernelcall` parallel as with [`@parallel`](@ref) (see [`@parallel`](@ref) for more information); deactivates however automatic synchronization at the end of the call. Use [`@synchronize`](@ref) for synchronizing.
Expand Down Expand Up @@ -210,11 +218,11 @@ function parallel_call_ad(caller::Module, kernelcall::Expr, backend_kwargs_expr:
end
end
annotated_args = (:($(ad_annotations_byvar[var][1])($((var keys(ad_vars) ? (var, ad_vars[var]) : (var,))...))) for var in f_posargs)
ad_call = :(autodiff_deferred!($ad_mode, $f_name, $(annotated_args...)))
ad_call = :(ParallelStencil.ParallelKernel.AD.autodiff_deferred!($ad_mode, $f_name, $(annotated_args...)))
kwargs_remaining = filter(x->!(x in (:∇, :ad_mode, :ad_annotations)), keys(kwargs))
kwargs_remaining_expr = [:($key=$val) for (key,val) in kwargs_remaining]
if (async) return :( @parallel $(posargs...) $(backend_kwargs_expr...) $(kwargs_remaining_expr...) configcall=$kernelcall $ad_call ) #TODO: the package needs to be passed further here later.
else return :( @parallel_async $(posargs...) $(backend_kwargs_expr...) $(kwargs_remaining_expr...) configcall=$kernelcall $ad_call ) #...
if (async) return :( @parallel_async $(posargs...) $(backend_kwargs_expr...) $(kwargs_remaining_expr...) configcall=$kernelcall $ad_call ) #TODO: the package needs to be passed further here later.
else return :( @parallel $(posargs...) $(backend_kwargs_expr...) $(kwargs_remaining_expr...) configcall=$kernelcall $ad_call ) #...
end
end

Expand Down
13 changes: 9 additions & 4 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ See also: [`@init_parallel_stencil`](@ref)
--------------------------------------------------------------------------------
@parallel kernelcall
@parallel memopt=... kernelcall
@parallel ∇=... kernelcall
@parallel ∇=... memopt=... kernelcall
!!! note "Advanced"
@parallel ranges kernelcall
@parallel nblocks nthreads kernelcall
@parallel ranges nblocks nthreads kernelcall
@parallel (...) kwargs... kernelcall
@parallel (...) memopt=... configcall=... backendkwargs... kernelcall
@parallel ∇=... ad_mode=... ad_annotations=... (...) memopt=... backendkwargs... kernelcall
Declare the `kernelcall` parallel. The kernel will automatically be called as required by the package for parallelization selected with [`@init_parallel_kernel`](@ref). Synchronizes at the end of the call (if a stream is given via keyword arguments, then it synchronizes only this stream).
Expand All @@ -29,10 +32,12 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
- `ranges::Tuple{UnitRange{},UnitRange{},UnitRange{}} | Tuple{UnitRange{},UnitRange{}} | Tuple{UnitRange{}} | UnitRange{}`: the ranges of indices in each dimension for which computations must be performed.
- `nblocks::Tuple{Integer,Integer,Integer}`: the number of blocks to be used if the package CUDA or AMDGPU was selected with [`@init_parallel_kernel`](@ref).
- `nthreads::Tuple{Integer,Integer,Integer}`: the number of threads to be used if the package CUDA or AMDGPU was selected with [`@init_parallel_kernel`](@ref).
- `kwargs...`: keyword arguments to be passed further to CUDA or AMDGPU (ignored for Threads).
# Optional keyword arguments
# Keyword arguments
- `memopt::Bool=false`: whether the kernel to be launched was generated with `memopt=true` (meaning the keyword was set in the kernel declaration).
!!! note "Advanced"
- `configcall=kernelcall`: a call to a kernel that is declared parallel, which is used for determining the kernel launch parameters. This keyword is useful, e.g., for generic automatic differentiation using the low-level submodule [`AD`](@ref).
- `backendkwargs...`: keyword arguments to be passed further to CUDA or AMDGPU (ignored for Threads).
!!! note "Performance note"
Kernel launch parameters are automatically defined with heuristics, where not defined with optional kernel arguments. For CUDA and AMDGPU, `nthreads` is typically set to (32,8,1) and `nblocks` accordingly to ensure that enough threads are launched.
Expand Down Expand Up @@ -111,7 +116,7 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp
parallel_kernel(metadata_module, metadata_function, caller, package, numbertype, ndims, kernelarg, posargs...; kwargs)
elseif is_call(args[end])
posargs, kwargs_expr, kernelarg = split_parallel_args(args)
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:memopt, :configcall), "@parallel <kernelcall>", true; eval_args=(:memopt,))
kwargs, backend_kwargs_expr = extract_kwargs(caller, kwargs_expr, (:memopt, :configcall, :∇, :ad_mode, :ad_annotations), "@parallel <kernelcall>", true; eval_args=(:memopt,))
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt()
configcall = haskey(kwargs, :configcall) ? kwargs.configcall : kernelarg
configcall_kwarg_expr = :(configcall=$configcall)
Expand Down
10 changes: 5 additions & 5 deletions test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ end
end;
end;
@testset "@parallel ∇" begin
@test @prettystring(1, @parallel=B->f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.Const)(A), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=B) f!(A, B, a)) == "@parallel_async configcall = f!(A, B, a) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=(B,A), Active=b) f!(A, B, a, b)) == "@parallel_async configcall = f!(A, B, a, b) autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.Duplicated)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a), (EnzymeCore.Active)(b))"
@test @prettystring(1, @parallel=B->f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.Const)(A), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=B) f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=(B,A), Active=b) f!(A, B, a, b)) == "@parallel configcall = f!(A, B, a, b) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.Duplicated)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a), (EnzymeCore.Active)(b))"
end;
@testset "@parallel_indices" begin
@testset "addition of range arguments" begin
Expand Down
7 changes: 7 additions & 0 deletions test/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ import ParallelStencil.@gorgeousexpand
@test @prettystring(2, @parallel ranges memopt=true f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))"
end;
end;
@testset "@parallel ∇" begin
@test @prettystring(1, @parallel=B->f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.Const)(A), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Reverse, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.DuplicatedNoNeed)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=B) f!(A, B, a)) == "@parallel configcall = f!(A, B, a) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.DuplicatedNoNeed)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a))"
@test @prettystring(1, @parallel=(A->Ā, B->B̄) ad_mode=Enzyme.Forward ad_annotations=(Duplicated=(B,A), Active=b) f!(A, B, a, b)) == "@parallel configcall = f!(A, B, a, b) ParallelStencil.ParallelKernel.AD.autodiff_deferred!(Enzyme.Forward, f!, (EnzymeCore.Duplicated)(A, Ā), (EnzymeCore.Duplicated)(B, B̄), (EnzymeCore.Const)(a), (EnzymeCore.Active)(b))"
end;
@testset "@parallel <kernel>" begin
@testset "addition of range arguments" begin
expansion = @gorgeousstring(1, @parallel f(A, B, c::T) where T <: Integer = (@all(A) = @all(B)^c; return))
Expand Down

0 comments on commit bc1fada

Please sign in to comment.