In [79]:
from copy import deepcopy
import numpy as np
from scipy.stats import poisson

In [83]:
credit = 10
move_cost = 2
mean_req_first = 3
mean_req_second = 4
mean_ret_first = 3
mean_ret_second = 2
num_cars = 10
max_moves = 5
discount_factor = 0.9
max_req_ret_per_day = 11
free_shuttles = 2
parking_limit = 10
parking_cost = 4

In [84]:
values = np.zeros((num_cars+1, num_cars+1)) # Since we can have 0 to num_cars
policy = np.zeros((num_cars+1, num_cars+1), dtype=int)
actions = np.arange(-max_moves, max_moves+1)

In [85]:
# Precompute poisson PMFs because it takes a long time otherwise
poisson_pmfs = {}
for i in range(max_req_ret_per_day):
    poisson_pmfs[(i, mean_req_first)] = poisson.pmf(i, mean_req_first)
    poisson_pmfs[(i, mean_req_second)] = poisson.pmf(i, mean_req_second)
    poisson_pmfs[(i, mean_ret_first)] = poisson.pmf(i, mean_ret_first)
    poisson_pmfs[(i, mean_ret_second)] = poisson.pmf(i, mean_ret_second)
        
        

In [86]:
def get_return(i, j, action, values):
    return_val = 0.0
    remaining_cars_first = min(i - action, num_cars)
    remaining_cars_second = min(j + action, num_cars)
    for req_1 in range(max_req_ret_per_day):
        for req_2 in range(max_req_ret_per_day):

            num_cars_first = remaining_cars_first
            num_cars_second = remaining_cars_second
            fullfilled_req_first = min(num_cars_first, req_1)
            fullfilled_req_second = min(num_cars_second, req_2)

            reward = (fullfilled_req_first + fullfilled_req_second) * credit
            num_cars_first -= fullfilled_req_first
            num_cars_second -= fullfilled_req_second

            for ret_1 in range(max_req_ret_per_day):
                for ret_2 in range(max_req_ret_per_day):
                    num_cars_first_ = min(num_cars_first + ret_1, num_cars)
                    num_cars_second_ = min(num_cars_second + ret_2, num_cars)
                    parking_deduction = 0
                    if num_cars_first_ > parking_limit:
                        parking_deduction += parking_cost * (num_cars_first_ - parking_limit)
                    if num_cars_second_ > parking_limit:
                        parking_deduction += parking_cost * (num_cars_second_ - parking_limit)
                        
                    return_val += (poisson_pmfs[(req_1, mean_req_first)] * \
                                   poisson_pmfs[(req_2, mean_req_second)] * \
                                   poisson_pmfs[(ret_1, mean_ret_first)] * \
                                   poisson_pmfs[(ret_2, mean_ret_second)]) * \
                    (reward - parking_deduction + discount_factor * values[num_cars_first_, num_cars_second_])
    
    # Shuttle some cars for free if going from loc 1 to loc 2. Action is 
    # positive while going from loc 1 to loc 2
    if action > 0:
        return_val -= move_cost * (np.abs(action) - free_shuttles)
    else:
        return_val -= move_cost * np.abs(action)

    return return_val

def evaluate_policy():
    
    theta = 1e-3
    while True:
        delta = 0
        for i in range(num_cars+1):
            for j in range(num_cars+1):
                old_value_s = values[i, j]
                action = policy[i, j]
                values[i, j] = get_return(i, j, action, values)
                delta = max(delta, np.abs(old_value_s - values[i, j]))
        print ("Delta:", delta)
        if delta < theta:
            break
    return values
        
def improve_policy():
    stable = True
    for i in range(num_cars+1):
        for j in range(num_cars+1):
            old_policy = policy[i, j]
            returns = []
            for action in actions:
                if (action >= 0 and action <= i) or (action >= -j and action <= 0):
                    returns.append(get_return(i, j, action, values))
                else:
                    returns.append(-np.inf)
            policy[i, j] = actions[returns.index(max(returns))]
            if policy[i, j] != old_policy:
                print ("Policy Change for State:", i, j, policy[i, j], old_policy)
                stable = False
    return not stable

In [87]:
count = 0
values = evaluate_policy()
while (improve_policy()):
    count += 1
    print ("Iteration:", count)
    print (policy)
    print ("-----------------------")
    values = evaluate_policy()


Delta: 168.92571804672676
Delta: 103.32033602341775
Delta: 65.12881929454386
Delta: 52.08388752216038
Delta: 42.56742660599363
Delta: 35.62883015911794
Delta: 30.10747413895072
Delta: 25.128838400995505
Delta: 20.804558479089337
Delta: 17.137194067394944
Delta: 14.072404460240477
Delta: 11.533878237028034
Delta: 9.442454912708342
Delta: 7.724907030902955
Delta: 6.317119889868593
Delta: 5.164572248501429
Delta: 4.221652764196165
Delta: 3.450563030577598
Delta: 2.8201533798621767
Delta: 2.304838554585558
Delta: 1.8836456153702557
Delta: 1.5394031477551948
Delta: 1.258062429834979
Delta: 1.0281346207053161
Delta: 0.8402268213425259
Delta: 0.6866609492298039
Delta: 0.5611612888955051
Delta: 0.4585986615585398
Delta: 0.3747811111074384
Delta: 0.30628272643281207
Delta: 0.2503036903947873
Delta: 0.20455587803081698
Delta: 0.16716934929667104
Delta: 0.1366159248314034
Delta: 0.11164672613625726
Delta: 0.09124112965639597
Delta: 0.0745650496894541
Delta: 0.06093684540030608
Delta: 0.0497994587

Delta: 0.011783712376427502
Delta: 0.009594434345615355
Delta: 0.007811896793668893
Delta: 0.006360533720055628
Delta: 0.005178817315936612
Delta: 0.004216650536477573
Delta: 0.003433243609379133
Delta: 0.002795384976366222
Delta: 0.00227603340550786
Delta: 0.0018531715974177132
Delta: 0.0015088728360979076
Delta: 0.0012285409653713941
Delta: 0.0010002916577604992
Delta: 0.0008144485415755298
Policy Change for State: 4 0 0 1
Policy Change for State: 5 3 0 1
Policy Change for State: 6 2 1 2
Policy Change for State: 7 1 2 3
Policy Change for State: 8 0 3 4
Policy Change for State: 8 7 0 1
Policy Change for State: 9 6 1 2
Policy Change for State: 10 5 2 3
Iteration: 3
[[ 0  0  0  0  0  0  0 -1 -1 -2 -2]
 [ 0  0  0  0  0  0  0  0 -1 -1 -2]
 [ 0  0  0  0  0  0  0  0  0 -1 -1]
 [ 0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0]
 [ 1  1  1  0  0  0  0  0  0  0  0]
 [ 2  2  1  1  1  0  0  0  0  0  0]
 [ 3  2  2  2  1  1  0  0  0  0  0]
 [ 3  3  3  2  2  1  1  0  0  0  0]
 [

In [88]:
policy

array([[ 0,  0,  0,  0,  0,  0,  0, -1, -1, -2, -2],
       [ 0,  0,  0,  0,  0,  0,  0,  0, -1, -1, -2],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0, -1, -1],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 1,  1,  1,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 2,  2,  1,  1,  1,  0,  0,  0,  0,  0,  0],
       [ 3,  2,  2,  2,  1,  1,  0,  0,  0,  0,  0],
       [ 3,  3,  3,  2,  2,  1,  1,  0,  0,  0,  0],
       [ 4,  4,  3,  3,  2,  2,  1,  1,  1,  0,  0],
       [ 5,  4,  4,  3,  3,  2,  2,  2,  1,  1,  0]])

In [89]:
values

array([[404.27950443, 414.21581766, 424.02855816, 433.57465167,
        442.72089802, 451.38655396, 459.53997504, 467.3633309 ,
        474.98326557, 482.33000471, 489.35632214],
       [414.11370281, 424.04993838, 433.86225647, 443.40721434,
        452.55136351, 461.21393895, 469.36341912, 476.98335347,
        484.33009136, 491.35640833, 497.86555012],
       [423.52994724, 433.46578424, 443.27608358, 452.8157655 ,
        461.95029786, 470.59889528, 478.73075396, 486.33021781,
        493.35653405, 499.86566339, 506.16777323],
       [432.2626675 , 442.19748626, 452.00268304, 461.52910373,
        470.63957799, 479.25346029, 487.34212511, 494.8924857 ,
        501.86578878, 508.16789796, 513.64634657],
       [440.21234751, 450.14529143, 459.94110015, 469.4431408 ,
        478.50946195, 487.05990978, 495.07040531, 502.53348494,
        509.41512301, 515.62511678, 521.01540489],
       [448.14540438, 457.94119458, 467.44323147, 476.60151919,
        485.6011124 , 494.05578507, 501.9