Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow passing a custom pmap #63

Merged
merged 5 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,5 @@ end

# Parallel execution
- The macro `@phyperopt` works in the same way as `@hyperopt` but distributes all computation on available workers. The usual caveats apply, code must be loaded on all workers etc.
- `@phyperopt` accepts an optional second argument which is a `pmap`-like function. E.g. `(args...,) -> pmap(args...; on_error=...)`.
- The macro `@thyperopt` uses `ThreadPools.tmap` to evaluate the objective on all available threads. Beware of high memory consumption if your objective allocates a lot of memory.
69 changes: 53 additions & 16 deletions src/Hyperopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ end
Base.length(ho::Hyperoptimizer) = ho.iterations


function Base.iterate(ho::Hyperoptimizer, state=1)
function Base.iterate(ho::Hyperoptimizer, state=1; update_history=true)
state > ho.iterations && return nothing
samples = ho.sampler(ho, state)
push!(ho.history, samples)
update_history && push!(ho.history, samples)
nt = (; Pair.((:i, ho.params...), (state, samples...))...)
nt, state+1
end
Expand Down Expand Up @@ -157,7 +157,7 @@ function optimize(ho::Hyperoptimizer)
if e isa InterruptException
@info "Aborting hyperoptimization"
else
rethrow(e)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will mess with the stacktrace and point to this line instead of the original one, so I fixed it to just rethrow() which does the right thing in a catch block

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I learned something new 😃

rethrow()
end
end
ho
Expand Down Expand Up @@ -195,20 +195,57 @@ function pmacrobody(ex, params, ho_, pmap=pmap)
quote
function workaround_function()
ho = $(ho_)
# Getting the history right is tricky when using workers. The approach I've found to work is to
# save the actual array (not copy) in hist, temporarily use a new array that will later be discarded
# reassign the original array and then append the new history. If a new array is used, the change will not be visible in the original hyperoptimizer
hist = ho.history
ho.history = []
res = $(pmap)(1:ho.iterations) do i
$(Expr(:tuple, esc.(params)...)),_ = iterate(ho,i)

# We use a `RemoteChannel` to coordinate access to a single Hyperoptimizer object
# that lives on the manager process.
ho_channel = RemoteChannel(() -> Channel{Hyperoptimizer}(1), 1)

# We use a `deepcopy` to ensure we get the same semantics whether or not the code
# ends up executing on a remote process or not (i.e. always a copy). See
# <https://docs.julialang.org/en/v1/manual/distributed-computing/#Local-invocations>
put!(ho_channel, deepcopy(ho))

# We don't care about the results of the `pmap` since we update the hyperoptimizer
# inside the loop.
$(esc(pmap))(1:ho.iterations) do i
# We take the hyperoptimizer out of the channel, and get our new parameter values
local_ho = take!(ho_channel)
# We use `update_history` because we want to only update the history once we've
# finished the iteration and have a result to report back as well. Otherwise,
# some processes may observe the Hyperoptimizer in an inconsistent state with
# `length(ho.history) > length(ho.results)`. Moreover, if one run is very quick
# the history and results could become out of order with respect to one another.
$(Expr(:tuple, esc.(params)...)), _ = iterate(local_ho, i; update_history = false)
# Now we put it back so another processor can use the hyperoptimizer
put!(ho_channel, local_ho)

# Now run the objective
res = $(esc(ex.args[2])) # ex.args[2] = Body of the For loop

res, $(Expr(:tuple, esc.(params[2:end])...))
# Now update the results; we again grab the one true hyperoptimizer,
# and populate it's history and result.
local_ho = take!(ho_channel)
push!(local_ho.history, $(Expr(:tuple, esc.(params[2:end])...)))
push!(local_ho.results, res)
put!(ho_channel, local_ho)

res
end
ho.history = hist
append!(ho.results, getindex.(res,1))
append!(ho.history, getindex.(res,2)) # history automatically appended by the iteration

# What we get out of the channel is an updated copy of our original hyperoptimizer.
# So now, back on the manager process, we take it out one last time and update
# the original hyperoptimizer.
updated_ho = take!(ho_channel)
close(ho_channel)

# Getting the history right is tricky. For some reason, we can't do `ho.history = updated_ho.history`.
# Instead, we must mutate the existing `ho.history` vector. (Similarly for results).
empty!(ho.history)
append!(ho.history, updated_ho.history)

empty!(ho.results)
append!(ho.results, updated_ho.results)

ho
end
workaround_function()
Expand All @@ -218,10 +255,10 @@ end
"""
Same as `@hyperopt` but uses `Distributed.pmap` for parallel evaluation of the cost function.
"""
macro phyperopt(ex)
macro phyperopt(ex, pmap=pmap)
pre = preprocess_expression(ex)
ho_ = create_ho(pre...)
pmacrobody(ex, pre[1], ho_)
pmacrobody(ex, pre[1], ho_, pmap)
end

"""
Expand Down
15 changes: 15 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ Random.seed!(0)
using Hyperopt, Plots
f(a,b=true;c=10) = sum(@. 100 + (a-3)^2 + (b ? 10 : 20) + (c-100)^2) # This function must be defined outside testsets to avoid scoping issues

# Use a separate module to test for escaping issues
module MyPmapModule
using Distributed
my_pmap(args...) = Distributed.pmap(args...; on_error=identity)
end

@testset "Hyperopt" begin

@testset "Random sampler" begin
Expand Down Expand Up @@ -295,6 +301,15 @@ f(a,b=true;c=10) = sum(@. 100 + (a-3)^2 + (b ? 10 : 20) + (c-100)^2) # This func
f(horp.history[i][1:2]..., c=horp.history[i][3]) == horp.results[i]
end

# Test with a custom pmap
horp = @phyperopt for i=300, sampler=RandomSampler(), a = LinRange(1,5,50), b = [true, false], c = exp10.(LinRange(-1,3,50))
# println(i, "\t", a, "\t", b, "\t", c)
i > 100 && error("Too many iterations")
f(a,b,c=c)
end MyPmapModule.my_pmap
# we can't handle errors in the history or results
@test length(horp.history) == 100
@test length(horp.results) == 100

end
@testset "BOHB" begin
Expand Down