In [1]:
using POMDPs
using Random # for AbstractRNG
using POMDPModelTools
using Pkg
Pkg.add("JSON")

[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaPOMDP/Registry`
[?25l[2K[?25h[32m[1m Resolving[22m[39m package versions...
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.0/Project.toml`
[90m [no changes][39m
[32m[1m  Updating[22m[39m `~/.julia/environments/v1.0/Manifest.toml`
[90m [no changes][39m


In [2]:
struct ChainMDP <: MDP{Int, Symbol}
    len::Int
    p_success::Float64
    discount::Float64
    theta::Int
end


In [3]:
function POMDPs.generate_s(p::ChainMDP, s::Int, a::Symbol, rng::AbstractRNG)
    if a == :right
        success = min(s+1, p.len)
        failure = max(s-1, 1)
    else # a == :left
        success = max(s-1, 1)
        failure = min(s+1, p.len)
    end
    if s + 1 == p.len
        return p.len 
    elseif  s == 2
        return 1
    end
    return rand(rng) < p.p_success ? success : failure
end


In [4]:
# theta = 10
function POMDPs.reward(p::ChainMDP, s::Int, a::Symbol)
    if s == 2
        return p.theta
    end
    if s + 1 == p.len
        return -p.theta
    end
    if s == 0 || s == p.len
        return 0
    end
    return -1
end

In [5]:
POMDPs.initialstate_distribution(m::ChainMDP) = Deterministic(4)

In [6]:
function POMDPs.isterminal(p::ChainMDP, s::Int)
    if s == 1
        return true
    end
    if s == p.len
        return true
    end
    return false
end        

In [7]:
using POMDPSimulators
using POMDPPolicies

ChainMDP() = ChainMDP(6+2,1.0,.9, 10)
m = ChainMDP()

# policy that maps every input to a right action
policy = FunctionPolicy(s->:right)

for (s, a, r) in stepthrough(m, policy, "s,a,r", max_steps=10)
    @show s
    @show a
    @show r
    render(m, (s,a,r))
    println()
end


s = 4
a = :right
r = -1

s = 5
a = :right
r = -1

s = 6
a = :right
r = -1

s = 7
a = :right
r = -10



In [8]:
using POMDPSimulators
using POMDPPolicies

ChainMDP() = ChainMDP(6+2,1.0,.9, 10)
m = ChainMDP()

# policy that maps every input to a left action
policy = FunctionPolicy(s->:left)

function POMDPs.initialstate_distribution(m::ChainMDP)
    return Deterministic(4)
end
for (s, a, r) in stepthrough(m, policy, "s,a,r", max_steps=10)
    # @show s
    # @show a
    # @show r
    render(m, (s,a,r))
    println("s,a,r:($s,$a,$r)")
end


s,a,r:(4,left,-1)
s,a,r:(3,left,-1)
s,a,r:(2,left,10)


In [9]:

ChainMDP() = ChainMDP(6+2,1.0,.9, 10)
m = ChainMDP()

# policy that maps every input to a right action
policy = FunctionPolicy(s->:right)

function POMDPs.initialstate_distribution(m::ChainMDP)
    return Deterministic(4)
end
for (s, a, r) in stepthrough(m, policy, "s,a,r", max_steps=10)
    # @show s
    # @show a
    # @show r
    render(m, (s,a,r))
    println("s,a,r:($s,$a,$r)")
end


s,a,r:(4,right,-1)
s,a,r:(5,right,-1)
s,a,r:(6,right,-1)
s,a,r:(7,right,-10)


In [10]:
n_agents = 1
agents = Any[]
mdps = Any[]
n_states = 6
n_actions = 2
epochs = 10
H = 10
actions = [1, 2]
action_map = Dict(1 => :left, :2 => :right)
rev_action_map = Dict(:left => 1, :right => 2)
states = 1:(n_states+2)
# setup agents
Q_tables = []
N_tables = zeros((n_agents, n_states+2, n_actions))
print(N_tables)
policies = []
theta = 10
# policy that maps every input to a right action
for i in 1:n_agents
    push!(Q_tables, Dict{Int32,Float32}())
    push!(policies, s->:right)
    push!(mdps, ChainMDP(n_states+2,.9,.9, theta))
end
# policy = s->:right

function update_Q(Q_table, s, a, r, sp, t)
    alpha=.95
    gamma=.95
    # print("s:$s, a:$a, sp:$sp, r:$r")
    Q_table[s][rev_action_map[a]] +=  alpha * (r + gamma * 
        findmax(Q_table[sp])[1] - Q_table[s][rev_action_map[a]])
    #println(Q_table)
end
function POMDPs.initialstate_distribution(m::ChainMDP)
    return Deterministic(Int64((n_states+2)/2))
end
function run_chain!(;policies, mdps, reward_reveal_condition, true_mdp, update_Q, n_agents, n_states,
                    Q_tables, N_tables, epochs, steps)
    for e in 1:epochs
        agents = []
        for i in 1:n_agents
            m = mdps[i]
            push!(agents, 
                  Iterators.Stateful(stepthrough(m, FunctionPolicy(policies[i]), "s,a,r,sp,t", max_steps=steps)))
        end

        println("epoch: $e")
        done = false
        t = 0
        while ! done
           done = true
            
           for i in 1:n_agents
                if isempty(agents[i])
                    println("agent $i is done")
                    continue
                end
                res = popfirst!(agents[i])
                r = res[:r]
                if reward_reveal_condition(r)
                    for mdp in mpds
                        mdp = true_mdp
                    end
                end
                #println("before update: N_table: $N_tables")
                N_tables[i,res[:s],rev_action_map[res[:a]]] += 1
                #println("after update:  N_table: $N_tables")
                # println("update Q")
                update_Q(Q_tables[i],res...)
                t = res[:t]
                println("t: $t, print agent $i result: $res")
          
           end
           # println("t:$t")
           for i in 1:n_agents
                if ! isempty(agents[i])
                    done = false
                end
            end
        end
    end
end
run_chain!(policies=policies,
           mdps=mdps,  
           true_mdp=mdps[1],
           reward_reveal_condition= r -> false,
           update_Q=update_Q,
           n_agents=n_agents,
           n_states=n_states,
           Q_tables=Q_tables,
           N_tables=N_tables,
           epochs=epochs,
           steps=H)

[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]

[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]epoch: 1


KeyError: KeyError: key 4 not found

In [11]:
function square(x)
    return x * x
end

vals = Dict(1 => 4, 2 => 5, 3 => 6)
Dict(key=> square(val) for (key, val) in vals)

Dict{Int64,Int64} with 3 entries:
  2 => 25
  3 => 36
  1 => 16

In [12]:
A = [1 2; 3 4; 5 6]
B = A .+ 6
C = B .+ 6
D = hcat([A , B , C])

3×1 Array{Array{Int64,2},2}:
 [1 2; 3 4; 5 6]      
 [7 8; 9 10; 11 12]   
 [13 14; 15 16; 17 18]

In [13]:
D[2,:]

1-element Array{Array{Int64,2},1}:
 [7 8; 9 10; 11 12]

In [57]:
function ucb_pol(Q_tables, N_tables, i, actions, s)
    # return 
    C = 10.0
    if haskey(Q_tables[i], s)
       # println("ucb_pol:s:$s, N_tables:$N_tables")
       ucbs = Dict(key => sqrt(log(sum(N_tables[i][s,:])/N_tables[i][s,key])) 
                for (key,val) in Q_tables[i][s])
       # println("ucb: $(ucbs)")
       val, idx = findmax(Dict(key => Q_tables[i][s][key] + C* sqrt(log(sum(N_tables[i][s,:])/N_tables[i][s,key])) 
                for (key,val) in Q_tables[i][s])) # need to test
       return action_map[idx]
    else
       act = action_map[rand(actions, 1)[1]]
       println("selected random action $act!!!!!!!!!!!!!!!!!!!!!") 
       return act
    end
end




ucb_pol (generic function with 1 method)

In [28]:
function ran_pol(Q_tables, N_tables, i, actions, s)
    # return 
    act = action_map[rand(actions, 1)[1]]
    return act
end

ran_pol (generic function with 1 method)

In [15]:
curry(f, a) = (xs...) -> f(a, xs...)
curry2(f, a, b) = (xs...) -> f(a, b, xs...)
curry3(f, a, b, c) = (xs...) -> f(a, b, c, xs...)
curry4(f, a, b, c, d) = (xs...) -> f(a, b, c, d, xs...)
curry5(f, a, b, c, d, e) = (xs...) -> f(a, b, c, d, e, xs...)
curry6(f, a, b, c, d, e, g) = (xs...) -> f(a, b, c, d, e, g, xs...)

curry6 (generic function with 1 method)

In [16]:
n_states = 10
curry(f, x) = (xs...) -> f(x, xs...)
Q_tables = []
mdps = Any[]
N_tables = zeros((n_agents, n_states+2, n_actions))
policies = []
epochs = 500
theta = 10
H = 10
# Setup Q table according to MDP

n_agents = 1
agents = Any[]
n_states = 10


10

In [17]:
function POMDPs.reward(p::ChainMDP, s::Int, a::Symbol)
    if s == 2
        return 10
    end
    if s + 1 == p.len
        return -10
    end
    if s == 0 || s == p.len
        return 0
    end
    return -1
end

In [18]:
# UCB




#setup
for i in 1:n_agents
    push!(Q_tables, Dict{Int32, Dict{Int32, Float32}}())
    for state in states
      Q_tables[i][state] = Dict{Int32, Float32}()
      # print(Q_tables[i])
      for action in actions
         Q_tables[i][state][action] = 0 
      end
    end
    N_table = N_tables[i,:,:]
    push!(policies, curry(curry(curry(curry(ucb_pol, Q_tables), N_tables),i), actions))
    push!(mdps, ChainMDP(n_states+2,.9,.9, 10))
end

run_chain!(policies, mdps, theta, update_Q, n_agents,  n_states, Q_tables, N_tables, epochs, H)

MethodError: MethodError: no method matching run_chain!(::Array{Any,1}, ::Array{Any,1}, ::Int64, ::typeof(update_Q), ::Int64, ::Int64, ::Array{Any,1}, ::Array{Float64,3}, ::Int64, ::Int64)

In [19]:

n_agents = 10
Q_tables = [deepcopy(Q_tables[1]) for i in range(1,n_agents)]
policies = [deepcopy(policies[1]) for i in range(1,n_agents)]
mdps = [deepcopy(mdps[1]) for i in range(1,n_agents)]
N_tables = zeros((n_agents, n_states+2, n_actions))
run_chain!(policies, mdps, update_Q, n_agents,  n_states, Q_tables, N_tables, epochs, H)

MethodError: MethodError: no method matching run_chain!(::Array{getfield(Main, Symbol("##32#33")){getfield(Main, Symbol("##32#33")){getfield(Main, Symbol("##32#33")){getfield(Main, Symbol("##32#33")){typeof(ucb_pol),Array{Any,1}},Array{Float64,3}},Int64},Array{Int64,1}},1}, ::Array{ChainMDP,1}, ::typeof(update_Q), ::Int64, ::Int64, ::Array{Dict{Int32,Dict{Int32,Float32}},1}, ::Array{Float64,3}, ::Int64, ::Int64)

In [20]:
# paramter for seed sampling
using Distributions
theta = 10 * sign(rand(Bernoulli(0.5))-.5)
ntheta = - theta
print("theta: $theta, ntheta: $ntheta")
# Thompson sampling
#Should just be seed sampling, but you do it every step

theta: 10.0, ntheta: -10.0

In [21]:

using POMDPs
include("./ChainMDP.jl")
num_states = 10

mdp = PFChainMDP.PChainMDP(num_states+2,.9,.9, 10)
POMDPs.reward(mdp, 1, :left)

-1

In [22]:
function update_Q(Q_table, s, a, r, sp, t, rev_action_map)
    alpha=.95
    gamma=.95
    # print("s:$s, a:$a, sp:$sp, r:$r")
    Q_table[s][rev_action_map[a]] +=  alpha * (r + gamma * 
        findmax(Q_table[sp])[1] - Q_table[s][rev_action_map[a]])
    #println(Q_table)
end

update_Q (generic function with 2 methods)

In [72]:
mdps = [deepcopy(mdp) for i in 1:2]
isequal(mdps[1], mdps[2])

mdp1 = PFChainMDP.PChainMDP(num_states+2,.9,.9, 10)

mdp2 = PFChainMDP.PChainMDP(num_states+2,.9,.9, 10)
mdp1 === mdp2

true

In [74]:
# UCB
using Distributions
include("./Agent.jl")
num_agents = 5
agents = Any[]
num_states = 20
num_actions = 2
epochs = 30
H = floor(Int, 3 * num_states / 2)
actions = [1, 2]
action_map = Dict(1 => :left, :2 => :right)
rev_action_map = Dict(:left => 1, :right => 2)
states = 1:(num_states+2)

theta = 10
# Setup Q table according to MDP



#setup
#agmdp = PFChainMDP.PChainMDP(num_states+2,1.0,.9, theta * sign(rand(Bernoulli(0.5))-.5))
(Q_tables, N_tables, policies) = PFAgent.setup_agents(states, num_states, num_agents,
                                                            actions, num_actions, ucb_pol)
true_mdp = PFChainMDP.PChainMDP(num_states+2,1.0,.9, theta)

function chain_found_target(r)
   if r == -10 || r == 10
        return true
   end
   return false
end

PFAgent.run_chain!(
           policies=policies,
           found_target=chain_found_target,
           mdps=[deepcopy(true_mdp) for i in 1:num_agents],
           update_Q=update_Q,
           n_agents=num_agents,
           n_states=num_states,
           Q_tables=Q_tables,
           N_tables=N_tables,
           epochs=epochs,
           steps=H,
           rev_action_map=rev_action_map,
           stop_early=false)
print(Q_tables)

agent 1 is done
agent 1 is done
agent 2 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 4 is done
agent 1 is done
agent 1 is done
agent 2 is done
agent 1 is done
agent 2 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 4 is done
e: 3, t: 1, agent 1, result: (s = 12, a = :right, r = -1, sp = 13, t = 1)
e: 3, t: 2, agent 1, result: (s = 13, a = :left, r = -1, sp = 12, t = 2)
e: 3, t: 1, agent 2, result: (s = 12, a = :right, r = -1, sp = 13, t = 1)
e: 3, t: 3, agent 1, result: (s = 12, a = :right, r = -1, sp = 13, t = 3)
e: 3, t: 2, agent 2, result: (s = 13, a = :left, r = -1, sp = 12, t = 2)
e: 3, t: 1, agent 3, result: (s = 12, a = :right, r = -1, sp = 13, t = 1)
e: 3, t: 4, agent 1, result: (s = 13, a = :right, r = -1, sp = 14, t = 4)
e: 3, t: 3, agent 2, result: (s = 12, a = :right, r = -1, sp = 13, t = 3)
e: 3, t: 2, agent 3, result: (s = 13, a = :right, r = -1, 

e: 3, t: 26, agent 4, result: (s = 15, a = :right, r = -1, sp = 16, t = 26)
e: 3, t: 25, agent 5, result: (s = 16, a = :right, r = -1, sp = 17, t = 25)
e: 3, t: 30, agent 1, result: (s = 9, a = :right, r = -1, sp = 10, t = 30)
e: 3, t: 29, agent 2, result: (s = 16, a = :right, r = -1, sp = 17, t = 29)
e: 3, t: 28, agent 3, result: (s = 9, a = :right, r = -1, sp = 10, t = 28)
e: 3, t: 27, agent 4, result: (s = 16, a = :right, r = -1, sp = 17, t = 27)
e: 3, t: 26, agent 5, result: (s = 17, a = :right, r = -1, sp = 18, t = 26)
agent 1 is done
e: 3, t: 30, agent 2, result: (s = 17, a = :left, r = -1, sp = 16, t = 30)
e: 3, t: 29, agent 3, result: (s = 10, a = :right, r = -1, sp = 11, t = 29)
e: 3, t: 28, agent 4, result: (s = 17, a = :left, r = -1, sp = 16, t = 28)
e: 3, t: 27, agent 5, result: (s = 18, a = :left, r = -1, sp = 17, t = 27)
agent 1 is done
agent 2 is done
e: 3, t: 30, agent 3, result: (s = 11, a = :right, r = -1, sp = 12, t = 30)
e: 3, t: 29, agent 4, result: (s = 16, a = :l

e: 6, t: 23, agent 4, result: (s = 14, a = :right, r = -1, sp = 15, t = 23)
e: 6, t: 22, agent 5, result: (s = 17, a = :right, r = -1, sp = 18, t = 22)
agent 1 is done
e: 6, t: 26, agent 2, result: (s = 7, a = :left, r = -1, sp = 6, t = 26)
e: 6, t: 25, agent 3, result: (s = 6, a = :right, r = -1, sp = 7, t = 25)
e: 6, t: 24, agent 4, result: (s = 15, a = :left, r = -1, sp = 14, t = 24)
e: 6, t: 23, agent 5, result: (s = 18, a = :right, r = -1, sp = 19, t = 23)
agent 1 is done
e: 6, t: 27, agent 2, result: (s = 6, a = :right, r = -1, sp = 7, t = 27)
e: 6, t: 26, agent 3, result: (s = 7, a = :right, r = -1, sp = 8, t = 26)
e: 6, t: 25, agent 4, result: (s = 14, a = :left, r = -1, sp = 13, t = 25)
e: 6, t: 24, agent 5, result: (s = 19, a = :right, r = -1, sp = 20, t = 24)
agent 1 is done
e: 6, t: 28, agent 2, result: (s = 7, a = :right, r = -1, sp = 8, t = 28)
e: 6, t: 27, agent 3, result: (s = 8, a = :left, r = -1, sp = 7, t = 27)
e: 6, t: 26, agent 4, result: (s = 13, a = :left, r = -1

e: 9, t: 24, agent 1, result: (s = 7, a = :right, r = -1, sp = 8, t = 24)
e: 9, t: 23, agent 2, result: (s = 8, a = :left, r = -1, sp = 7, t = 23)
e: 9, t: 22, agent 3, result: (s = 7, a = :right, r = -1, sp = 8, t = 22)
e: 9, t: 21, agent 4, result: (s = 8, a = :left, r = -1, sp = 7, t = 21)
e: 9, t: 20, agent 5, result: (s = 11, a = :right, r = -1, sp = 12, t = 20)
e: 9, t: 25, agent 1, result: (s = 8, a = :left, r = -1, sp = 7, t = 25)
e: 9, t: 24, agent 2, result: (s = 7, a = :left, r = -1, sp = 6, t = 24)
e: 9, t: 23, agent 3, result: (s = 8, a = :right, r = -1, sp = 9, t = 23)
e: 9, t: 22, agent 4, result: (s = 7, a = :right, r = -1, sp = 8, t = 22)
e: 9, t: 21, agent 5, result: (s = 12, a = :left, r = -1, sp = 11, t = 21)
e: 9, t: 26, agent 1, result: (s = 7, a = :left, r = -1, sp = 6, t = 26)
e: 9, t: 25, agent 2, result: (s = 6, a = :left, r = -1, sp = 5, t = 25)
e: 9, t: 24, agent 3, result: (s = 9, a = :right, r = -1, sp = 10, t = 24)
e: 9, t: 23, agent 4, result: (s = 8, a 

e: 12, t: 16, agent 4, result: (s = 5, a = :right, r = -1, sp = 6, t = 16)
e: 12, t: 15, agent 5, result: (s = 10, a = :left, r = -1, sp = 9, t = 15)
e: 12, t: 20, agent 1, result: (s = 15, a = :right, r = -1, sp = 16, t = 20)
e: 12, t: 19, agent 2, result: (s = 4, a = :right, r = -1, sp = 5, t = 19)
e: 12, t: 18, agent 3, result: (s = 5, a = :left, r = -1, sp = 4, t = 18)
e: 12, t: 17, agent 4, result: (s = 6, a = :right, r = -1, sp = 7, t = 17)
e: 12, t: 16, agent 5, result: (s = 9, a = :left, r = -1, sp = 8, t = 16)
e: 12, t: 21, agent 1, result: (s = 16, a = :left, r = -1, sp = 15, t = 21)
e: 12, t: 20, agent 2, result: (s = 5, a = :left, r = -1, sp = 4, t = 20)
e: 12, t: 19, agent 3, result: (s = 4, a = :left, r = -1, sp = 3, t = 19)
e: 12, t: 18, agent 4, result: (s = 7, a = :left, r = -1, sp = 6, t = 18)
e: 12, t: 17, agent 5, result: (s = 8, a = :left, r = -1, sp = 7, t = 17)
e: 12, t: 22, agent 1, result: (s = 15, a = :left, r = -1, sp = 14, t = 22)
e: 12, t: 21, agent 2, resu

e: 15, t: 15, agent 3, result: (s = 10, a = :right, r = -1, sp = 11, t = 15)
e: 15, t: 14, agent 4, result: (s = 19, a = :right, r = -1, sp = 20, t = 14)
e: 15, t: 13, agent 5, result: (s = 8, a = :left, r = -1, sp = 7, t = 13)
e: 15, t: 18, agent 1, result: (s = 15, a = :left, r = -1, sp = 14, t = 18)
e: 15, t: 17, agent 2, result: (s = 10, a = :right, r = -1, sp = 11, t = 17)
e: 15, t: 16, agent 3, result: (s = 11, a = :left, r = -1, sp = 10, t = 16)
e: 15, t: 15, agent 4, result: (s = 20, a = :left, r = -1, sp = 19, t = 15)
e: 15, t: 14, agent 5, result: (s = 7, a = :left, r = -1, sp = 6, t = 14)
e: 15, t: 19, agent 1, result: (s = 14, a = :left, r = -1, sp = 13, t = 19)
e: 15, t: 18, agent 2, result: (s = 11, a = :left, r = -1, sp = 10, t = 18)
e: 15, t: 17, agent 3, result: (s = 10, a = :left, r = -1, sp = 9, t = 17)
e: 15, t: 16, agent 4, result: (s = 19, a = :left, r = -1, sp = 18, t = 16)
e: 15, t: 15, agent 5, result: (s = 6, a = :right, r = -1, sp = 7, t = 15)
e: 15, t: 20, a

e: 18, t: 13, agent 1, result: (s = 8, a = :right, r = -1, sp = 9, t = 13)
e: 18, t: 12, agent 2, result: (s = 11, a = :left, r = -1, sp = 10, t = 12)
e: 18, t: 11, agent 3, result: (s = 10, a = :right, r = -1, sp = 11, t = 11)
e: 18, t: 10, agent 4, result: (s = 11, a = :left, r = -1, sp = 10, t = 10)
e: 18, t: 9, agent 5, result: (s = 12, a = :left, r = -1, sp = 11, t = 9)
e: 18, t: 14, agent 1, result: (s = 9, a = :right, r = -1, sp = 10, t = 14)
e: 18, t: 13, agent 2, result: (s = 10, a = :left, r = -1, sp = 9, t = 13)
e: 18, t: 12, agent 3, result: (s = 11, a = :right, r = -1, sp = 12, t = 12)
e: 18, t: 11, agent 4, result: (s = 10, a = :left, r = -1, sp = 9, t = 11)
e: 18, t: 10, agent 5, result: (s = 11, a = :right, r = -1, sp = 12, t = 10)
e: 18, t: 15, agent 1, result: (s = 10, a = :right, r = -1, sp = 11, t = 15)
e: 18, t: 14, agent 2, result: (s = 9, a = :left, r = -1, sp = 8, t = 14)
e: 18, t: 13, agent 3, result: (s = 12, a = :right, r = -1, sp = 13, t = 13)
e: 18, t: 12, 



e: 21, t: 9, agent 1, result: (s = 12, a = :right, r = -1, sp = 13, t = 9)
e: 21, t: 8, agent 2, result: (s = 7, a = :right, r = -1, sp = 8, t = 8)
e: 21, t: 7, agent 3, result: (s = 14, a = :left, r = -1, sp = 13, t = 7)
e: 21, t: 6, agent 4, result: (s = 13, a = :right, r = -1, sp = 14, t = 6)
e: 21, t: 5, agent 5, result: (s = 8, a = :left, r = -1, sp = 7, t = 5)
e: 21, t: 10, agent 1, result: (s = 13, a = :right, r = -1, sp = 14, t = 10)
e: 21, t: 9, agent 2, result: (s = 8, a = :left, r = -1, sp = 7, t = 9)
e: 21, t: 8, agent 3, result: (s = 13, a = :left, r = -1, sp = 12, t = 8)
e: 21, t: 7, agent 4, result: (s = 14, a = :left, r = -1, sp = 13, t = 7)
e: 21, t: 6, agent 5, result: (s = 7, a = :right, r = -1, sp = 8, t = 6)
e: 21, t: 11, agent 1, result: (s = 14, a = :left, r = -1, sp = 13, t = 11)
e: 21, t: 10, agent 2, result: (s = 7, a = :left, r = -1, sp = 6, t = 10)
e: 21, t: 9, agent 3, result: (s = 12, a = :left, r = -1, sp = 11, t = 9)
e: 21, t: 8, agent 4, result: (s = 13

e: 24, t: 5, agent 1, result: (s = 8, a = :left, r = -1, sp = 7, t = 5)
e: 24, t: 4, agent 2, result: (s = 15, a = :right, r = -1, sp = 16, t = 4)
e: 24, t: 3, agent 3, result: (s = 14, a = :right, r = -1, sp = 15, t = 3)
e: 24, t: 2, agent 4, result: (s = 13, a = :right, r = -1, sp = 14, t = 2)
e: 24, t: 1, agent 5, result: (s = 12, a = :right, r = -1, sp = 13, t = 1)
e: 24, t: 6, agent 1, result: (s = 7, a = :left, r = -1, sp = 6, t = 6)
e: 24, t: 5, agent 2, result: (s = 16, a = :left, r = -1, sp = 15, t = 5)
e: 24, t: 4, agent 3, result: (s = 15, a = :right, r = -1, sp = 16, t = 4)
e: 24, t: 3, agent 4, result: (s = 14, a = :right, r = -1, sp = 15, t = 3)
e: 24, t: 2, agent 5, result: (s = 13, a = :right, r = -1, sp = 14, t = 2)
e: 24, t: 7, agent 1, result: (s = 6, a = :left, r = -1, sp = 5, t = 7)
e: 24, t: 6, agent 2, result: (s = 15, a = :left, r = -1, sp = 14, t = 6)
e: 24, t: 5, agent 3, result: (s = 16, a = :left, r = -1, sp = 15, t = 5)
e: 24, t: 4, agent 4, result: (s = 15

e: 27, t: 3, agent 1, result: (s = 10, a = :left, r = -1, sp = 9, t = 3)
e: 27, t: 2, agent 2, result: (s = 13, a = :left, r = -1, sp = 12, t = 2)
e: 27, t: 1, agent 3, result: (s = 12, a = :right, r = -1, sp = 13, t = 1)
e: 27, t: 4, agent 1, result: (s = 9, a = :left, r = -1, sp = 8, t = 4)
e: 27, t: 3, agent 2, result: (s = 12, a = :left, r = -1, sp = 11, t = 3)
e: 27, t: 2, agent 3, result: (s = 13, a = :right, r = -1, sp = 14, t = 2)
e: 27, t: 1, agent 4, result: (s = 12, a = :left, r = -1, sp = 11, t = 1)
e: 27, t: 5, agent 1, result: (s = 8, a = :left, r = -1, sp = 7, t = 5)
e: 27, t: 4, agent 2, result: (s = 11, a = :left, r = -1, sp = 10, t = 4)
e: 27, t: 3, agent 3, result: (s = 14, a = :right, r = -1, sp = 15, t = 3)
e: 27, t: 2, agent 4, result: (s = 11, a = :left, r = -1, sp = 10, t = 2)
e: 27, t: 1, agent 5, result: (s = 12, a = :right, r = -1, sp = 13, t = 1)
e: 27, t: 6, agent 1, result: (s = 7, a = :left, r = -1, sp = 6, t = 6)
e: 27, t: 5, agent 2, result: (s = 10, a 

e: 27, t: 30, agent 5, result: (s = 15, a = :left, r = -1, sp = 14, t = 30)
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 3 is done
agent 5 is done
agent 1 is done
agent 3 is done
agent 5 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 5 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 5 is done
agent 1 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 1 is done
agent 2 is done
agent 3 is done
agent 4 is done
e: 30, t: 1, agent 1, result: (s = 12, a = :right, r = -1, sp = 13, t = 1)
e: 30, t: 2, agent 1, result: (s = 13, a = :left, r = -1, sp = 12, t = 2)
e: 30, t: 1, agent 2, result: (s = 12, a = :right, r = -1, sp = 13, t = 1)
e: 30, t: 3, agent 1, result: (s = 12, a = :left, r = -1, sp = 11, t = 3)
e: 30, t: 2, agent 2, result: (s =

e: 30, t: 26, agent 3, result: (s = 9, a = :left, r = -1, sp = 8, t = 26)
e: 30, t: 25, agent 4, result: (s = 12, a = :right, r = -1, sp = 13, t = 25)
e: 30, t: 24, agent 5, result: (s = 13, a = :left, r = -1, sp = 12, t = 24)
e: 30, t: 29, agent 1, result: (s = 6, a = :left, r = -1, sp = 5, t = 29)
e: 30, t: 28, agent 2, result: (s = 13, a = :right, r = -1, sp = 14, t = 28)
e: 30, t: 27, agent 3, result: (s = 8, a = :left, r = -1, sp = 7, t = 27)
e: 30, t: 26, agent 4, result: (s = 13, a = :right, r = -1, sp = 14, t = 26)
e: 30, t: 25, agent 5, result: (s = 12, a = :left, r = -1, sp = 11, t = 25)
e: 30, t: 30, agent 1, result: (s = 5, a = :left, r = -1, sp = 4, t = 30)
e: 30, t: 29, agent 2, result: (s = 14, a = :right, r = -1, sp = 15, t = 29)
e: 30, t: 28, agent 3, result: (s = 7, a = :left, r = -1, sp = 6, t = 28)
e: 30, t: 27, agent 4, result: (s = 14, a = :left, r = -1, sp = 13, t = 27)
e: 30, t: 26, agent 5, result: (s = 11, a = :left, r = -1, sp = 10, t = 26)
agent 1 is done
e:

In [85]:

include("./Agent.jl")
# Thompson sampling

# Solution

# solve mdp in both directions
# Then each agent will pick on of the policies.
# common
num_agents = 1
agents = Any[]
num_states = 10
num_actions = 2
epochs = 100
H = floor(Int, 3 * num_states / 2)
actions = [1, 2]
action_map = Dict(1 => :left, :2 => :right)
rev_action_map = Dict(:left => 1, :right => 2)
states = 1:(num_states+2)


theta = 10


(Q_tables1, N_tables1, policies1) = PFAgent.setup_agents(states, num_states, num_agents,
                                                            actions, num_actions, ucb_pol)

true_mdp = PFChainMDP.PChainMDP(num_states+2,1.0,.9, theta)
PFAgent.run_chain!(policies=policies1,
           found_target=chain_found_target,
           mdps=[true_mdp],
           update_Q=update_Q,
           n_agents=num_agents,
           n_states=num_states,
           Q_tables=Q_tables1,
           N_tables=N_tables1,
           epochs=epochs,
           steps=H,
           rev_action_map=rev_action_map,
           stop_early=false)

# theta = -10
using JSON
println("Q1:$(json(Q_tables1,2))")
println("N1:$(json(N_tables1,2))")

true_mdp2 = PFChainMDP.PChainMDP(num_states+2,1.0,.9, -theta)
(Q_tables2, N_tables2, policies2) = PFAgent.setup_agents(states, num_states, num_agents,
                                                            actions, num_actions, ucb_pol)
PFAgent.run_chain!(policies=policies2,
           found_target=chain_found_target,
           true_mdp=true_mdp2,
           update_Q=update_Q,
           n_agents=num_agents,
           n_states=num_states,
           Q_tables=Q_tables2,
           N_tables=N_tables2,
           epochs=epochs,
           steps=H,
           rev_action_map=rev_action_map,
           stop_early=false)



println("Q2: $(json(Q_tables2[1],2))")
println("N2: $(json(N_tables2[1],2))")

e: 10, t: 1, agent 1, result: (s = 7, a = :right, r = -1, sp = 8, t = 1)
e: 10, t: 2, agent 1, result: (s = 8, a = :left, r = -1, sp = 7, t = 2)
e: 10, t: 3, agent 1, result: (s = 7, a = :left, r = -1, sp = 6, t = 3)
e: 10, t: 4, agent 1, result: (s = 6, a = :left, r = -1, sp = 5, t = 4)
e: 10, t: 5, agent 1, result: (s = 5, a = :right, r = -1, sp = 6, t = 5)
e: 10, t: 6, agent 1, result: (s = 6, a = :right, r = -1, sp = 7, t = 6)
e: 10, t: 7, agent 1, result: (s = 7, a = :left, r = -1, sp = 6, t = 7)
e: 10, t: 8, agent 1, result: (s = 6, a = :left, r = -1, sp = 5, t = 8)
e: 10, t: 9, agent 1, result: (s = 5, a = :left, r = -1, sp = 4, t = 9)
e: 10, t: 10, agent 1, result: (s = 4, a = :right, r = -1, sp = 5, t = 10)
e: 10, t: 11, agent 1, result: (s = 5, a = :left, r = -1, sp = 4, t = 11)
e: 10, t: 12, agent 1, result: (s = 4, a = :left, r = -1, sp = 3, t = 12)
e: 10, t: 13, agent 1, result: (s = 3, a = :left, r = -1, sp = 2, t = 13)
e: 10, t: 14, agent 1, result: (s = 2, a = :right, r



UndefKeywordError: UndefKeywordError: keyword argument mdps not assigned

In [86]:
# Now we need a policy which randomly picks between the 2
function thomp_pol_func(Q_tables, N_tables, i, actions, s)
    # passed in i ignored, kept for consistency
    i = rand(1:length(Q_tables))
    println("i:$i")
    # return 
    if haskey(Q_tables[i], s)
       # println("ucb: $(ucbs)")
       val, idx = findmax(Q_tables[i][s])
       println("Selected $val, $idx from $(Q_tables[i][s]) for $i, $s")
       #print("$(Q_tables[i][s])")
       return action_map[idx]
    else
       act = action_map[rand(actions, 1)[1]]
       println("random action $act") 
       return act
    end
end

thomp_pol_func (generic function with 1 method)

In [98]:
#Now we run with that policy
include("./Agent.jl")
function run_thompson_chain_simulations(nruns, Q_tables1, Q_tables2, N_tables1, N_tables2)
    runs = []
    for i in 1:nruns
        num_agents = 1
        agents = Any[]
        num_states = 10
        num_actions = 2
        epochs=1
        num_agents = 2
        thomp_policies = []
        H = floor(Int, 3 * num_states / 2)
        actions = [1, 2]
        action_map = Dict(1 => :left, :2 => :right)
        rev_action_map = Dict(:left => 1, :right => 2)
        states = 1:(num_states+2)
        print("Before q's")
        true_mdp = PFChainMDP.PChainMDP(num_states+2,1.0,.9, theta)
        Q_tables_thomp = Dict(1 => Q_tables1[1], 2 => Q_tables2[1])
        N_tables_thomp = Dict(1 => N_tables1[1], 2 => N_tables2[1])
        println("Q_T: $Q_tables_thomp")
        println("N_T: $N_tables_thomp")
        push!(thomp_policies, curry(curry(curry(curry(thomp_pol_func, Q_tables_thomp), N_tables_thomp),0), actions))
        push!(thomp_policies, curry(curry(curry(curry(thomp_pol_func, Q_tables_thomp), N_tables_thomp),0), actions))
        println("after policies")
        (Q_tables3, N_tables3, trash) = PFAgent.setup_agents(states, num_states, num_agents,
                                                                    actions, num_actions, ucb_pol)
        r_history = PFAgent.run_chain!(
                   policies=thomp_policies,
                   found_target=chain_found_target,
                   mdps=[true_mdp, true_mdp],
                   update_Q=update_Q,
                   n_agents=num_agents,
                   n_states=num_states,
                   Q_tables=Q_tables3,
                   N_tables=N_tables3,
                   epochs=epochs,
                   steps=H,
                   rev_action_map=rev_action_map,
                   stop_early=true)

        R = (num_states - 2) / 2
        println([r for (e,i,t,r) in r_history if i == 1])
        println([r for (e,i,t,r) in r_history if i == 2])
        reg_a1 = R - sum([r for (e,i,t,r) in r_history if i == 1])
        reg_a2 = R - sum([r for (e,i,t,r) in r_history if i == 2])
        push!(runs, (run, reg_a1, reg_a2))
    end
    return runs
end



run_thompson_chain_simulations (generic function with 1 method)

In [100]:
println(run_thompson_chain_simulations(1, Q_tables1, Q_tables2, N_tables1, N_tables2))

Before q'sQ_T: Dict(2=>Dict(2=>Dict(2=>0.0,1=>0.0),11=>Dict(2=>0.0,1=>0.0),7=>Dict(2=>0.0,1=>0.0),9=>Dict(2=>0.0,1=>0.0),10=>Dict(2=>0.0,1=>0.0),8=>Dict(2=>0.0,1=>0.0),6=>Dict(2=>0.0,1=>0.0),4=>Dict(2=>0.0,1=>0.0),3=>Dict(2=>0.0,1=>0.0),5=>Dict(2=>0.0,1=>0.0),12=>Dict(2=>0.0,1=>0.0),1=>Dict(2=>0.0,1=>0.0)),1=>Dict(2=>Dict(2=>10.0,1=>10.0),11=>Dict(2=>-9.99875,1=>-9.975),7=>Dict(2=>0.950118,1=>3.21343),9=>Dict(2=>-1.09252,1=>0.950118),10=>Dict(2=>-10.4526,1=>-0.0973877),8=>Dict(2=>-0.0973877,1=>2.05276),6=>Dict(2=>2.05276,1=>4.43519),4=>Dict(2=>4.43519,1=>7.075),3=>Dict(2=>5.72125,1=>8.5),5=>Dict(2=>3.21343,1=>5.72125),12=>Dict(2=>0.0,1=>0.0),1=>Dict(2=>0.0,1=>0.0)))
N_T: Dict(2=>[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0],1=>[0.0 0.0; 19.0 20.0; 47.0 29.0; 76.0 49.0; 110.0 72.0; 133.0 90.0; 165.0 114.0; 89.0 62.0; 53.0 37.0; 28.0 5.0; 2.0 3.0; 0.0 0.0])
after policies


BoundsError: BoundsError: attempt to access (Any[Dict(2=>Dict(2=>0.0,1=>0.0),11=>Dict(2=>0.0,1=>0.0),7=>Dict(2=>0.0,1=>0.0),9=>Dict(2=>0.0,1=>0.0),10=>Dict(2=>0.0,1=>0.0),8=>Dict(2=>0.0,1=>0.0),6=>Dict(2=>0.0,1=>0.0),4=>Dict(2=>0.0,1=>0.0),3=>Dict(2=>0.0,1=>0.0),5=>Dict(2=>0.0,1=>0.0),12=>Dict(2=>0.0,1=>0.0),1=>Dict(2=>0.0,1=>0.0)), Dict(2=>Dict(2=>0.0,1=>0.0),11=>Dict(2=>0.0,1=>0.0),7=>Dict(2=>0.0,1=>0.0),9=>Dict(2=>0.0,1=>0.0),10=>Dict(2=>0.0,1=>0.0),8=>Dict(2=>0.0,1=>0.0),6=>Dict(2=>0.0,1=>0.0),4=>Dict(2=>0.0,1=>0.0),3=>Dict(2=>0.0,1=>0.0),5=>Dict(2=>0.0,1=>0.0),12=>Dict(2=>0.0,1=>0.0),1=>Dict(2=>0.0,1=>0.0))], Dict(2=>[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0],1=>[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0]), Any[##1#2{##1#2{##1#2{##1#2{typeof(ucb_pol),Array{Any,1}},Dict{Int32,Array{Float64,2}}},Int64},Array{Int64,1}}(##1#2{##1#2{##1#2{typeof(ucb_pol),Array{Any,1}},Dict{Int32,Array{Float64,2}}},Int64}(##1#2{##1#2{typeof(ucb_pol),Array{Any,1}},Dict{Int32,Array{Float64,2}}}(##1#2{typeof(ucb_pol),Array{Any,1}}(ucb_pol, Any[Dict(2=>Dict(2=>0.0,1=>0.0),11=>Dict(2=>0.0,1=>0.0),7=>Dict(2=>0.0,1=>0.0),9=>Dict(2=>0.0,1=>0.0),10=>Dict(2=>0.0,1=>0.0),8=>Dict(2=>0.0,1=>0.0),6=>Dict(2=>0.0,1=>0.0),4=>Dict(2=>0.0,1=>0.0),3=>Dict(2=>0.0,1=>0.0),5=>Dict(2=>0.0,1=>0.0),12=>Dict(2=>0.0,1=>0.0),1=>Dict(2=>0.0,1=>0.0)), Dict(2=>Dict(2=>0.0,1=>0.0),11=>Dict(2=>0.0,1=>0.0),7=>Dict(2=>0.0,1=>0.0),9=>Dict(2=>0.0,1=>0.0),10=>Dict(2=>0.0,1=>0.0),8=>Dict(2=>0.0,1=>0.0),6=>Dict(2=>0.0,1=>0.0),4=>Dict(2=>0.0,1=>0.0),3=>Dict(2=>0.0,1=>0.0),5=>Dict(2=>0.0,1=>0.0),12=>Dict(2=>0.0,1=>0.0),1=>Dict(2=>0.0,1=>0.0))]), Dict(2=>[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0],1=>[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0])), 1), [1, 2]), ##1#2{##1#2{##1#2{##1#2{typeof(ucb_pol),Array{Any,1}},Dict{Int32,Array{Float64,2}}},Int64},Array{Int64,1}}(##1#2{##1#2{##1#2{typeof(ucb_pol),Array{Any,1}},Dict{Int32,Array{Float64,2}}},Int64}(##1#2{##1#2{typeof(ucb_pol),Array{Any,1}},Dict{Int32,Array{Float64,2}}}(##1#2{typeof(ucb_pol),Array{Any,1}}(ucb_pol, Any[Dict(2=>Dict(2=>0.0,1=>0.0),11=>Dict(2=>0.0,1=>0.0),7=>Dict(2=>0.0,1=>0.0),9=>Dict(2=>0.0,1=>0.0),10=>Dict(2=>0.0,1=>0.0),8=>Dict(2=>0.0,1=>0.0),6=>Dict(2=>0.0,1=>0.0),4=>Dict(2=>0.0,1=>0.0),3=>Dict(2=>0.0,1=>0.0),5=>Dict(2=>0.0,1=>0.0),12=>Dict(2=>0.0,1=>0.0),1=>Dict(2=>0.0,1=>0.0)), Dict(2=>Dict(2=>0.0,1=>0.0),11=>Dict(2=>0.0,1=>0.0),7=>Dict(2=>0.0,1=>0.0),9=>Dict(2=>0.0,1=>0.0),10=>Dict(2=>0.0,1=>0.0),8=>Dict(2=>0.0,1=>0.0),6=>Dict(2=>0.0,1=>0.0),4=>Dict(2=>0.0,1=>0.0),3=>Dict(2=>0.0,1=>0.0),5=>Dict(2=>0.0,1=>0.0),12=>Dict(2=>0.0,1=>0.0),1=>Dict(2=>0.0,1=>0.0))]), Dict(2=>[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0],1=>[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0])), 2), [1, 2])])
  at index [4]