In [1]:
using POMDPs
using Random # for AbstractRNG
using POMDPModelTools

In [2]:
struct ChainMDP <: MDP{Int, Symbol}
    len::Int
    p_success::Float64
    discount::Float64
end


In [3]:
function POMDPs.generate_s(p::ChainMDP, s::Int, a::Symbol, rng::AbstractRNG)
    if a == :right
          success = min(s+1, p.len)
        failure = max(s-1, 1)
    else # a == :left
        success = max(s-1, 1)
        failure = min(s+1, p.len)
    end
    if a == :right && s + 2 == p.len
        return p.len - 1
    elseif a == :left && s == 2
        return 1
    end
    return rand(rng) < p.p_success ? success : failure
end


In [41]:
function POMDPs.reward(p::ChainMDP, s::Int, a::Symbol)
    if s == 2
        return 10
    end
    if s + 2 == p.len
        return -10
    end
    return 0
end

In [5]:
POMDPs.initialstate_distribution(m::ChainMDP) = Deterministic(4)

In [6]:
function POMDPs.isterminal(p::ChainMDP, s::Int)
    if s == 1
        return true
    end
    if s+1 == p.len
        return true
    end
    return false
end        

In [7]:
using POMDPSimulators
using POMDPPolicies

ChainMDP() = ChainMDP(6+2,.9,.9)
m = ChainMDP()

# policy that maps every input to a right action
policy = FunctionPolicy(s->:right)

for (s, a, r) in stepthrough(m, policy, "s,a,r", max_steps=10)
    @show s
    @show a
    @show r
    render(m, (s,a,r))
    println()
end


┌ Info: Precompiling POMDPSimulators [e0d0a172-29c6-5d4e-96d0-f262df5d01fd]
└ @ Base loading.jl:1186


s = 4
a = :right
r = 0

s = 5
a = :right
r = 0

s = 6
a = :right
r = 10



In [8]:
using POMDPSimulators
using POMDPPolicies

ChainMDP() = ChainMDP(6+2,.9,.9)
m = ChainMDP()

# policy that maps every input to a left action
policy = FunctionPolicy(s->:right)

for (s, a, r) in stepthrough(m, policy, "s,a,r", max_steps=10)
    # @show s
    # @show a
    # @show r
    render(m, (s,a,r))
    println("s,a,r:($s,$a,$r)")
end


s,a,r:(4,right,0)
s,a,r:(3,right,0)
s,a,r:(4,right,0)
s,a,r:(5,right,0)
s,a,r:(6,right,10)


In [46]:
n_agents = 5
agents = Any[]
n_states = 10
# setup agents
Q_tables = []
policies = []
# policy that maps every input to a right action
for i in 1:n_agents
    push!(Q_tables, Dict{Int32,Float32}())
    push!(policies, s->:right)
end
# policy = s->:right

function update_Q(Q_table, s, a, r, sp, t)
    println("s:$s, a:$a, sp:$sp, r:$r, t:$t")
end
function POMDPs.initialstate_distribution(m::ChainMDP)
    return Deterministic(Int64((n_states+2)/2))
end
function run_chain(policies, update_Q, n_agents, n_states, Q_tables, epochs)
    for e in 1:epochs
        agents = []
        for i in 1:n_agents
            m = ChainMDP(n_states+2,.9,.9)
            push!(agents, Iterators.Stateful(stepthrough(m, FunctionPolicy(policies[i]), "s,a,r,sp,t", max_steps=10)))
        end

        println("epoch: $e")
        done = false
        while ! done
           done = true
           for i in 1:n_agents
                if isempty(agents[i])
                    println("agent $i is done")
                    continue
                end
                res = popfirst!(agents[i])
                # println("update Q")
                update_Q(Q_tables[i],res...)
                println("print agent $i result: $res")

           end

           for i in 1:n_agents
                if ! isempty(agents[i])
                    done = false
                end
            end
        end
    end
end
run_chain(policies, update_Q, n_agents,  n_states, Q_tables, 10)

epoch: 1
s:6, a:right, sp:7, r:0, t:1
print agent 1 result: (s = 6, a = :right, r = 0, sp = 7, t = 1)
s:6, a:right, sp:7, r:0, t:1
print agent 2 result: (s = 6, a = :right, r = 0, sp = 7, t = 1)
s:6, a:right, sp:7, r:0, t:1
print agent 3 result: (s = 6, a = :right, r = 0, sp = 7, t = 1)
s:6, a:right, sp:7, r:0, t:1
print agent 4 result: (s = 6, a = :right, r = 0, sp = 7, t = 1)
s:6, a:right, sp:7, r:0, t:1
print agent 5 result: (s = 6, a = :right, r = 0, sp = 7, t = 1)
s:7, a:right, sp:6, r:0, t:2
print agent 1 result: (s = 7, a = :right, r = 0, sp = 6, t = 2)
s:7, a:right, sp:8, r:0, t:2
print agent 2 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
s:7, a:right, sp:6, r:0, t:2
print agent 3 result: (s = 7, a = :right, r = 0, sp = 6, t = 2)
s:7, a:right, sp:8, r:0, t:2
print agent 4 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
s:7, a:right, sp:8, r:0, t:2
print agent 5 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
s:6, a:right, sp:7, r:0, t:3
print agent 1 result: (s = 6, a 

s:7, a:right, sp:6, r:0, t:2
print agent 1 result: (s = 7, a = :right, r = 0, sp = 6, t = 2)
s:7, a:right, sp:8, r:0, t:2
print agent 2 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
s:7, a:right, sp:8, r:0, t:2
print agent 3 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
s:7, a:right, sp:8, r:0, t:2
print agent 4 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
s:7, a:right, sp:6, r:0, t:2
print agent 5 result: (s = 7, a = :right, r = 0, sp = 6, t = 2)
s:6, a:right, sp:7, r:0, t:3
print agent 1 result: (s = 6, a = :right, r = 0, sp = 7, t = 3)
s:8, a:right, sp:9, r:0, t:3
print agent 2 result: (s = 8, a = :right, r = 0, sp = 9, t = 3)
s:8, a:right, sp:7, r:0, t:3
print agent 3 result: (s = 8, a = :right, r = 0, sp = 7, t = 3)
s:8, a:right, sp:9, r:0, t:3
print agent 4 result: (s = 8, a = :right, r = 0, sp = 9, t = 3)
s:6, a:right, sp:7, r:0, t:3
print agent 5 result: (s = 6, a = :right, r = 0, sp = 7, t = 3)
s:7, a:right, sp:8, r:0, t:4
print agent 1 result: (s = 7, a = :right,

s:8, a:right, sp:9, r:0, t:3
print agent 3 result: (s = 8, a = :right, r = 0, sp = 9, t = 3)
s:8, a:right, sp:9, r:0, t:3
print agent 4 result: (s = 8, a = :right, r = 0, sp = 9, t = 3)
s:6, a:right, sp:7, r:0, t:3
print agent 5 result: (s = 6, a = :right, r = 0, sp = 7, t = 3)
s:9, a:right, sp:8, r:0, t:4
print agent 1 result: (s = 9, a = :right, r = 0, sp = 8, t = 4)
s:9, a:right, sp:10, r:0, t:4
print agent 2 result: (s = 9, a = :right, r = 0, sp = 10, t = 4)
s:9, a:right, sp:10, r:0, t:4
print agent 3 result: (s = 9, a = :right, r = 0, sp = 10, t = 4)
s:9, a:right, sp:10, r:0, t:4
print agent 4 result: (s = 9, a = :right, r = 0, sp = 10, t = 4)
s:7, a:right, sp:6, r:0, t:4
print agent 5 result: (s = 7, a = :right, r = 0, sp = 6, t = 4)
s:8, a:right, sp:9, r:0, t:5
print agent 1 result: (s = 8, a = :right, r = 0, sp = 9, t = 5)
s:10, a:right, sp:11, r:-10, t:5
print agent 2 result: (s = 10, a = :right, r = -10, sp = 11, t = 5)
s:10, a:right, sp:11, r:-10, t:5
print agent 3 result: (

In [45]:
function square(x)
    return x * x
end

vals = Dict(1 => 4, 2 => 5, 3 => 6)
Dict(key=> square(val) for (key, val) in vals)

Dict{Int64,Int64} with 3 entries:
  2 => 25
  3 => 36
  1 => 16

In [47]:
# UCB
curry(f, x) = (xs...) -> f(x, xs...)
Q_tables = []
actions = [1, 2]
action_map = Dict(1 => :left, :2 => :right)
rev_action_map = Dict(:left => 1, :right => 2)
states = 1:(n_states+2)
policies = []

function ucb_pol(Q_table, actions, s)
    # return 
    if haskey(Q_table, s)
       val, idx = findmax(Q_table[s]) # need to test
       return action_map[idx]
    else
       act = action_map[rand(actions, 1)[1]]
       print("selected action $act") 
       return act
    end
end

function update_Q(Q_table, s, a, r, sp, t)
    alpha=.95
    gamma=.95
    # print("s:$s, a:$a, sp:$sp, r:$r")
    Q_table[s][rev_action_map[a]] +=  alpha * (r + gamma * 
        findmax(Q_table[sp])[1] - Q_table[s][rev_action_map[a]])
    #println(Q_table)
end

#setup
for i in 1:n_agents
    push!(Q_tables, Dict{Int32, Dict{Int32, Float32}}())
    for state in states
      Q_tables[i][state] = Dict{Int32, Float32}()
      # print(Q_tables[i])
      for action in actions
            Q_tables[i][state][action] = 0.0
      end
    end
    push!(policies, curry(curry(ucb_pol, Q_tables[i]), actions))
end

n_agents = 1
agents = Any[]
n_states = 10
run_chain(policies, update_Q, n_agents,  n_states, Q_tables, 100)

epoch: 1
print agent 1 result: (s = 6, a = :right, r = 0, sp = 7, t = 1)
print agent 1 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
print agent 1 result: (s = 8, a = :right, r = 0, sp = 9, t = 3)
print agent 1 result: (s = 9, a = :right, r = 0, sp = 10, t = 4)
print agent 1 result: (s = 10, a = :right, r = -10, sp = 11, t = 5)
epoch: 2
print agent 1 result: (s = 6, a = :right, r = 0, sp = 7, t = 1)
print agent 1 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
print agent 1 result: (s = 8, a = :right, r = 0, sp = 9, t = 3)
print agent 1 result: (s = 9, a = :right, r = 0, sp = 10, t = 4)
print agent 1 result: (s = 10, a = :left, r = -10, sp = 9, t = 5)
print agent 1 result: (s = 9, a = :right, r = 0, sp = 10, t = 6)
print agent 1 result: (s = 10, a = :right, r = -10, sp = 11, t = 7)
epoch: 3
print agent 1 result: (s = 6, a = :right, r = 0, sp = 7, t = 1)
print agent 1 result: (s = 7, a = :right, r = 0, sp = 8, t = 2)
print agent 1 result: (s = 8, a = :right, r = 0, sp = 7, t = 3)


epoch: 18
print agent 1 result: (s = 6, a = :left, r = 0, sp = 7, t = 1)
print agent 1 result: (s = 7, a = :left, r = 0, sp = 6, t = 2)
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 3)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 4)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 3, t = 5)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 6)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 7)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 8)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 1, t = 9)
epoch: 19
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 1)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 2)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 5, t = 3)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 4)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 3, t = 5)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 6)
print agent 1 result: (s = 2, a

print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 10)
epoch: 35
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 1)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 2)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 5, t = 3)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 4)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 3, t = 5)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 6)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 7)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 8)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 9)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 10)
epoch: 36
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 1)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 2)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 3, t = 3)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 4)
print agent 1 result: (s = 2,

print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 4)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 5)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 6)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 7)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 8)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 9)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 10)
epoch: 52
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 1)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 2)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 3, t = 3)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 4)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 5)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 6)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 7)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 8)
print agent 1 result: (s = 2, a = 

print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 7)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 8)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 9)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 10)
epoch: 69
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 1)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 2)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 3, t = 3)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 4)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 5)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 6)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 7)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 8)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 9)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 10)
epoch: 70
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 1)
print agent 1 result: (

print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 9)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 10)
epoch: 86
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 1)
print agent 1 result: (s = 5, a = :left, r = 0, sp = 4, t = 2)
print agent 1 result: (s = 4, a = :left, r = 0, sp = 3, t = 3)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 4)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 5)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 6)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 7)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 8)
print agent 1 result: (s = 2, a = :right, r = 10, sp = 3, t = 9)
print agent 1 result: (s = 3, a = :left, r = 0, sp = 2, t = 10)
epoch: 87
print agent 1 result: (s = 6, a = :left, r = 0, sp = 7, t = 1)
print agent 1 result: (s = 7, a = :left, r = 0, sp = 6, t = 2)
print agent 1 result: (s = 6, a = :left, r = 0, sp = 5, t = 3)
print agent 1 result: (s 