Skip to content

Commit

Permalink
allow passing a custom pmap (#63)
Browse files Browse the repository at this point in the history
* allow passing a custom `pmap`

* add comment

* wip

* fix appending history

* hm
  • Loading branch information
ericphanson committed Aug 31, 2021
1 parent c7418d1 commit 23096d8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,5 @@ end

# Parallel execution
- The macro `@phyperopt` works in the same way as `@hyperopt` but distributes all computation on available workers. The usual caveats apply, code must be loaded on all workers etc.
- `@phyperopt` accepts an optional second argument which is a `pmap`-like function. E.g. `(args...,) -> pmap(args...; on_error=...)`.
- The macro `@thyperopt` uses `ThreadPools.tmap` to evaluate the objective on all available threads. Beware of high memory consumption if your objective allocates a lot of memory.
69 changes: 53 additions & 16 deletions src/Hyperopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ end
Base.length(ho::Hyperoptimizer) = ho.iterations


function Base.iterate(ho::Hyperoptimizer, state=1)
function Base.iterate(ho::Hyperoptimizer, state=1; update_history=true)
state > ho.iterations && return nothing
samples = ho.sampler(ho, state)
push!(ho.history, samples)
update_history && push!(ho.history, samples)
nt = (; Pair.((:i, ho.params...), (state, samples...))...)
nt, state+1
end
Expand Down Expand Up @@ -166,7 +166,7 @@ function optimize(ho::Hyperoptimizer)
if e isa InterruptException
@info "Aborting hyperoptimization"
else
rethrow(e)
rethrow()
end
end
ho
Expand Down Expand Up @@ -204,20 +204,57 @@ function pmacrobody(ex, params, ho_, pmap=pmap)
quote
function workaround_function()
ho = $(ho_)
# Getting the history right is tricky when using workers. The approach I've found to work is to
# save the actual array (not copy) in hist, temporarily use a new array that will later be discarded
# reassign the original array and then append the new history. If a new array is used, the change will not be visible in the original hyperoptimizer
hist = ho.history
ho.history = []
res = $(pmap)(1:ho.iterations) do i
$(Expr(:tuple, esc.(params)...)),_ = iterate(ho,i)

# We use a `RemoteChannel` to coordinate access to a single Hyperoptimizer object
# that lives on the manager process.
ho_channel = RemoteChannel(() -> Channel{Hyperoptimizer}(1), 1)

# We use a `deepcopy` to ensure we get the same semantics whether or not the code
# ends up executing on a remote process or not (i.e. always a copy). See
# <https://docs.julialang.org/en/v1/manual/distributed-computing/#Local-invocations>
put!(ho_channel, deepcopy(ho))

# We don't care about the results of the `pmap` since we update the hyperoptimizer
# inside the loop.
$(esc(pmap))(1:ho.iterations) do i
# We take the hyperoptimizer out of the channel, and get our new parameter values
local_ho = take!(ho_channel)
# We use `update_history` because we want to only update the history once we've
# finished the iteration and have a result to report back as well. Otherwise,
# some processes may observe the Hyperoptimizer in an inconsistent state with
# `length(ho.history) > length(ho.results)`. Moreover, if one run is very quick
# the history and results could become out of order with respect to one another.
$(Expr(:tuple, esc.(params)...)), _ = iterate(local_ho, i; update_history = false)
# Now we put it back so another processor can use the hyperoptimizer
put!(ho_channel, local_ho)

# Now run the objective
res = $(esc(ex.args[2])) # ex.args[2] = Body of the For loop

res, $(Expr(:tuple, esc.(params[2:end])...))
# Now update the results; we again grab the one true hyperoptimizer,
# and populate it's history and result.
local_ho = take!(ho_channel)
push!(local_ho.history, $(Expr(:tuple, esc.(params[2:end])...)))
push!(local_ho.results, res)
put!(ho_channel, local_ho)

res
end
ho.history = hist
append!(ho.results, getindex.(res,1))
append!(ho.history, getindex.(res,2)) # history automatically appended by the iteration

# What we get out of the channel is an updated copy of our original hyperoptimizer.
# So now, back on the manager process, we take it out one last time and update
# the original hyperoptimizer.
updated_ho = take!(ho_channel)
close(ho_channel)

# Getting the history right is tricky. For some reason, we can't do `ho.history = updated_ho.history`.
# Instead, we must mutate the existing `ho.history` vector. (Similarly for results).
empty!(ho.history)
append!(ho.history, updated_ho.history)

empty!(ho.results)
append!(ho.results, updated_ho.results)

ho
end
workaround_function()
Expand All @@ -227,10 +264,10 @@ end
"""
Same as `@hyperopt` but uses `Distributed.pmap` for parallel evaluation of the cost function.
"""
macro phyperopt(ex)
macro phyperopt(ex, pmap=pmap)
pre = preprocess_expression(ex)
ho_ = create_ho(pre...)
pmacrobody(ex, pre[1], ho_)
pmacrobody(ex, pre[1], ho_, pmap)
end

"""
Expand Down
15 changes: 15 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ Random.seed!(0)
using Hyperopt, Plots
f(a,b=true;c=10) = sum(@. 100 + (a-3)^2 + (b ? 10 : 20) + (c-100)^2) # This function must be defined outside testsets to avoid scoping issues

# Use a separate module to test for escaping issues
module MyPmapModule
using Distributed
my_pmap(args...) = Distributed.pmap(args...; on_error=identity)
end

@testset "Hyperopt" begin

@testset "Random sampler" begin
Expand Down Expand Up @@ -295,6 +301,15 @@ f(a,b=true;c=10) = sum(@. 100 + (a-3)^2 + (b ? 10 : 20) + (c-100)^2) # This func
f(horp.history[i][1:2]..., c=horp.history[i][3]) == horp.results[i]
end

# Test with a custom pmap
horp = @phyperopt for i=300, sampler=RandomSampler(), a = LinRange(1,5,50), b = [true, false], c = exp10.(LinRange(-1,3,50))
# println(i, "\t", a, "\t", b, "\t", c)
i > 100 && error("Too many iterations")
f(a,b,c=c)
end MyPmapModule.my_pmap
# we can't handle errors in the history or results
@test length(horp.history) == 100
@test length(horp.results) == 100

end
@testset "BOHB" begin
Expand Down

0 comments on commit 23096d8

Please sign in to comment.