Skip to content

Commit

Permalink
CUCB: allow tuning the confidence radius.
Browse files Browse the repository at this point in the history
  • Loading branch information
dourouc05 committed Dec 21, 2020
1 parent 55efdd7 commit 5f69863
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/policies/cucb.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# CUCB, "combinatorial upper confidence bound".
# Based on http://proceedings.mlr.press/v28/chen13a.html.

mutable struct CUCB <: Policy end
mutable struct CUCB <: Policy
α::Float64
end

CUCB() = CUCB(sqrt(0.5))

struct CUCBDetails <: PolicyDetails
solver_time::Float64
end

function choose_action(instance::CombinatorialInstance{T}, ::CUCB, state::State{T}; with_trace::Bool=false) where T
function choose_action(instance::CombinatorialInstance{T}, algo::CUCB, state::State{T}; with_trace::Bool=false) where T
if any(v == 0 for v in values(state.arm_counts))
# There is at least one arm that has never been tried: maximise the arms that have never been tested.
weights = Dict(arm => (state.arm_counts[arm] == 0.0) ? 1.0 : 0.0 for arm in keys(state.arm_counts))
else
# All arms have been seen, thus this formula makes sense (no zero-valued arm count).
weights = Dict(arm => state.arm_average_reward[arm] + sqrt((log(state.round)) / (2 * state.arm_counts[arm])) for arm in keys(state.arm_counts))
# weights = Dict(arm => state.arm_average_reward[arm] + sqrt((3 * log(state.round)) / (2 * state.arm_counts[arm])) for arm in keys(state.arm_counts))
numerator = algo.α * sqrt(log(state.round))
weights = Dict(arm => state.arm_average_reward[arm] + numerator * sqrt(1.0 / state.arm_counts[arm]) for arm in keys(state.arm_counts))
end

t0 = time_ns()
Expand Down

0 comments on commit 5f69863

Please sign in to comment.