https://medium.com/@kireetr/working-the-book-policy-iteration-b9ae8441979f
    
https://raw.githubusercontent.com/kireet/ml-playground/master/sutton_barto/four/jacks.py

In [None]:
import logging
import math
import collections

# Jack's car rental. Jack can choose how to divvy up his rental cars between 2 locations each night in preparation for
# the next day's rentals. See book (pg 93) for more details. This is the unmodified version.


poisson_cache = {}
def poisson_prob(n, lam):
    rv = poisson_cache.get((n, lam))
    if rv is None:
        rv = math.exp(-lam) * (lam ** n) / float(math.factorial(n))
        poisson_cache[(n, lam)] = rv
    return rv


def loc_rental_probability(start, end, rent_lam, return_lam):
    """
    calculate the rental probabilities and rewards given a particular start and end count
    :param start: the day start count
    :param end: the day end count
    :param rent_lam: the rental poisson lambda
    :param return_lam: the return poisson lambda
    :return: a list of probability/reward tuples enumerating the possible outcomes.
    """
    rv = []
    for rented in range(0, start+1):
        remaining = start - rented
        returned = end - remaining
        if returned >= 0: # feasible
            if remaining == 0:
                # cars that can be rented is capped at the number on the lot, but in theory an infinite number of customers could arrive...
                prob = 1
                for i in range(0, rented):
                    prob -= poisson_prob(i, rent_lam)
            else:
                prob = poisson_prob(rented, rent_lam)
            if end == 20:
                # lot size is capped at 20, but in theory an infinite number of cars could be returned...
                return_prob = 1
                for i in range(0, returned):
                    return_prob -= poisson_prob(i, return_lam)
            else:
                return_prob = poisson_prob(returned, return_lam)

            prob *= return_prob
            reward = 10 * rented
            rv.append((prob, reward))

    return rv


rental_probabilities_cache = {}
def rental_probabilities(start_state):
    """
    calculate the day activity (rental/return activity) probabilities
    :param start_state: the start of day state
    :return: a next_state -> (transition probability, *expected* reward) tuples.
    """
    
    key = start_state
    if key in rental_probabilities_cache: return rental_probabilities_cache[key]
    
    rv = {}
    for l1 in range(0,21):
        for l2 in range(0,21):
            l1_probs = loc_rental_probability(start_state[0], l1, 3, 3)
            l2_probs = loc_rental_probability(start_state[1], l2, 4, 2)
            prob = 0
            reward = 0
            for l1p, l1r in l1_probs:
                for l2p, l2r in l2_probs:
                    p = l1p*l2p
                    prob += p
                    reward += p * (l1r+l2r)
            rv[(l1,l2)] = (prob, reward)

    rental_probabilities_cache[key] = rv
    return rv


def calculate_value(start_state, action, v, gamma, modified):
    # state is num cars at each location at the end of the day

    # do the overnight move at a cost of $2 per car, updating the counts and return
    day_start_state = (min(20, start_state[0] - action), min(20, start_state[1] + action))
    ret = -2 * abs(action)

    if modified:
        if action > 0:
            ret += 2 # can shuttle one car from 1 -> 2 for free!

        for s in day_start_state:
            if s > 10:
                ret -= 4 # $4 storage fee for > 10 cars

    # calculate the day activity probabilities
    prob_rewards = rental_probabilities(day_start_state)

    # update the return based on the probabilities, rewards and returns
    for next_state, prob_and_exp_reward in iter(prob_rewards.items()):
        ret += prob_and_exp_reward[1] + prob_and_exp_reward[0] * gamma * v[next_state]

    return ret

# the state is a tuple of cars at each location, the action is the number of cars moved from location 1 to location 2.
# note that no more than 20 cars can be at either location. at most 5 cars can be moved. Policies are deterministic.
def eval_policy(v, p, modified, gamma=0.9, eps=0.1):
    iteration = 0
    while True:
        delta = 0
        v_old = collections.OrderedDict(v)
        for state, value in iter(v_old.items()):
            v[state] = calculate_value(state, p[state], v_old, gamma, modified)
            delta = max(delta, abs(v[state] - value))

        iteration += 1
        if delta <= eps:
            logging.info('converged after %d iterations', iteration)
            break
        else:
            logging.info('%d: %f', iteration, delta)

        if iteration == 10000:
            raise ValueError('value function non-convergence')


def action_space(state):
    # action value is cars from 1 -> 2. negative means moving cars to location 1.
    # at most 5 cars can move, at most number of cars in other location
    minval = -1 * min(state[1], 5)
    maxval = min(state[0], 5)
    # logging.info('%s: exploring states %d -> %d', state, minval, maxval)
    return range(minval, maxval+1)


def improve_policy(v, p, modified, gamma=0.9):
    stable = True
    for state in iter(v.keys()):
        temp = p[state]

        best_action = p[state]
        best_return = v[state]
        for action in action_space(state):
            ret = calculate_value(state, action, v, gamma, modified)
            if state == (0, 8):
                print("(0,8), ", action, ": ", ret)

            if ret > best_return:
                best_action = action
                best_return = ret

        if best_action != temp:
            logging.info('updating p[%s]: %d:%f -> %d:%f (%f)', state, temp, v[state], best_action, best_return, best_return - v[state])
            p[state] = best_action
            stable = False

    return stable

def print_value_function(v):
    logging.info('    0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20')
    logging.info('    -------------------------------------------------------------------------------------------------------------------------')
    for i in range(20, -1, -1):
        values = ''
        for j in range(0, 21):
            values += str(round(v[(i, j)], 1)).rjust(5) + ' '
        logging.info('%s| %s', str(i).ljust(2), values)

if __name__ == '__main__':
    logging.basicConfig(level='INFO', format='%(message)s')

    # the policy function -- start by doing nothing
    p = collections.OrderedDict()
    for i in range(0, 21):
        for j in range(0, 21):
            p[(i,j)] = 0

    # the value function -- initialize all values to zero
    v = collections.OrderedDict()
    for i in range(0,21):
        for j in range(0,21):
            v[(i,j)] = 0

    use_modified_reward = False # True to run policy iteration with modifications from problem (1 free car shipment + $4 storage cost)
    policy_stable = False
    iteration = 0
    while not policy_stable:
        logging.info('POLICY')
        logging.info('    0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20')
        logging.info('    -------------------------------------------------------------')
        for i in range(20, -1, -1):
            policy = ''
            for j in range(0, 21):
                policy += str(p[(i, j)]).ljust(2) + ' '
            logging.info('%s| %s', str(i).ljust(2), policy)

        eval_policy(v, p, use_modified_reward)
        policy_stable = improve_policy(v, p, use_modified_reward)

        iteration += 1
        logging.info('policy improvement iteration %d, stable=%s', iteration, policy_stable)

POLICY
    0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
    -------------------------------------------------------------
20| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
19| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
18| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
17| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
16| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
15| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
14| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
13| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
12| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
11| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
10| 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
9 | 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  
8 | 0  0  0  0  0  0  0  0  0  0  0  0  0  0 

(0,8),  -5 :  466.13160143319703
(0,8),  -4 :  471.508803279755
(0,8),  -3 :  475.3182095402839
(0,8),  -2 :  477.518065606009
(0,8),  -1 :  478.3163567318444
(0,8),  0 :  478.14035487913384


updating p[(4, 20)]: 0:578.572754 -> -1:579.024415 (0.451661)
updating p[(5, 0)]: 0:446.573592 -> 2:449.532502 (2.958910)
updating p[(5, 1)]: 0:456.547663 -> 2:459.164041 (2.616379)
updating p[(5, 2)]: 0:466.412965 -> 2:468.447457 (2.034492)
updating p[(5, 3)]: 0:476.044504 -> 1:477.508803 (1.464299)
updating p[(5, 4)]: 0:485.327920 -> 1:486.379556 (1.051635)
updating p[(5, 5)]: 0:494.198673 -> 1:494.823689 (0.625017)
updating p[(5, 6)]: 0:502.642806 -> 1:502.856070 (0.213264)
updating p[(6, 0)]: 0:451.468133 -> 3:457.164041 (5.695908)
updating p[(6, 1)]: 0:461.442204 -> 3:466.447457 (5.005253)
updating p[(6, 2)]: 0:471.307506 -> 2:475.508803 (4.201297)
updating p[(6, 3)]: 0:480.939046 -> 2:484.379556 (3.440510)
updating p[(6, 4)]: 0:490.222462 -> 2:492.823689 (2.601228)
updating p[(6, 5)]: 0:499.093214 -> 2:500.856070 (1.762856)
updating p[(6, 6)]: 0:507.537348 -> 1:508.762284 (1.224936)
updating p[(6, 7)]: 0:515.569728 -> 1:516.406109 (0.836381)
updating p[(6, 8)]: 0:523.213554 -> 1:

updating p[(13, 16)]: 0:589.650636 -> 3:592.377608 (2.726973)
updating p[(13, 17)]: 0:594.467159 -> 2:596.266823 (1.799663)
updating p[(13, 18)]: 0:598.945267 -> 1:599.821758 (0.876491)
updating p[(13, 19)]: 0:603.009830 -> 1:603.363404 (0.353573)
updating p[(14, 0)]: 0:469.686412 -> 5:499.349204 (29.662793)
updating p[(14, 1)]: 0:479.660482 -> 5:507.793338 (28.132856)
updating p[(14, 2)]: 0:489.525784 -> 5:515.825719 (26.299934)
updating p[(14, 3)]: 0:499.157324 -> 5:523.469544 (24.312220)
updating p[(14, 4)]: 0:508.440740 -> 5:530.746570 (22.305830)
updating p[(14, 5)]: 0:517.311492 -> 5:537.675342 (20.363850)
updating p[(14, 6)]: 0:525.755626 -> 5:544.272191 (18.516566)
updating p[(14, 7)]: 0:533.788007 -> 5:550.552073 (16.764067)
updating p[(14, 8)]: 0:541.431832 -> 5:556.528570 (15.096738)
updating p[(14, 9)]: 0:548.708858 -> 5:562.213102 (13.504244)
updating p[(14, 10)]: 0:555.637630 -> 5:567.613152 (11.975522)
updating p[(14, 11)]: 0:562.234479 -> 5:572.728655 (10.494176)
updati

updating p[(20, 8)]: 0:544.371616 -> 5:565.419940 (21.048324)
updating p[(20, 9)]: 0:551.648642 -> 5:571.104472 (19.455830)
updating p[(20, 10)]: 0:558.577414 -> 5:576.504522 (17.927108)
updating p[(20, 11)]: 0:565.174263 -> 5:581.620025 (16.445762)
updating p[(20, 12)]: 0:571.454145 -> 5:586.436549 (14.982404)
updating p[(20, 13)]: 0:577.430641 -> 5:590.914657 (13.484016)
updating p[(20, 14)]: 0:583.115173 -> 5:594.979220 (11.864047)
updating p[(20, 15)]: 0:588.515224 -> 5:598.520866 (10.005642)
updating p[(20, 16)]: 0:593.630727 -> 4:601.193899 (7.563172)
updating p[(20, 17)]: 0:598.447251 -> 3:603.721808 (5.274558)
updating p[(20, 18)]: 0:602.925359 -> 2:606.123976 (3.198618)
updating p[(20, 19)]: 0:606.989922 -> 1:608.417353 (1.427431)


In [None]:
p = collections.OrderedDict()
for i in range(0, 21):
    for j in range(0, 21):
        p[(i,j)] = 0

# the value function -- initialize all values to zero
v = collections.OrderedDict()
for i in range(0,21):
    for j in range(0,21):
        v[(i,j)] = 0

eval_policy(v, p, False)

1: 70.000000
2: 62.999649
3: 56.675706
4: 50.792824
5: 45.038889
6: 39.373431
7: 34.033296
8: 29.268275
9: 25.191935
10: 21.787735
11: 18.968603
12: 16.627533
13: 14.665218
14: 13.000268
15: 11.570073
16: 10.327904
17: 9.239195


In [None]:
import numpy as np
np.array([i for i in p.values()])

In [None]:
v

In [None]:
np.array([i[1] for i in rental_probabilities_cache[(0, 8)].values()]).sum()

In [None]:
rental_probabilities_cache[(0, 8)][(0, 8)]