In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import math

# Intro

### Markov Poperty
$P(S_{t+1}\vert S_t, A_t) = P(S_{T+1}\vert S_t, A_t, S_{t-1}, A_{t-1},\ldots)$

### Transition Function
$p(s'\vert s, a) = P(S_t=s'\vert S_{t-1}=s, A_{t-1}=a)$

### Reward Function
$r(s,a)=\mathbb{E} [R_t\vert S_{t-1}=s, A_{t-1}=a]$<br>
$r(s,a,s')=\mathbb{E} [R_t\vert S_{t-1}=s, A_{t-1}=a, S_t=s']$

### Return
$G_t=R_{t+1}+\gamma R_{t+2}+\gamma^2 R_{t+3}+\ldots+\gamma^{T-1}R_T$

### MDP
$\mathcal{MDP(S,A,T,R,S_\theta,\gamma,H)}$<br>
$\mathcal{POMDP(S,A,T,R,S_\theta,\gamma,H,O,E)}$

# Bellman Equations

### State-Value Function
$v_\pi(s) = \sum_a [\pi(a|s) \sum_{s',r} [p(s',r|s, a) \left[r + \gamma v_\pi(s') \right],\forall s \in S]]$

### Action-Value Function
$q_\pi(s,a) = \sum_{s',r} p(s',r|s, a) \left[r + \gamma v_\pi(s') \right],\forall s \in S,\forall a \in A$

### Action Advantage
$a_\pi(s,a)=q_\pi(s,a)-v_\pi(s)$

### Bellman optimality equations

$v_\star(s)=\displaystyle\max_\pi v_\pi(s),\forall s \in S$<br>
$q_\star(s,a)=\sum_{s',r}p(s',r\vert s,a)[r+\gamma \displaystyle\max_{a'}q_\star(s',a')]$

# Policy Iteration

### Policy Evaluation
$v_{k+1}(s)=\sum_a \pi(a\vert s) \sum_{s',r}p(s',r\vert s,a)[r+\gamma v_k(s')]$

### Policy Improvement
$\pi'(s)=\text{argmax}_a \sum_{s',r}p(s',r\vert s,a)[r+\gamma v_k(s')]$

# Value Iteration
$v_{k+1}(s)=\displaystyle\max_a\sum_{s',r}p(s',r\vert s,a)[r+\gamma v_k(s')]$

In [5]:
def initialize_example():
    P = {
        # state: {action: [(prob. of transition, next state, reward, if statement is terminal)]}
        0: {
            0: [
                (0.3333, 0, 0.0, False),
                (0.3333, 0, 0.0, False),
                (0.3333, 4, 0.0, False),
            ],
            1: [
                (0.3333, 0, 0.0, False),
                (0.3333, 4, 0.0, False),
                (0.3333, 1, 0.0, False),
            ],
            2: [
                (0.3333, 4, 0.0, False),
                (0.3333, 1, 0.0, False),
                (0.3333, 0, 0.0, False),
            ],
            3: [
                (0.3333, 1, 0.0, False),
                (0.3333, 0, 0.0, False),
                (0.3333, 0, 0.0, False),
            ],
        },
        1: {
            0: [
                (0.3333, 1, 0.0, False),
                (0.3333, 0, 0.0, False),
                (0.3333, 5, 0.0, True),
            ],
            1: [
                (0.3333, 0, 0.0, False),
                (0.3333, 5, 0.0, True),
                (0.3333, 2, 0.0, False),
            ],
            2: [
                (0.3333, 5, 0.0, True),
                (0.3333, 2, 0.0, False),
                (0.3333, 1, 0.0, False),
            ],
            3: [
                (0.3333, 2, 0.0, False),
                (0.3333, 1, 0.0, False),
                (0.3333, 0, 0.0, False),
            ],
        },
        2: {
            0: [
                (0.3333, 2, 0.0, False),
                (0.3333, 1, 0.0, False),
                (0.3333, 6, 0.0, False),
            ],
            1: [
                (0.3333, 1, 0.0, False),
                (0.3333, 6, 0.0, False),
                (0.3333, 3, 0.0, False),
            ],
            2: [
                (0.3333, 6, 0.0, False),
                (0.3333, 3, 0.0, False),
                (0.3333, 2, 0.0, False),
            ],
            3: [
                (0.3333, 3, 0.0, False),
                (0.3333, 2, 0.0, False),
                (0.3333, 1, 0.0, False),
            ],
        },
        3: {
            0: [
                (0.3333, 3, 0.0, False),
                (0.3333, 2, 0.0, False),
                (0.3333, 7, 0.0, True),
            ],
            1: [
                (0.3333, 2, 0.0, False),
                (0.3333, 7, 0.0, True),
                (0.3333, 3, 0.0, False),
            ],
            2: [
                (0.3333, 7, 0.0, True),
                (0.3333, 3, 0.0, False),
                (0.3333, 3, 0.0, False),
            ],
            3: [
                (0.3333, 3, 0.0, False),
                (0.3333, 3, 0.0, False),
                (0.3333, 2, 0.0, False),
            ],
        },
        4: {
            0: [
                (0.3333, 0, 0.0, False),
                (0.3333, 4, 0.0, False),
                (0.3333, 8, 0.0, False),
            ],
            1: [
                (0.3333, 4, 0.0, False),
                (0.3333, 8, 0.0, False),
                (0.3333, 5, 0.0, True),
            ],
            2: [
                (0.3333, 8, 0.0, False),
                (0.3333, 5, 0.0, True),
                (0.3333, 0, 0.0, False),
            ],
            3: [
                (0.3333, 5, 0.0, True),
                (0.3333, 0, 0.0, False),
                (0.3333, 4, 0.0, False),
            ],
        },
        5: {
            0: [(1.0, 5, 0, True)],
            1: [(1.0, 5, 0, True)],
            2: [(1.0, 5, 0, True)],
            3: [(1.0, 5, 0, True)],
        },
        6: {
            0: [
                (0.3333, 2, 0.0, False),
                (0.3333, 5, 0.0, True),
                (0.3333, 10, 0.0, False),
            ],
            1: [
                (0.3333, 5, 0.0, True),
                (0.3333, 10, 0.0, False),
                (0.3333, 7, 0.0, True),
            ],
            2: [
                (0.3333, 10, 0.0, False),
                (0.3333, 7, 0.0, True),
                (0.3333, 2, 0.0, False),
            ],
            3: [
                (0.3333, 7, 0.0, True),
                (0.3333, 2, 0.0, False),
                (0.3333, 5, 0.0, True),
            ],
        },
        7: {
            0: [(1.0, 7, 0, True)],
            1: [(1.0, 7, 0, True)],
            2: [(1.0, 7, 0, True)],
            3: [(1.0, 7, 0, True)],
        },
        8: {
            0: [
                (0.3333, 4, 0.0, False),
                (0.3333, 8, 0.0, False),
                (0.3333, 12, 0.0, True),
            ],
            1: [
                (0.3333, 8, 0.0, False),
                (0.3333, 12, 0.0, True),
                (0.3333, 9, 0.0, False),
            ],
            2: [
                (0.3333, 12, 0.0, True),
                (0.3333, 9, 0.0, False),
                (0.3333, 4, 0.0, False),
            ],
            3: [
                (0.3333, 9, 0.0, False),
                (0.3333, 4, 0.0, False),
                (0.3333, 8, 0.0, False),
            ],
        },
        9: {
            0: [
                (0.3333, 5, 0.0, True),
                (0.3333, 8, 0.0, False),
                (0.3333, 13, 0.0, False),
            ],
            1: [
                (0.3333, 8, 0.0, False),
                (0.3333, 13, 0.0, False),
                (0.3333, 10, 0.0, False),
            ],
            2: [
                (0.3333, 13, 0.0, False),
                (0.3333, 10, 0.0, False),
                (0.3333, 5, 0.0, True),
            ],
            3: [
                (0.3333, 10, 0.0, False),
                (0.3333, 5, 0.0, True),
                (0.3333, 8, 0.0, False),
            ],
        },
        10: {
            0: [
                (0.3333, 6, 0.0, False),
                (0.3333, 9, 0.0, False),
                (0.3333, 14, 0.0, False),
            ],
            1: [
                (0.3333, 9, 0.0, False),
                (0.3333, 14, 0.0, False),
                (0.3333, 11, 0.0, True),
            ],
            2: [
                (0.3333, 14, 0.0, False),
                (0.3333, 11, 0.0, True),
                (0.3333, 6, 0.0, False),
            ],
            3: [
                (0.3333, 11, 0.0, True),
                (0.3333, 6, 0.0, False),
                (0.3333, 9, 0.0, False),
            ],
        },
        11: {
            0: [(1.0, 11, 0, True)],
            1: [(1.0, 11, 0, True)],
            2: [(1.0, 11, 0, True)],
            3: [(1.0, 11, 0, True)],
        },
        12: {
            0: [(1.0, 12, 0, True)],
            1: [(1.0, 12, 0, True)],
            2: [(1.0, 12, 0, True)],
            3: [(1.0, 12, 0, True)],
        },
        13: {
            0: [
                (0.3333, 9, 0.0, False),
                (0.3333, 12, 0.0, True),
                (0.3333, 13, 0.0, False),
            ],
            1: [
                (0.3333, 12, 0.0, True),
                (0.3333, 13, 0.0, False),
                (0.3333, 14, 0.0, False),
            ],
            2: [
                (0.3333, 13, 0.0, False),
                (0.3333, 14, 0.0, False),
                (0.3333, 9, 0.0, False),
            ],
            3: [
                (0.3333, 14, 0.0, False),
                (0.3333, 9, 0.0, False),
                (0.3333, 12, 0.0, True),
            ],
        },
        14: {
            0: [
                (0.3333, 10, 0.0, False),
                (0.3333, 13, 0.0, False),
                (0.3333, 14, 0.0, False),
            ],
            1: [
                (0.3333, 13, 0.0, False),
                (0.3333, 14, 0.0, False),
                (0.3333, 15, 1.0, True),
            ],
            2: [
                (0.3333, 14, 0.0, False),
                (0.3333, 15, 1.0, True),
                (0.3333, 10, 0.0, False),
            ],
            3: [
                (0.3333, 15, 1.0, True),
                (0.3333, 10, 0.0, False),
                (0.3333, 13, 0.0, False),
            ],
        },
        15: {
            0: [(1.0, 15, 0, True)],
            1: [(1.0, 15, 0, True)],
            2: [(1.0, 15, 0, True)],
            3: [(1.0, 15, 0, True)],
        },
    }
    
    LEFT, DOWN, RIGHT, UP = 0, 1, 2, 3
    
    pi = lambda s: {
        0:UP, 1:UP, 2:UP, 3:UP,
        4:LEFT, 5:UP, 6:UP, 7:UP,
        8:UP, 9:DOWN, 10:LEFT, 11:LEFT,
        12:LEFT, 13:RIGHT, 14:RIGHT, 15:LEFT
    }[s]

    return P, pi

In [6]:
def print_pi(pi):
    for i in range(0, 16, 4):
        k, l, m, n = [pi(_) for _ in range(i, i+4)]
        d = {0:"←", 1:"↓", 2:"→", 3:"↑"}
        print(d[k], d[l], d[m], d[n])

def print_V(V):
    for i in range(0, 16, 4):
        print(round(V[i], 3), round(V[i+1], 3), round(V[i+2], 3), round(V[i+3], 3))

def print_pi_V(pi, V, iteration):
    print()
    print(iteration)
    print_V(V)
    print_pi(pi)

In [7]:
def policy_evaluation(pi, P, gamma=1.0, theta=1e-5):
    prev_V = np.zeros(len(P))
    iteration = 0
    while True:
        V = np.zeros(len(P))
        for s in range(len(P)):
            for prob, next_state, reward, done in P[s][pi(s)]:
                V[s] += prob * (reward + gamma * prev_V[next_state] * (not done))
        if np.max(np.abs(prev_V - V)) < theta:
            break
        prev_V = V.copy()
        iteration += 1
    return V, iteration
    
def policy_improvement(V, P, gamma=1.0):
    Q = np.zeros((len(P), len(P[0])), dtype=np.float64)
    for s in range(len(P)):
        for a in range(len(P[s])):
            for prob, next_state, reward, done in P[s][a]:
                Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
    new_pi = lambda s: {s: a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    return new_pi

In [8]:
def policy_iteration(P, pi):
    prev_V = np.zeros(len(P))
    theta = 1e-10
    iteration = 0
    while True:
        V, i = policy_evaluation(pi, P)
        iteration += i
        
        if np.max(np.abs(V - prev_V)) < theta: break
        prev_V = V

        print_pi_V(pi, V, iteration)
        
        pi = policy_improvement(V, P)
    return V, pi, iteration

def value_iteration(P, gamma=1.0, theta=1e-5):
    V = np.zeros(len(P), dtype=np.float64)
    iteration = 0
    while True:
        Q = np.zeros((len(P), len(P[0])), dtype=np.float64)
        
        for s in range(len(P)):
            for a in range(len(P[s])):
                for prob, next_state, reward, done in P[s][a]:
                    Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
        iteration+=1
        if iteration % 100 == 0:
            pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
            print_pi_V(pi, V, iteration)
        if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
            break
        V = np.max(Q, axis=1)
    pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    return V, pi, iteration

In [9]:
print("POLICY ITERATION")
P, pi = initialize_example()
V_PI, pi_PI, iterations_PI = policy_iteration(P, pi)

POLICY ITERATION

54
0.0 0.0 0.0 0.0
0.125 0.0 0.0 0.0
0.25 0.375 0.35 0.0
0.0 0.525 0.675 0.0
↑ ↑ ↑ ↑
← ↑ ↑ ↑
↑ ↓ ← ←
← → → ←

307
0.751 0.375 0.359 0.18
0.751 0.0 0.343 0.0
0.751 0.752 0.671 0.0
0.0 0.834 0.917 0.0
← ← ← ←
← ← ← ←
↑ ↓ ← ←
← → ↓ ←

571
0.777 0.655 0.534 0.534
0.777 0.0 0.413 0.0
0.778 0.778 0.706 0.0
0.0 0.852 0.926 0.0
← ↑ ← ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←

908
0.819 0.818 0.818 0.818
0.82 0.0 0.527 0.0
0.82 0.821 0.762 0.0
0.0 0.88 0.94 0.0
← ↑ ↑ ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←


In [10]:
print("VALUE ITERATION")
V_VI, pi_VI, iterations_VI = value_iteration(P)

VALUE ITERATION

100
0.739 0.712 0.693 0.683
0.745 0.0 0.47 0.0
0.757 0.774 0.721 0.0
0.0 0.847 0.923 0.0
← ↑ ↑ ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←

200
0.813 0.81 0.808 0.807
0.813 0.0 0.522 0.0
0.815 0.817 0.759 0.0
0.0 0.878 0.939 0.0
← ↑ ↑ ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←

300
0.819 0.818 0.817 0.817
0.819 0.0 0.526 0.0
0.82 0.821 0.762 0.0
0.0 0.88 0.94 0.0
← ↑ ↑ ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←


In [11]:
print("PI iterations:", iterations_PI, "\nVI iterations:", iterations_VI)
print("VI policy == PI policy:", [round(pi_VI(_), 4) == round(pi_PI(_), 4) for _ in range(16)])
print("VI values == PI values:", [round(V_VI[_], 4) == round(V_PI[_], 4) for _ in range(16)])

print_pi(pi_PI)
print_V(V_PI)

PI iterations: 1245 
VI iterations: 336
VI policy == PI policy: [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]
VI values == PI values: [True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]
← ↑ ↑ ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←
0.819 0.818 0.818 0.818
0.82 0.0 0.527 0.0
0.82 0.821 0.762 0.0
0.0 0.88 0.94 0.0


# Total Regret

$\mathcal{T}=\sum_{e=1}^E \mathbb{E}[v_\star-q_\star(A_e)]$

# Softmax Exploration

$\pi(a)=\frac{exp(\frac{Q(a)}{\mathcal{T}})}{\sum_{b=0}^B exp(\frac{Q(b)}{\mathcal{T}})}$