In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

# Intro

### Markov Poperty
$P(S_{t+1}\vert S_t, A_t) = P(S_{T+1}\vert S_t, A_t, S_{t-1}, A_{t-1},\ldots)$

### Transition Function
$p(s'\vert s, a) = P(S_t=s'\vert S_{t-1}=s, A_{t-1}=a)$

### Reward Function
$r(s,a)=\mathbb{E} [R_t\vert S_{t-1}=s, A_{t-1}=a]$<br>
$r(s,a,s')=\mathbb{E} [R_t\vert S_{t-1}=s, A_{t-1}=a, S_t=s']$

### Return
$G_t=R_{t+1}+\gamma R_{t+2}+\gamma^2 R_{t+3}+\ldots+\gamma^{T-1}R_T$

### MDP
$\mathcal{MDP(S,A,T,R,S_\theta,\gamma,H)}$<br>
$\mathcal{POMDP(S,A,T,R,S_\theta,\gamma,H,O,E)}$

# Bellman Equations

### State-Value Function
$v_\pi(s) = \sum_a [\pi(a|s) \sum_{s',r} [p(s',r|s, a) \left[r + \gamma v_\pi(s') \right],\forall s \in S]]$

### Action-Value Function
$q_\pi(s,a) = \sum_{s',r} p(s',r|s, a) \left[r + \gamma v_\pi(s') \right],\forall s \in S,\forall a \in A$

### Action Advantage
$a_\pi(s,a)=q_\pi(s,a)-v_\pi(s)$

### Bellman optimality equations

$v_\star(s)=\displaystyle\max_\pi v_\pi(s),\forall s \in S$<br>
$q_\star(s,a)=\sum_{s',r}p(s',r\vert s,a)[r+\gamma \displaystyle\max_{a'}q_\star(s',a')]$

# Policy Iteration

### Policy Evaluation
$v_{k+1}(s)=\sum_a \pi(a\vert s) \sum_{s',r}p(s',r\vert s,a)[r+\gamma v_k(s')]$

### Policy Improvement
$\pi'(s)=\text{argmax}_a \sum_{s',r}p(s',r\vert s,a)[r+\gamma v_k(s')]$

# Value Iteration
$v_{k+1}(s)=\displaystyle\max_a\sum_{s',r}p(s',r\vert s,a)[r+\gamma v_k(s')]$

# Total Regret

$\mathcal{T}=\sum_{e=1}^E \mathbb{E}[v_\star-q_\star(A_e)]$

# Softmax Exploration

$\pi(a)=\frac{exp(\frac{Q(a)}{\mathcal{T}})}{\sum_{b=0}^B exp(\frac{Q(b)}{\mathcal{T}})}$

In [89]:
P = {
    # state: {action: [(prob. of transition, next state, reward, if statement is terminal)]}
    0: {
        0: [
            (0.3333, 0, 0.0, False),
            (0.3333, 0, 0.0, False),
            (0.3333, 4, 0.0, False),
        ],
        1: [
            (0.3333, 0, 0.0, False),
            (0.3333, 4, 0.0, False),
            (0.3333, 1, 0.0, False),
        ],
        2: [
            (0.3333, 4, 0.0, False),
            (0.3333, 1, 0.0, False),
            (0.3333, 0, 0.0, False),
        ],
        3: [
            (0.3333, 1, 0.0, False),
            (0.3333, 0, 0.0, False),
            (0.3333, 0, 0.0, False),
        ],
    },
    1: {
        0: [
            (0.3333, 1, 0.0, False),
            (0.3333, 0, 0.0, False),
            (0.3333, 5, 0.0, True),
        ],
        1: [
            (0.3333, 0, 0.0, False),
            (0.3333, 5, 0.0, True),
            (0.3333, 2, 0.0, False),
        ],
        2: [
            (0.3333, 5, 0.0, True),
            (0.3333, 2, 0.0, False),
            (0.3333, 1, 0.0, False),
        ],
        3: [
            (0.3333, 2, 0.0, False),
            (0.3333, 1, 0.0, False),
            (0.3333, 0, 0.0, False),
        ],
    },
    2: {
        0: [
            (0.3333, 2, 0.0, False),
            (0.3333, 1, 0.0, False),
            (0.3333, 6, 0.0, False),
        ],
        1: [
            (0.3333, 1, 0.0, False),
            (0.3333, 6, 0.0, False),
            (0.3333, 3, 0.0, False),
        ],
        2: [
            (0.3333, 6, 0.0, False),
            (0.3333, 3, 0.0, False),
            (0.3333, 2, 0.0, False),
        ],
        3: [
            (0.3333, 3, 0.0, False),
            (0.3333, 2, 0.0, False),
            (0.3333, 1, 0.0, False),
        ],
    },
    3: {
        0: [
            (0.3333, 3, 0.0, False),
            (0.3333, 2, 0.0, False),
            (0.3333, 7, 0.0, True),
        ],
        1: [
            (0.3333, 2, 0.0, False),
            (0.3333, 7, 0.0, True),
            (0.3333, 3, 0.0, False),
        ],
        2: [
            (0.3333, 7, 0.0, True),
            (0.3333, 3, 0.0, False),
            (0.3333, 3, 0.0, False),
        ],
        3: [
            (0.3333, 3, 0.0, False),
            (0.3333, 3, 0.0, False),
            (0.3333, 2, 0.0, False),
        ],
    },
    4: {
        0: [
            (0.3333, 0, 0.0, False),
            (0.3333, 4, 0.0, False),
            (0.3333, 8, 0.0, False),
        ],
        1: [
            (0.3333, 4, 0.0, False),
            (0.3333, 8, 0.0, False),
            (0.3333, 5, 0.0, True),
        ],
        2: [
            (0.3333, 8, 0.0, False),
            (0.3333, 5, 0.0, True),
            (0.3333, 0, 0.0, False),
        ],
        3: [
            (0.3333, 5, 0.0, True),
            (0.3333, 0, 0.0, False),
            (0.3333, 4, 0.0, False),
        ],
    },
    5: {
        0: [(1.0, 5, 0, True)],
        1: [(1.0, 5, 0, True)],
        2: [(1.0, 5, 0, True)],
        3: [(1.0, 5, 0, True)],
    },
    6: {
        0: [
            (0.3333, 2, 0.0, False),
            (0.3333, 5, 0.0, True),
            (0.3333, 10, 0.0, False),
        ],
        1: [
            (0.3333, 5, 0.0, True),
            (0.3333, 10, 0.0, False),
            (0.3333, 7, 0.0, True),
        ],
        2: [
            (0.3333, 10, 0.0, False),
            (0.3333, 7, 0.0, True),
            (0.3333, 2, 0.0, False),
        ],
        3: [
            (0.3333, 7, 0.0, True),
            (0.3333, 2, 0.0, False),
            (0.3333, 5, 0.0, True),
        ],
    },
    7: {
        0: [(1.0, 7, 0, True)],
        1: [(1.0, 7, 0, True)],
        2: [(1.0, 7, 0, True)],
        3: [(1.0, 7, 0, True)],
    },
    8: {
        0: [
            (0.3333, 4, 0.0, False),
            (0.3333, 8, 0.0, False),
            (0.3333, 12, 0.0, True),
        ],
        1: [
            (0.3333, 8, 0.0, False),
            (0.3333, 12, 0.0, True),
            (0.3333, 9, 0.0, False),
        ],
        2: [
            (0.3333, 12, 0.0, True),
            (0.3333, 9, 0.0, False),
            (0.3333, 4, 0.0, False),
        ],
        3: [
            (0.3333, 9, 0.0, False),
            (0.3333, 4, 0.0, False),
            (0.3333, 8, 0.0, False),
        ],
    },
    9: {
        0: [
            (0.3333, 5, 0.0, True),
            (0.3333, 8, 0.0, False),
            (0.3333, 13, 0.0, False),
        ],
        1: [
            (0.3333, 8, 0.0, False),
            (0.3333, 13, 0.0, False),
            (0.3333, 10, 0.0, False),
        ],
        2: [
            (0.3333, 13, 0.0, False),
            (0.3333, 10, 0.0, False),
            (0.3333, 5, 0.0, True),
        ],
        3: [
            (0.3333, 10, 0.0, False),
            (0.3333, 5, 0.0, True),
            (0.3333, 8, 0.0, False),
        ],
    },
    10: {
        0: [
            (0.3333, 6, 0.0, False),
            (0.3333, 9, 0.0, False),
            (0.3333, 14, 0.0, False),
        ],
        1: [
            (0.3333, 9, 0.0, False),
            (0.3333, 14, 0.0, False),
            (0.3333, 11, 0.0, True),
        ],
        2: [
            (0.3333, 14, 0.0, False),
            (0.3333, 11, 0.0, True),
            (0.3333, 6, 0.0, False),
        ],
        3: [
            (0.3333, 11, 0.0, True),
            (0.3333, 6, 0.0, False),
            (0.3333, 9, 0.0, False),
        ],
    },
    11: {
        0: [(1.0, 11, 0, True)],
        1: [(1.0, 11, 0, True)],
        2: [(1.0, 11, 0, True)],
        3: [(1.0, 11, 0, True)],
    },
    12: {
        0: [(1.0, 12, 0, True)],
        1: [(1.0, 12, 0, True)],
        2: [(1.0, 12, 0, True)],
        3: [(1.0, 12, 0, True)],
    },
    13: {
        0: [
            (0.3333, 9, 0.0, False),
            (0.3333, 12, 0.0, True),
            (0.3333, 13, 0.0, False),
        ],
        1: [
            (0.3333, 12, 0.0, True),
            (0.3333, 13, 0.0, False),
            (0.3333, 14, 0.0, False),
        ],
        2: [
            (0.3333, 13, 0.0, False),
            (0.3333, 14, 0.0, False),
            (0.3333, 9, 0.0, False),
        ],
        3: [
            (0.3333, 14, 0.0, False),
            (0.3333, 9, 0.0, False),
            (0.3333, 12, 0.0, True),
        ],
    },
    14: {
        0: [
            (0.3333, 10, 0.0, False),
            (0.3333, 13, 0.0, False),
            (0.3333, 14, 0.0, False),
        ],
        1: [
            (0.3333, 13, 0.0, False),
            (0.3333, 14, 0.0, False),
            (0.3333, 15, 1.0, True),
        ],
        2: [
            (0.3333, 14, 0.0, False),
            (0.3333, 15, 1.0, True),
            (0.3333, 10, 0.0, False),
        ],
        3: [
            (0.3333, 15, 1.0, True),
            (0.3333, 10, 0.0, False),
            (0.3333, 13, 0.0, False),
        ],
    },
    15: {
        0: [(1.0, 15, 0, True)],
        1: [(1.0, 15, 0, True)],
        2: [(1.0, 15, 0, True)],
        3: [(1.0, 15, 0, True)],
    },
}

LEFT, DOWN, RIGHT, UP = 0, 1, 2, 3

pi = lambda s: {
    0:UP, 1:UP, 2:UP, 3:UP,
    4:LEFT, 5:UP, 6:UP, 7:UP,
    8:UP, 9:DOWN, 10:LEFT, 11:LEFT,
    12:LEFT, 13:RIGHT, 14:RIGHT, 15:LEFT
}[s]

In [None]:
def policy_evaluation(pi, P, gamma=1.0, theta=1e-10):
    prev_V = np.zeros(len(P))
    while True:
        V = np.zeros(len(P))
        for s in range(len(P)):
            if s < 2: print(s)
            for prob, next_state, reward, done in P[s][pi(s)]:
                V[s] += prob * (reward + gamma * prev_V[next_state] * (not done))
        if np.max(np.abs(prev_V - V)) < theta:
            break
        prev_V = V.copy()
    return V
policy_evaluation(pi, P)

In [69]:
def policy_improvement(V, P, gamma=1.0):
    Q = np.zeros((len(P), len(P[0])), dtype=np.float64)
    for s in range(len(P)):
        for a in range(len(P[s])):
            for prob, next_state, reward, done in P[s][a]:
                Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
    new_pi = lambda s: {s: a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    return new_pi

In [72]:
def print_policy(pi):
    for i in range(0, 16, 4):
        k, l, m, n = [pi(_) for _ in range(i, i+4)]
        d = {0:"←", 1:"↓", 2:"→", 3:"↑"}
        print(d[k], d[l], d[m], d[n])
    print()

for i in tqdm(range(2)):
    print_policy(pi)
    V = policy_evaluation(pi, P)
    pi = policy_improvement(V, P)

In [85]:
def value_iteration(P, gamma=1.0, theta=1e-10):
    V = np.zeros(len(P), dtype=np.float64)
    iteration = 1
    while iteration < 20: #True:
        Q = np.zeros((len(P), len(P[0])), dtype=np.float64)
        for s in range(len(P)):
            for a in range(len(P[s])):
                for prob, next_state, reward, done in P[s][a]:
                    Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
        print_policy(pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s])
        print(iteration)
        iteration+=1
        for i in range(0, 16, 4):
            print(round(V[i], 3), round(V[i+1], 3), round(V[i+2], 3), round(V[i+3], 3))
        if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
            break
        V = np.max(Q, axis=1)
    pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    return V, pi

In [87]:
print_policy(pi)
print_policy(value_iteration(P)[1])

↑ ↑ ↑ ↑
← ↑ ↑ ↑
↑ ↓ ← ←
← → → ←

← ← ← ←
← ← ← ←
← ← ← ←
← ← ↓ ←

1
0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0
← ← ← ←
← ← ← ←
← ← ← ←
← ↓ ↓ ←

2
0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0
0.0 0.0 0.333 0.0
← ← ← ←
← ← ← ←
← ↓ ← ←
← ↓ ↓ ←

3
0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0
0.0 0.0 0.111 0.0
0.0 0.111 0.444 0.0
← ← ← ←
← ← ← ←
↓ ↓ ← ←
← → ↓ ←

4
0.0 0.0 0.0 0.0
0.0 0.0 0.037 0.0
0.0 0.074 0.148 0.0
0.0 0.185 0.518 0.0
← ↓ ← ←
← ← ← ←
↓ ↓ ← ←
← → ↓ ←

5
0.0 0.0 0.012 0.0
0.0 0.0 0.049 0.0
0.025 0.111 0.21 0.0
0.0 0.259 0.568 0.0
↓ → ← ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←

6
0.0 0.004 0.021 0.004
0.008 0.0 0.074 0.0
0.045 0.165 0.243 0.0
0.0 0.313 0.609 0.0
↓ ↑ → ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←

7
0.004 0.008 0.033 0.01
0.018 0.0 0.088 0.0
0.073 0.2 0.282 0.0
0.0 0.362 0.64 0.0
↓ ↑ → ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←

8
0.01 0.015 0.043 0.017
0.032 0.0 0.105 0.0
0.097 0.239 0.309 0.0
0.0 0.401 0.667 0.0
↓ ↑ → ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←

9
0.019 0.023 0.055 0.026
0.046 0.0 0.118 0.0
0.122 