In [1]:
using POMDPs
using Distributions: Normal, Uniform, DiscreteUniform, Multinomial, Hypergeometric, Binomial
using Random
using LinearAlgebra
using Statistics
using Plots
import POMDPs: initialstate, gen, actions, discount, isterminal
Random.seed!(1);

In [2]:
mutable struct m <: POMDPs.POMDP{Array{Int64,1},Int,Int}  #initialize POMDP 
    discount_factor::Float64
end
m() = m(.95)
discount(p::m) = p.discount_factor

using POMDPModelTools 
actions(::m) = [1, 2]



actions (generic function with 11 methods)

In [3]:
function gen(m::m, s::Array{Int64,1}, a::Int, rng::AbstractRNG) #function to generate (s,a) \to (s',r)

    if a == 1
       sp = waitT(s)
    end
    if a == 2
       sp = lockT(s)
    end

  
    if a == 1
    r = 10*s[1]+ -5 * s[2] + 10*s[3] - 50 * s[4] #reward for wait
    end
  
    if a == 2
    r =  -100 #reward forlockdown
    end
  
    return (sp=sp, r=r)
end;

In [4]:
np = Uniform(0.3,0.6)
function waitT(s::Array{Int64,1})  #state change if action "wait"
c=3

p= rand(np)
ri = 0.05
  
A = zeros(Int8, s[1])
  
new_cases = 0
if s[1] != 0 && s[2] !=0
  for i = 1:s[2] #each infected person contacts c=3 people and infects them with probability p
    for j = 1:c
      ind = rand(1:s[1])
      b = rand(Uniform(0,1)) < p
      A[ind]= b ? 1 : A[ind]
    end
  end
 end
if s[1] !=0
  for i=1:s[1] #assume low chance of outside infection 
    a = rand(Uniform(0,1)) > ri 
    new_cases= !a ? new_cases+1 : new_cases+0
  end
end
 

s2p = [0.2,0.5,0.25,0.05] #multinomial for infected people
s2 = Multinomial(s[2],s2p)
s2r = rand(s2)

s3p = [0.0,0.02,0.98,0.0] #multinomial for immune people
s3 = Multinomial(s[3],s3p)
s3r = rand(s3)
  

n = min(s[1],new_cases + sum(A))
s=  [s[1]-n + s2r[1],s2r[2] + n+s3r[2],s3r[3] + s2r[3] ,s[4] + s2r[4]+s3r[4]]
return s
  
end

waitT (generic function with 1 method)

In [5]:
function lockT(s::Array{Int64,1}) #state change if action "lockdown"

ri = 0.02
  new_cases = 0
if s[1] !=0
  for i=1:s[1]
    a = rand(Uniform(0,1)) > ri
    new_cases= !a ? new_cases+1 : new_cases+0
  end
end
 

s2p = [0.2,0.5,0.25,0.05]
s2 = Multinomial(s[2],s2p)
s2r = rand(s2)

s3p = [0.0,0.02,0.98,0.0]
s3 = Multinomial(s[3],s3p)
s3r = rand(s3)
  

n = min(s[1],new_cases)
s=  [s[1]-n + s2r[1],s2r[2] + n+s3r[2],s3r[3] + s2r[3] ,s[4] + s2r[4]+s3r[4]]
return s
  
end

lockT (generic function with 1 method)

In [6]:
using Distributions

function hyp(s::Array{Int64,1},sen::Float64,spe::Float64,smpl::Int,x::Int) #calculate PDF for observation distribution
  #println(s[2])
  #println(s[3])
  #println(smpl)
  if s[2] < 0
    s[2]=0
  end
  h=Hypergeometric(s[2],s[1]+s[3],smpl)
  
    p=0
    for i=0:min(smpl,d.s[2]) #get anywhere from 0 to min(m,inf) infected people
        p = p + pdf(h,i) * bin(x,i,smpl-i,sen,spe) #probability if i infected people, times probability of x positives given i infected people
    end
    return p
end
function bin(t::Int, I::Int, NI::Int,sen::Float64,spe::Float64)
  bI = Binomial(I,sen)
  bNI = Binomial(NI,1-spe)
  tp =0
  for i=0:t
    #println(i)
    p = pdf(bI,i) * pdf(bNI,t-i)
    #print(p)
    tp = tp +p
  end
  return tp
end

function Obs(s::Array{Int64,1},sen::Float64,spe::Float64,smpl::Int) #generate observation in state s, sample size smpl, sensitivity sen and specificity spe
  inf = 0
 
  tinf = s[2]
  tnotinf = s[1]+s[3]
  
  posit = 0
  
  for i=1:smpl
    a = rand(Uniform(0,tinf+tnotinf))<tinf
    inf = a ? inf+1 : inf+0
    tinf = a ? tinf-1 : tinf+0
    tnotinf = !a ? tnotinf-1 : tnotinf+0
  end
  for i=1:inf
    posit = rand(Uniform(0,1)) < sen ? posit+1 : posit+0
  end
  for i=1:smpl-inf
    posit = rand(Uniform(0,1)) > spe ? posit+1 : posit+0
  end
  
  return posit
  
end

Obs (generic function with 1 method)

In [7]:
struct mDist
    s::Array{Int,1}
end


function Base.rand(rng::AbstractRNG, d::Random.SamplerTrivial{mDist}) #generate observations
  return Obs(d[].s,0.98,0.98,d[].s[1]+d[].s[2]+d[].s[3])
end

function POMDPs.pdf(d::mDist, x::Int) #PDF of observation distribution
  return hyp(d.s,0.98,0.98,d.s[1]+d.s[2]+d.s[3],x)
end

function POMDPs.observation(p::m,a::Int,sp::Array{Int,1})
     return mDist(sp) 
end

In [8]:
d = mDist([10,20,30,40])


mDist([10, 20, 30, 40])

In [9]:
struct initialD #initial state
end


function Base.rand(rng::AbstractRNG, d::Random.SamplerTrivial{initialD}) #deterministic sample since initial state is deterministic (doesn't have to be)
  return [90,10,0,0]
end

start = initialD()

initialD()

In [10]:
using Distributions, LinearAlgebra, Statistics

POMDPs.initialstate(m::m) = start

In [11]:
using POMDPPolicies: FunctionPolicy #heuristic policy toevaluate leaf nodes

function my_heuristic(b::Array{Int64,1})

return 1
 
end

heuristic_policy = FunctionPolicy(my_heuristic)

FunctionPolicy{typeof(my_heuristic)}(my_heuristic)

In [12]:
#initialize solver
using BasicPOMCP 
using POMDPSimulators
solver = POMCPSolver(tree_queries=100000, c=7000, max_depth=20, estimate_value = FORollout(heuristic_policy)) 
pomdp = m()
planner = solve(solver, pomdp);

In [13]:
Random.seed!(2);
using ParticleFilters
i=0
rw=0
filter = SIRParticleFilter(pomdp, 2000)

for (s,a,r,sp,o,b) in stepthrough(pomdp, planner, filter, "s,a,r,sp,o,b")
    @show (s,a,r,sp,o)
  println(mean(b))
    rw+=r
i+=1
  if i == 30
    break
  end
end
println(rw)

(s, a, r, sp, o) = ([90, 10, 0, 0], 1, 850, [74, 22, 2, 2], 22)
[90.0, 10.0, 0.0, 0.0]
(s, a, r, sp, o) = ([74, 22, 2, 2], 2, -100, [76, 14, 8, 2], 18)
[77.257, 19.465, 2.734, 0.544]
(s, a, r, sp, o) = ([76, 14, 8, 2], 2, -100, [77, 7, 13, 3], 7)
[77.226, 15.324, 6.2785, 1.1715]
(s, a, r, sp, o) = ([77, 7, 13, 3], 1, 715, [67, 16, 14, 3], 20)
[80.5975, 6.221, 11.004, 2.1775]
(s, a, r, sp, o) = ([67, 16, 14, 3], 2, -100, [67, 10, 20, 3], 11)
[67.244, 18.227, 12.084, 2.445]
(s, a, r, sp, o) = ([67, 10, 20, 3], 1, 670, [47, 24, 24, 5], 28)
[70.0705, 9.7985, 16.7205, 3.4105]
(s, a, r, sp, o) = ([47, 24, 24, 5], 2, -100, [53, 13, 27, 7], 14)
[57.5615, 19.79, 19.0485, 3.6]
(s, a, r, sp, o) = ([53, 13, 27, 7], 2, -100, [57, 3, 32, 8], 4)
[59.931, 12.432, 23.1805, 4.4565]
(s, a, r, sp, o) = ([57, 3, 32, 8], 1, 475, [55, 6, 31, 8], 9)
[63.0635, 3.6845, 27.4915, 5.7605]
(s, a, r, sp, o) = ([55, 6, 31, 8], 1, 430, [48, 9, 35, 8], 11)
[57.8195, 7.94, 28.223, 6.0175]
(s, a, r, sp, o) = ([48, 9, 35,

In [14]:
using POMDPPolicies: FunctionPolicy

function cutoff(b)

  if i % 2 == 0
    return 2
  else
    return 1
  end
end

cutoff_p = FunctionPolicy(cutoff)

function simple(b)

  if mean(b)[2] > 8
    return 2
  else
    return 1
  end
end

simple_p = FunctionPolicy(simple)



FunctionPolicy{typeof(simple)}(simple)

In [15]:
using ParticleFilters
filter = SIRParticleFilter(pomdp, 2000)
i = 0

rwe = 0

for (s,a,r,sp,o,b) in stepthrough(pomdp, simple_p, filter, "s,a,r,sp,o,b")
    @show (s,a,r,sp,o)
  println(mean(b))
    rwe+=r
    i+=1
  if i == 30
    break
  end
end
println(rwe)

(s, a, r, sp, o) = ([90, 10, 0, 0], 2, -100, [91, 4, 4, 1], 6)
[90.0, 10.0, 0.0, 0.0]
(s, a, r, sp, o) = ([91, 4, 4, 1], 1, 880, [80, 15, 4, 1], 18)
[91.3235, 5.055, 3.017, 0.6045]
(s, a, r, sp, o) = ([80, 15, 4, 1], 2, -100, [79, 9, 11, 1], 14)
[78.7845, 16.2975, 4.095, 0.823]
(s, a, r, sp, o) = ([79, 9, 11, 1], 2, -100, [76, 9, 14, 1], 14)
[79.2375, 11.7905, 7.4815, 1.4905]
(s, a, r, sp, o) = ([76, 9, 14, 1], 2, -100, [77, 4, 18, 1], 5)
[78.088, 10.9795, 9.1, 1.8325]
(s, a, r, sp, o) = ([77, 4, 18, 1], 1, 880, [64, 16, 19, 1], 16)
[80.245, 4.264, 12.8095, 2.6815]
(s, a, r, sp, o) = ([64, 16, 19, 1], 2, -100, [67, 10, 21, 2], 10)
[69.6155, 14.2015, 13.328, 2.855]
(s, a, r, sp, o) = ([67, 10, 21, 2], 2, -100, [68, 7, 23, 2], 9)
[71.2025, 8.5355, 16.665, 3.597]
(s, a, r, sp, o) = ([68, 7, 23, 2], 1, 775, [60, 18, 20, 2], 21)
[70.968, 6.974, 18.118, 3.94]
(s, a, r, sp, o) = ([60, 18, 20, 2], 2, -100, [60, 11, 27, 2], 13)
[57.713, 18.885, 19.393, 4.009]
(s, a, r, sp, o) = ([60, 11, 27, 2]

In [16]:
using ParticleFilters
filter = SIRParticleFilter(pomdp, 2000)
i = 0

rwe = 0

for (s,a,r,sp,o,b) in stepthrough(pomdp, cutoff_p, filter, "s,a,r,sp,o,b")
    @show (s,a,r,sp,o)
  println(mean(b))
    rwe+=r
    i+=1
  if i == 30
    break
  end
end
println(rwe)

(s, a, r, sp, o) = ([90, 10, 0, 0], 2, -100, [92, 5, 2, 1], 7)
[90.0, 10.0, 0.0, 0.0]
(s, a, r, sp, o) = ([92, 5, 2, 1], 1, 865, [83, 13, 3, 1], 14)
[90.86, 5.751, 2.8265, 0.5625]
(s, a, r, sp, o) = ([83, 13, 3, 1], 2, -100, [84, 7, 8, 1], 11)
[81.9535, 12.8785, 4.286, 0.882]
(s, a, r, sp, o) = ([84, 7, 8, 1], 1, 835, [78, 13, 8, 1], 16)
[82.245, 9.1505, 7.1385, 1.466]
(s, a, r, sp, o) = ([78, 13, 8, 1], 2, -100, [77, 11, 10, 2], 11)
[73.336, 15.123, 9.5355, 2.0055]
(s, a, r, sp, o) = ([77, 11, 10, 2], 1, 715, [63, 25, 10, 2], 23)
[74.8445, 9.446, 12.987, 2.7225]
(s, a, r, sp, o) = ([63, 25, 10, 2], 2, -100, [70, 12, 15, 3], 13)
[61.9615, 19.5465, 15.415, 3.077]
(s, a, r, sp, o) = ([70, 12, 15, 3], 1, 640, [54, 26, 16, 4], 26)
[64.58, 11.494, 19.9025, 4.0235]
(s, a, r, sp, o) = ([54, 26, 16, 4], 2, -100, [61, 12, 21, 6], 13)
[52.338, 19.743, 23.2255, 4.6935]
(s, a, r, sp, o) = ([61, 12, 21, 6], 1, 460, [47, 24, 21, 8], 25)
[55.163, 11.6005, 27.6135, 5.623]
(s, a, r, sp, o) = ([47, 24, 

Evaluation

In [17]:
Random.seed!(1);

In [18]:
#evaluate simple policy

using POMDPPolicies: FunctionPolicy

function simple(b) #try every k days
  if i % k == 0
    return 2
  else
    return 1
  end
end

simple_p = FunctionPolicy(simple)

using ParticleFilters
filter = SIRParticleFilter(pomdp, 2000)
k = 0
for c = 0:5
  
  k += 1
  print("Every: ")
  println(k)
  tr = 0
  for iter = 1:10
    i = 0
    rwe = 0

for (s,a,r,sp,o,b) in stepthrough(pomdp, simple_p, filter, "s,a,r,sp,o,b")
    #@show (s,a,r,sp,o)
  #println(mean(b))
    rwe+=r
    i+=1
  if i == 30
        #print(sp)
    break
  end
     
end
     
   tr += rwe
  end
  println(tr/10)
end

Every: 1
-3000.0
Every: 2
694.5
Every: 3
-263.0
Every: 4
-1964.5
Every: 5
-4092.0
Every: 6
-2272.0


In [19]:
#cutoff heuristic

using POMDPPolicies: FunctionPolicy

function cutoff(b) #lockdown if greater than k in belief state
  if mean(b)[2] > k
    return 2
  else
    return 1
  end
end

cutoff_p = FunctionPolicy(cutoff)

using ParticleFilters
filter = SIRParticleFilter(pomdp, 2000)
k = -1
for c = 1:15
  
  k += 1
  print("Cutoff: ")
  println(k)
  tr = 0
  for iter = 1:30
    i = 0
    rwe = 0

for (s,a,r,sp,o,b) in stepthrough(pomdp, cutoff_p, filter, "s,a,r,sp,o,b")
    #@show (s,a,r,sp,o)
  #println(mean(b))
    rwe+=r
    i+=1
  if i == 30
        #print(sp)
    break
  end
end
   tr += rwe
    #println(rwe)
  end
  println(tr/30)
end

Cutoff: 0
-3000.0
Cutoff: 1
-2490.5
Cutoff: 2
-1294.0
Cutoff: 3
546.8333333333334
Cutoff: 4
1666.0
Cutoff: 5
2587.1666666666665
Cutoff: 6
2348.5
Cutoff: 7
3065.6666666666665
Cutoff: 8
2834.8333333333335
Cutoff: 9
3161.8333333333335
Cutoff: 10
2409.1666666666665
Cutoff: 11
1720.6666666666667
Cutoff: 12
1756.0
Cutoff: 13
1250.3333333333333
Cutoff: 14
11.333333333333334


In [20]:
Random.seed!(2);

In [21]:
#POMCP evaluation
using ParticleFilters
filter = SIRParticleFilter(pomdp, 2000)
tr = 0
for iter = 1:10
  i = 0
  rwe = 0

for (s,a,r,sp,o,b) in stepthrough(pomdp, planner, filter, "s,a,r,sp,o,b")
    #@show (s,a,r,sp,o)
  #println(mean(b))
    rwe+=r
    i+=1
  if i == 30
    break
  end
end
  println(rwe)
   tr += rwe
end
println("Average: ")
  println(tr/10)

3510
1645
2175
1355
1495
1400
2445
3530
3010
4525
Average: 
2509.0
