https://medium.com/@kireetr/working-the-book-policy-iteration-b9ae8441979f
    
https://raw.githubusercontent.com/kireet/ml-playground/master/sutton_barto/four/jacks.py

In [1]:
import logging
import math
import collections

# Jack's car rental. Jack can choose how to divvy up his rental cars between 2 locations each night in preparation for
# the next day's rentals. See book (pg 93) for more details. This is the unmodified version.


poisson_cache = {}
def poisson_prob(n, lam):
    rv = poisson_cache.get((n, lam))
    if rv is None:
        rv = math.exp(-lam) * (lam ** n) / float(math.factorial(n))
        poisson_cache[(n, lam)] = rv
    return rv


def loc_rental_probability(start, end, rent_lam, return_lam):
    """
    calculate the rental probabilities and rewards given a particular start and end count
    :param start: the day start count
    :param end: the day end count
    :param rent_lam: the rental poisson lambda
    :param return_lam: the return poisson lambda
    :return: a list of probability/reward tuples enumerating the possible outcomes.
    """
    rv = []
    for rented in range(0, start+1):
        remaining = start - rented
        returned = end - remaining
        if returned >= 0: # feasible
            if remaining == 0:
                # cars that can be rented is capped at the number on the lot, but in theory an infinite number of customers could arrive...
                prob = 1
                for i in range(0, rented):
                    prob -= poisson_prob(i, rent_lam)
            else:
                prob = poisson_prob(rented, rent_lam)
            if end == 20:
                # lot size is capped at 20, but in theory an infinite number of cars could be returned...
                return_prob = 1
                for i in range(0, returned):
                    return_prob -= poisson_prob(i, return_lam)
            else:
                return_prob = poisson_prob(returned, return_lam)

            prob *= return_prob
            reward = 10 * rented
            rv.append((prob, reward))

    return rv


def rental_probabilities(start_state):
    """
    calculate the day activity (rental/return activity) probabilities
    :param start_state: the start of day state
    :return: a next_state -> (transition probability, *expected* reward) tuples.
    """
    rv = {}
    for l1 in range(0,21):
        for l2 in range(0,21):
            l1_probs = loc_rental_probability(start_state[0], l1, 3, 3)
            l2_probs = loc_rental_probability(start_state[1], l2, 4, 2)
            prob = 0
            reward = 0
            for l1p, l1r in l1_probs:
                for l2p, l2r in l2_probs:
                    p = l1p*l2p
                    prob += p
                    reward += p * (l1r+l2r)
            rv[(l1,l2)] = (prob, reward)

    return rv


def calculate_value(start_state, action, v, gamma, modified):
    # state is num cars at each location at the end of the day

    # do the overnight move at a cost of $2 per car, updating the counts and return
    day_start_state = (min(20, start_state[0] - action), min(20, start_state[1] + action))
    ret = -2 * abs(action)

    if modified:
        if action > 0:
            ret += 2 # can shuttle one car from 1 -> 2 for free!

        for s in day_start_state:
            if s > 10:
                ret -= 4 # $4 storage fee for > 10 cars

    # calculate the day activity probabilities
    prob_rewards = rental_probabilities(day_start_state)

    # update the return based on the probabilities, rewards and returns
    for next_state, prob_and_exp_reward in iter(prob_rewards.items()):
        ret += prob_and_exp_reward[1] + prob_and_exp_reward[0] * gamma * v[next_state]

    return ret

# the state is a tuple of cars at each location, the action is the number of cars moved from location 1 to location 2.
# note that no more than 20 cars can be at either location. at most 5 cars can be moved. Policies are deterministic.
def eval_policy(v, p, modified, gamma=0.9, eps=0.1):
    iteration = 0
    while True:
        delta = 0
        v_old = collections.OrderedDict(v)
        for state, value in iter(v_old.items()):
            v[state] = calculate_value(state, p[state], v_old, gamma, modified)
            delta = max(delta, abs(v[state] - value))

        iteration += 1
        if delta <= eps:
            logging.info('converged after %d iterations', iteration)
            break
        else:
            logging.info('%d: %f', iteration, delta)

        if iteration == 10000:
            raise ValueError('value function non-convergence')


def action_space(state):
    # action value is cars from 1 -> 2. negative means moving cars to location 1.
    # at most 5 cars can move, at most number of cars in other location
    minval = -1 * min(state[1], 5)
    maxval = min(state[0], 5)
    # logging.info('%s: exploring states %d -> %d', state, minval, maxval)
    return range(minval, maxval+1)


def improve_policy(v, p, modified, gamma=0.9):
    stable = True
    for state in iter(v.keys()):
        temp = p[state]

        best_action = p[state]
        best_return = v[state]
        for action in action_space(state):
            ret = calculate_value(state, action, v, gamma, modified)

            if ret > best_return:
                best_action = action
                best_return = ret

        if best_action != temp:
            logging.info('updating p[%s]: %d:%f -> %d:%f (%f)', state, temp, v[state], best_action, best_return, best_return - v[state])
            p[state] = best_action
            stable = False

    return stable

def print_value_function(v):
    logging.info('    0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20')
    logging.info('    -------------------------------------------------------------------------------------------------------------------------')
    for i in range(20, -1, -1):
        values = ''
        for j in range(0, 21):
            values += str(round(v[(i, j)], 1)).rjust(5) + ' '
        logging.info('%s| %s', str(i).ljust(2), values)

if __name__ == '__main__':
    logging.basicConfig(level='INFO', format='%(message)s')

    # the policy function -- start by doing nothing
    p = collections.OrderedDict()
    for i in range(0, 21):
        for j in range(0, 21):
            p[(i,j)] = 0

    # the value function -- initialize all values to zero
    v = collections.OrderedDict()
    for i in range(0,21):
        for j in range(0,21):
            v[(i,j)] = 0

    use_modified_reward = False # True to run policy iteration with modifications from problem (1 free car shipment + $4 storage cost)
    policy_stable = False
    iteration = 0
    while not policy_stable:
        logging.info('POLICY')
        logging.info('    0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20')
        logging.info('    -------------------------------------------------------------')
        for i in range(20, -1, -1):
            policy = ''
            for j in range(0, 21):
                policy += str(p[(i, j)]).ljust(2) + ' '
            logging.info('%s| %s', str(i).ljust(2), policy)

        eval_policy(v, p, use_modified_reward)
        policy_stable = improve_policy(v, p, use_modified_reward)

        iteration += 1
        logging.info('policy improvement iteration %d, stable=%s', iteration, policy_stable)

POLICY
    0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
    -------------------------------------------------------------
20| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
19| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
18| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
17| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
16| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
15| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
14| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
13| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
12| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
11| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
10| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
9 | 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
8 | 0  0  0  0  0  0  0  0  0  0  0  0  0  0 

updating p[(9, 11)]: 0:554.185065 -> 2:556.408365 (2.223300)
updating p[(9, 12)]: 0:560.464947 -> 2:562.092897 (1.627950)
updating p[(9, 13)]: 0:566.441444 -> 2:567.492948 (1.051504)
updating p[(9, 14)]: 0:572.125975 -> 1:572.842475 (0.716499)
updating p[(9, 15)]: 0:577.526026 -> 1:577.957978 (0.431952)
updating p[(9, 16)]: 0:582.641529 -> 1:582.774501 (0.132972)
updating p[(10, 0)]: 0:463.926746 -> 5:484.285770 (20.359023)
updating p[(10, 1)]: 0:473.900817 -> 5:492.729903 (18.829086)
updating p[(10, 2)]: 0:483.766119 -> 5:500.762284 (16.996165)
updating p[(10, 3)]: 0:493.397659 -> 5:508.406109 (15.008451)
updating p[(10, 4)]: 0:502.681075 -> 5:515.683135 (13.002061)
updating p[(10, 5)]: 0:511.551827 -> 5:522.611907 (11.060080)
updating p[(10, 6)]: 0:519.995961 -> 4:529.506454 (9.510494)
updating p[(10, 7)]: 0:528.028341 -> 4:536.103304 (8.074962)
updating p[(10, 8)]: 0:535.672167 -> 4:542.383186 (6.711019)
updating p[(10, 9)]: 0:542.949193 -> 3:548.431869 (5.482676)
updating p[(10, 10

updating p[(16, 11)]: 0:563.749404 -> 5:576.907628 (13.158224)
updating p[(16, 12)]: 0:570.029286 -> 5:581.724151 (11.694865)
updating p[(16, 13)]: 0:576.005782 -> 5:586.202260 (10.196477)
updating p[(16, 14)]: 0:581.690314 -> 5:590.266823 (8.576508)
updating p[(16, 15)]: 0:587.090365 -> 4:593.821758 (6.731393)
updating p[(16, 16)]: 0:592.205868 -> 4:597.363404 (5.157535)
updating p[(16, 17)]: 0:597.022392 -> 3:600.638640 (3.616249)
updating p[(16, 18)]: 0:601.500500 -> 2:603.678957 (2.178457)
updating p[(16, 19)]: 0:605.565063 -> 1:606.520866 (0.955803)
updating p[(17, 0)]: 0:471.729239 -> 5:505.083112 (33.353873)
updating p[(17, 1)]: 0:481.703310 -> 5:513.527246 (31.823936)
updating p[(17, 2)]: 0:491.568612 -> 5:521.559627 (29.991015)
updating p[(17, 3)]: 0:501.200152 -> 5:529.203452 (28.003300)
updating p[(17, 4)]: 0:510.483567 -> 5:536.480478 (25.996910)
updating p[(17, 5)]: 0:519.354320 -> 5:543.409250 (24.054930)
updating p[(17, 6)]: 0:527.798453 -> 5:550.006099 (22.207646)
updat

updating p[(3, 16)]: 0:563.715638 -> -1:563.884167 (0.168529)
updating p[(4, 1)]: 1:463.328127 -> 0:463.521931 (0.193804)
updating p[(4, 2)]: 1:472.895440 -> 0:473.352300 (0.456860)
updating p[(4, 3)]: 1:482.059719 -> 0:482.892168 (0.832450)
updating p[(5, 0)]: 2:461.328127 -> 1:461.521931 (0.193804)
updating p[(5, 1)]: 2:470.895440 -> 1:471.352300 (0.456860)
updating p[(5, 2)]: 2:480.059719 -> 1:480.892168 (0.832450)
updating p[(5, 3)]: 1:489.917908 -> 0:490.075464 (0.157556)
updating p[(5, 4)]: 1:498.543376 -> 0:499.117608 (0.574232)
updating p[(5, 5)]: 1:506.681708 -> 0:507.640415 (0.958707)
updating p[(5, 6)]: 1:514.371519 -> 0:515.656987 (1.285468)
updating p[(6, 0)]: 3:468.895440 -> 2:469.352300 (0.456860)
updating p[(6, 1)]: 3:478.059719 -> 2:478.892168 (0.832450)
updating p[(6, 2)]: 2:487.917908 -> 1:488.075464 (0.157556)
updating p[(6, 3)]: 2:496.543376 -> 1:497.117608 (0.574232)
updating p[(6, 4)]: 2:504.681708 -> 1:505.640415 (0.958707)
updating p[(6, 5)]: 2:512.371519 -> 0:

updating p[(14, 2)]: 5:534.722437 -> 4:534.921814 (0.199377)
updating p[(14, 3)]: 5:541.076033 -> 3:541.570646 (0.494613)
updating p[(14, 4)]: 5:547.120466 -> 2:548.012075 (0.891610)
updating p[(14, 5)]: 5:552.904928 -> 1:554.236793 (1.331865)
updating p[(14, 6)]: 5:558.454919 -> 0:560.235146 (1.780227)
updating p[(14, 7)]: 5:563.783898 -> 0:565.937380 (2.153481)
updating p[(14, 8)]: 5:568.901199 -> 0:571.197207 (2.296008)
updating p[(14, 9)]: 5:573.815211 -> 0:576.141658 (2.326447)
updating p[(14, 10)]: 5:578.532180 -> 0:580.847727 (2.315547)
updating p[(14, 11)]: 5:583.051335 -> 0:585.358792 (2.307457)
updating p[(14, 12)]: 5:587.356732 -> 0:589.700372 (2.343640)
updating p[(14, 13)]: 5:591.405996 -> 0:593.890239 (2.484244)
updating p[(14, 14)]: 4:596.143437 -> 0:597.943227 (1.799790)
updating p[(14, 15)]: 4:599.794453 -> 0:601.871765 (2.077312)
updating p[(14, 16)]: 3:604.173294 -> 0:605.682468 (1.509174)
updating p[(14, 17)]: 2:608.322091 -> 0:609.368525 (1.046434)
updating p[(14, 

15| 5  5  4  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
14| 5  5  4  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
13| 5  5  4  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
12| 5  5  4  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
11| 5  4  4  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
10| 5  4  3  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
9 | 4  4  3  2  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
8 | 3  3  3  2  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
7 | 3  2  2  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
6 | 2  2  1  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
5 | 1  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
4 | 1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  -1 
3 | 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  -1 -1 -1 -1 -1 
2 | 0  0  0  0  0  0  0  0  0  0  0  -1 -1 -1 -1 -1 -1 -2 -2 -2 -2 
1 | 0  0  0  0  0  0  0  0  0  -1 -1 -1 -2 -2 -2

updating p[(19, 11)]: 2:607.254463 -> 1:607.361082 (0.106619)
updating p[(20, 10)]: 3:605.254463 -> 2:605.361082 (0.106619)
policy improvement iteration 4, stable=False
POLICY
    0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
    -------------------------------------------------------------
20| 5  5  5  5  4  4  3  3  3  3  2  2  2  2  2  1  1  1  0  0  0  
19| 5  5  5  4  4  3  3  2  2  2  2  1  1  1  1  1  0  0  0  0  0  
18| 5  5  5  4  3  3  2  2  1  1  1  1  0  0  0  0  0  0  0  0  0  
17| 5  5  5  4  3  2  2  1  1  0  0  0  0  0  0  0  0  0  0  0  0  
16| 5  5  5  4  3  2  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  
15| 5  5  5  4  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
14| 5  5  4  4  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
13| 5  5  4  3  3  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
12| 5  5  4  3  2  2  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
11| 5  4  4  3  2  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
10| 4  4  3 

In [2]:
v

OrderedDict([((0, 0), 421.0955450308663),
             ((0, 1), 431.06776452774017),
             ((0, 2), 440.92344624129066),
             ((0, 3), 450.52874547455497),
             ((0, 4), 459.7615134349537),
             ((0, 5), 468.55356455496957),
             ((0, 6), 476.89181752226835),
             ((0, 7), 484.79543525636547),
             ((0, 8), 492.65872102396673),
             ((0, 9), 500.1543376165245),
             ((0, 10), 507.56636485498154),
             ((0, 11), 514.6717995526984),
             ((0, 12), 521.4287847529695),
             ((0, 13), 528.0289288694281),
             ((0, 14), 534.4226585823884),
             ((0, 15), 540.5126958173897),
             ((0, 16), 546.3189425290266),
             ((0, 17), 551.8581500930081),
             ((0, 18), 557.282641599184),
             ((0, 19), 562.4896433732555),
             ((0, 20), 567.4435194921382),
             ((1, 0), 430.96928890383424),
             ((1, 1), 440.94144216725726),
             (