In [9]:
using Revise
using POMDPSimulators
using Interact
using Plots
plotly()
include("trafficlight.jl")

pomdp = TrafficLight(TrafficParams())
sol = TrafficWorldSolver(max_iters=50)
@time policy = solve(sol, pomdp);
hr = HistoryRecorder()
simulate(hr, pomdp, policy)

position_points = LogSpaceAround((w.params).initial_state[1], 0.0, w.goal_position, 45) = [-10.0, -8.99395, -8.0799, -7.24946, -6.49497, -5.80948, -5.18669, -4.62086, -4.10678, -3.63972, -3.21537, -2.82983, -2.47956, -2.16132, -1.87219, -1.6095, -1.37084, -1.154, -0.956996, -0.77801, -0.615394, -0.467651, -0.33342, -0.211467, -0.100666, 0.0, 0.0988928, 0.207565, 0.326985, 0.458214, 0.602421, 0.760889, 0.935028, 1.12639, 1.33667, 1.56775, 1.82169, 2.10073, 2.40737, 2.74433, 3.11462, 3.52153, 3.96868, 4.46004, 5.0]
finished iteration 50
extracting policy...     done.
  9.425603 seconds (213.71 M allocations: 9.546 GiB, 24.08% gc time)


MDPHistory{SArray{Tuple{3},Float64,1,3},Symbol}(SArray{Tuple{3},Float64,1,3}[[-10.0, 10.0, 12.4271], [-9.0, 10.0, 12.3271], [-8.1, 9.0, 12.2271], [-7.3, 8.0, 12.1271], [-6.5, 8.0, 12.0271], [-5.8, 7.0, 11.9271], [-5.2, 6.0, 11.8271], [-4.6, 6.0, 11.7271], [-4.0, 6.0, 11.6271], [-3.5, 5.0, 11.5271]  …  [1.1, 3.0, -0.772883], [1.5, 4.0, -0.872883], [2.0, 5.0, -0.972883], [2.4, 4.0, -1.07288], [2.8, 4.0, -1.17288], [3.2, 4.0, -1.27288], [3.6, 4.0, -1.37288], [4.1, 5.0, -1.47288], [4.7, 6.0, -1.57288], [5.4, 7.0, -1.67288]], Symbol[:accelerate, :brake, :brake, :cruise, :brake, :brake, :cruise, :cruise, :brake, :cruise  …  :accelerate, :accelerate, :accelerate, :brake, :cruise, :cruise, :cruise, :accelerate, :accelerate, :accelerate], [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0  …  -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], Any[nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing  …  nothing, nothing, nothing, nothing, no

In [None]:

@manipulate for i in 1:length(sol.value_hist)
    v = sol.value_hist[i]
    plot(TrafficLightVis(pomdp, f=s->clamp(evaluate(v, Vec3(s..., .1)), -13, 0)))
end

In [8]:

# Plot Position by Velocity for a given time.
@manipulate for t in range(-10,stop=10, step=.1)
    v = last(sol.value_hist)
    plot(TrafficLightVis(pomdp, f=s->clamp(evaluate(v, Vec3(s..., t)), -50, 0), title="Value at time: $t"))
end

In [None]:
# Plot Position by Velocity for a given time.
@manipulate for t in range(-10,stop=10, step=.1)
    plot(TrafficLightVis(pomdp, f=s->action_ind(policy, Vec3(s..., t)), title="Policy at time $t"))
end

In [None]:
policy

In [None]:
hr = HistoryRecorder()
history = simulate(hr, pomdp, policy)
println(length(history))
for (s, a, r, sp) in eachstep(history, "(s, a, r, sp)")    
    println("reward $r: $s -> $sp, after $a")
end

In [None]:
init_s =Vec3(-.6, 10, .1)
@show new_s = generate_s(pomdp, init_s, :accelerate, Random.GLOBAL_RNG)
@show isterminal(pomdp, new_s)
@show reward(pomdp, init_s, :accelerate, new_s)
v = last(sol.value_hist)
@show evaluate(v, new_s)
@show evaluate(v, init_s)


In [None]:
times = (exp10.(range(0, stop=log10(11), length=30)).-1)
vcat(times[end:-1:2].*-1, times)


In [None]:
w = pomdp    
@show position_points = range(w.params.initial_state[1], stop=w.goal_position, length=30)
    velocity_points = range(w.params.v_limits[1], stop=w.params.v_limits[2], length=30)
    time_points = range(-w.params.period, stop=w.params.period, length=30)
    grid = RectangleGrid(position_points, velocity_points, time_points)
    sol.value_hist = []
    data = zeros(length(grid))
    val = GIValue(grid, data)

    @show interp, _ = GridInterpolations.interpolants(val.grid, Vec3(.1, 10.0, .1))

for _ in 1:10
        newdata = similar(data)
        for i in 1:length(grid)
            s = Vec3(ind2x(grid, i))
            is_interesting = i == 1790
            if is_interesting
                println("INTERESTING: ", s, i)
            end
            if isterminal(w, s)
                newdata[i] = 0.0
            else
                best_Q = -Inf
                for a in actions(w, s)

                    sp, r = generate_sr(w, s, a, sol.rng)
                    Q = r + discount(w)*evaluate(val, sp)
                    best_Q = max(best_Q, Q)
                    if is_interesting
                        println(" VALUES: ", a, s, sp, best_Q)
                    end
                end
                newdata[i] = best_Q
            end
        end
        push!(sol.value_hist, val)
        val = GIValue(grid, newdata)
    
    init_s =Vec3(-.173, 10, .1)
@show new_s = generate_s(pomdp, init_s, :accelerate, Random.GLOBAL_RNG)
@show isterminal(pomdp, new_s)
@show reward(pomdp, init_s, :accelerate, new_s)
@show evaluate(val, new_s)
@show evaluate(val, init_s)
end



In [None]:
function LogSpaceAround(min::T, pivot::T, max::T, num_vals::Int, base::T = T(10)) where {T<:Real}
    @assert(pivot >= min)
    @assert(pivot <= max)
    @assert(num_vals > 0)
    low_range = log(base, pivot - min + 1)
    hi_range = log(base, max - pivot + 1)
    num_low = round(Int64, low_range / (low_range + hi_range) * num_vals)
    num_hi = num_vals - num_low
    out = Array{T,1}(undef, num_vals)
    low_vals = view(out, 1:num_low - 1)
    hi_vals = view(out, (num_low + 1):num_vals)
    low_vals .= range(low_range, stop=0, length=num_low)[1:end-1]
    out[num_low] = pivot
    hi_vals .= range(0, stop=hi_range, length=num_hi + 1)[2:end]
    @. low_vals = (base ^ low_vals - 1)*-1 + pivot
    @. hi_vals = (base ^ hi_vals - 1) + pivot
    return out
end

out = LogSpaceAroundd(-10.0, 0.0, 5.0, 10)
print(out)
out = LogSpaceAroundd(10.0, 15.0, 30.0, 10)
print(out)

