In [1]:
using POMDPs
using Random # for AbstractRNG
using POMDPModelTools # for Deterministic

In [2]:
struct BabyPOMDP <: POMDP{Bool, Bool, Bool}
    r_feed::Float64
    r_hungry::Float64
    p_become_hungry::Float64
    p_cry_when_hungry::Float64
    p_cry_when_not_hungry::Float64
    discount::Float64   
end

BabyPOMDP() = BabyPOMDP(-5., -10., 0.1, 0.8, 0.1, 0.9);

In [3]:
function POMDPs.gen(m::BabyPOMDP, s, a, rng)
    # transition model
    if a # feed
        sp = false
    elseif s # hungry
        sp = true
    else # not hungry
        sp = rand(rng) < m.p_become_hungry
    end
    
    # observation model
    if sp # hungry
        o = rand(rng) < m.p_cry_when_hungry
    else # not hungry
        o = rand(rng) < m.p_cry_when_not_hungry
    end
    
    # reward model
    r = s*m.r_hungry + a*m.r_feed
    
    # create and return a NamedTuple
    return (sp=sp, o=o, r=r)
end

In [4]:
POMDPs.initialstate_distribution(m::BabyPOMDP) = Deterministic(false)

In [5]:
using POMDPSimulators
using POMDPPolicies

m = BabyPOMDP()

# policy that maps every input to a feed (true) action
policy = FunctionPolicy(o->true)

for (s, a, r) in stepthrough(m, policy, "s,a,r", max_steps=10)
    @show s
    @show a
    @show r
    println()
end

MethodError: MethodError: no method matching initialstate_distribution(::BabyPOMDP)
Closest candidates are:
  initialstate_distribution(!Matched::FullyObservablePOMDP) at /Users/coreyshono/.julia/packages/POMDPModelTools/p5dFB/src/fully_observable_pomdp.jl:44
  initialstate_distribution(!Matched::UnderlyingMDP) at /Users/coreyshono/.julia/packages/POMDPModelTools/p5dFB/src/underlying_mdp.jl:22

In [10]:
using QuickPOMDPs
using POMDPPolicies

In [8]:
function update_d(tod)
    # Fitted to a sine wave with noise 
    # TODO: need to adjust equation params when converting to realistic data 
    # TODO: adjust clamp when realistic data (maybe functionalize clamp) 
    # TODO: fit model dependent on TOD 
    # Maybe fit a BN to this data? 
    noise_odt = 1
    
    d = 2.5*sin((tod+1.75)*pi/2.5)+5
    d += rand(rng, Normal(0,noise_odt), 1) 
    d = max(min(round(odt),10), 0) # Clamp to 0-10 int 
    
    return d
end

function update_tou(tod) 
    TOU_SCHEDULE = [2,2,3,4,2]; 
    # TODO: Change tou_schedule to take as model input rather than hard code 
    tou = TOU_SCHEDULE[tod] 
    return tou 
end

update_tou (generic function with 1 method)

In [None]:
mountaincar = QuickMDP(
    function (s, a, rng)        
        x, v = s
        vp = clamp(v + a*0.001 + cos(3*x)*-0.0025, -0.07, 0.07)
        xp = x + vp
        if xp > 0.5
            r = 100.0
        else
            r = -1.0
        end
        return (sp=(xp, vp), r=r)
    end,
    actions = [-1., 0., 1.],
    initialstate = (-0.5, 0.0),
    discount = 0.95,
    isterminal = s -> s[1] > 0.5
)

smarthome = QuickMDP(
    function (s, a, rng) 
        d, soc, tou, tod, t = s
        dp = update_d(tod) 
        soc_p = soc + a
        tod_p = rem(s.tod + 1, 5)
        tou_p = update_tou(tod_p) 
        tp = t + 1 
        
        r = -tou_p * (d + a)
        return (sp=(dp, soc_p, tou_p, tod_p, tp), r=r) 
        actions = [-1, 0, 1]
        initialstate = (2, 3, 2, 1, 1) 
        states = [collect(1:10), collect(1:5), collect(1:5), collect(1:5), collect(1:20)] 
        isterminal = s -> s[5] = 11
    end
        
)
      






In [16]:
smarthome = QuickMDP(
    function (s, a, rng) 
        d, soc, tou, tod, t = s
        dp = update_d(tod) 
        soc_p = soc + a
        tod_p = rem(s.tod + 1, 5)
        tou_p = update_tou(tod_p) 
        tp = t + 1 
        
        r = -tou_p * (d + a)
        return (sp=(dp, soc_p, tou_p, tod_p, tp), r=r) 
        end, 
    actions = [-1, 0, 1], 
    initialstate = (2, 3, 2, 1, 1),
    states = (collect(1:10), collect(1:5), collect(1:5), collect(1:5), collect(1:20)), 
    isterminal = s -> s[5] = 11
)


QuickMDP{UUID("55463fca-c857-4e72-bbd8-7e6ccf06feed"),Array{Int64,1},Int64,NamedTuple{(:stateindex, :isterminal, :actionindex, :initialstate, :gen, :actions, :states, :discount),Tuple{Dict{Array{Int64,1},Int64},getfield(Main, Symbol("##18#20")),Dict{Int64,Int64},NTuple{5,Int64},getfield(Main, Symbol("##17#19")),Array{Int64,1},NTuple{5,Array{Int64,1}},Float64}}}((stateindex = Dict([1, 2, 3, 4, 5, 6, 7, 8, 9, 10] => 1,[1, 2, 3, 4, 5] => 4,[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] => 5), isterminal = getfield(Main, Symbol("##18#20"))(), actionindex = Dict(0 => 2,-1 => 1,1 => 3), initialstate = (2, 3, 2, 1, 1), gen = getfield(Main, Symbol("##17#19"))(), actions = [-1, 0, 1], states = ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]), discount = 1.0))

In [21]:
#u = evaluate(m, FunctionPolicy(x->:left))
function my_policy(x)
    if x > 2
        return -1
    else
        return 1
    end
end

u = evaluate(smarthome, FunctionPolicy(x->1)) 

MethodError: MethodError: no method matching !(::Int64)
Closest candidates are:
  !(!Matched::Missing) at missing.jl:79
  !(!Matched::Bool) at bool.jl:35
  !(!Matched::Function) at operators.jl:894