In [None]:
import numpy as np

S = [0, 1, 2, 3]
A = [0, 1]
TERMINAL_STATES = [3]

# policy function returns an action for a given state
pi_storage = {}
def pi(s):
    if s not in pi_storage:
        pi_storage[s] = np.random.choice(A)
    return pi_storage[s]

def policy_evaluation(V, S, A, SR, P, pi, theta=1e-6, gamma=1.0):
    while True:
        delta = 0.0
        for s in S:
            v = V[s]
            V[s] = sum(P(s, pi(s), s_prime, r) * (r + gamma * V[s_prime]) for s_prime, r in SR(s, pi(s)))
            delta = max(delta, abs(v - V[s]))
        if delta < theta: break
    return V

def policy_improvement(V, S, A, SR, P, pi, gamma=1.0):
    policy_stable = True
    for s in S:
        old_action = pi(s)
        pi_storage[s] = np.argmax([sum(P(s, a, s_prime, r) * (r + gamma * V[s_prime]) for s_prime, r in SR(s, a)) for a in A])
        if old_action != pi(s):
            policy_stable = False
    return policy_stable


def policy_iteration(S, A, SR, P, pi, gamma=1.0, theta=1e-6):
    V = np.zeros(len(S), dtype=float)
    while True:
        V = policy_evaluation(V, S, A, SR, P, pi, theta, gamma)
        if policy_improvement(V, S, A, SR, P, pi, gamma):
            break
    return V, pi

SR_PAIRS = {
    (0, 0): [(0, 0.0)],
    (0, 1): [(1, 0.0)],
    (1, 0): [(2, -1.0)],
    (1, 1): [(0, 0.0)],
    (2, 0): [(3, 1.0)],
    (2, 1): [(1, 0.0)],
    (3, 0): [(3, 0.0)],
    (3, 1): [(3, 0.0)],
}

def SR(s, a):
    return SR_PAIRS.get((s, a), [])

P_PROB = {
    (0, 0, 0, 0.0): 1.0,
    (0, 1, 1, 0.0): 1.0,
    (1, 0, 2, -1.0): 1.0,
    (1, 1, 0, 0.0): 1.0,
    (2, 0, 3, 1.0): 1.0,
    (2, 1, 1, 0.0): 1.0,
    (3, 0, 3, 0.0): 1.0,
    (3, 1, 3, 0.0): 1.0,
}

def P(s, a, s_prime, r):
    return P_PROB.get((s, a, s_prime, r), 0.0)




[-9.99994933e-01 -9.99996417e-01  1.79150605e-06  0.00000000e+00]
