In [281]:
import numpy as np
from tabulate import tabulate

In [282]:
np.random.seed(47)

In [283]:
transition_probabilities = np.array([
    [0,0.8,0,0,0.2,0,0,0,0,0,0,0,0,0,0,0],
    [0.1,0,0.8,0,0,0.1,0,0,0,0,0,0,0,0,0,0],
    [0,0.1,0,0.8,0,0,0.1,0,0,0,0,0,0,0,0,0],
    [0,0,0.2,0,0,0,0,0.8,0,0,0,0,0,0,0,0],
    [0.1,0,0,0,0,0.8,0,0,0.1,0,0,0,0,0,0,0],
    [0,0.1,0,0,0,0,0.8,0,0,0.1,0,0,0,0,0,0],
    [0,0,0.1,0,0,0,0,0.8,0,0,0.1,0,0,0,0,0],
    [0,0,0,0.1,0,0,0.1,0,0,0,0,0.8,0,0,0,0],
    [0,0,0,0,0.1,0,0,0,0,0.8,0,0,0.1,0,0,0],
    [0,0,0,0,0,0.1,0,0,0,0,0.8,0,0,0.1,0,0],
    [0,0,0,0,0,0,0.1,0,0,0,0,0.8,0,0,0.1,0],
    [0,0,0,0,0,0,0,0.1,0,0,0.1,0,0,0,0,0.8],
    [0,0,0,0,0,0,0,0,0.2,0,0,0,0,0.8,0,0],
    [0,0,0,0,0,0,0,0,0,0.1,0,0,0.1,0,0.8,0],
    [0,0,0,0,0,0,0,0,0,0,0.1,0,0,0.1,0,0.8],
    [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]])

print(transition_probabilities.shape)
print(tabulate(transition_probabilities.tolist(), tablefmt='grid'))

(16, 16)
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0   | 0.8 | 0   | 0   | 0.2 | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0.1 | 0   | 0.8 | 0   | 0   | 0.1 | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0   | 0.1 | 0   | 0.8 | 0   | 0   | 0.1 | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0   | 0   | 0.2 | 0   | 0   | 0   | 0   | 0.8 | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0.1 | 0   | 0   | 0   | 0   | 0.8 | 0   | 0   | 0.1 | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+----

In [284]:
actionsMap = { action: i for i, action in enumerate(['up', 'down', 'left', 'right']) }
reverseActionsMap = { v: k for k, v in actionsMap.items() }
actionsMap, reverseActionsMap

({'up': 0, 'down': 1, 'left': 2, 'right': 3},
 {0: 'up', 1: 'down', 2: 'left', 3: 'right'})

In [285]:
def getValidActions(pos):
    row, col = pos

    actions = []
    if row > 0:
        actions.append(actionsMap['up'])
    if row < 3:
        actions.append(actionsMap['down'])
    if col > 0:
        actions.append(actionsMap['left'])
    if col < 3:
        actions.append(actionsMap['right'])
    return actions

def move(currentState, action):
    row, col = currentState
    if action == actionsMap['up']:
        row -= 1
    elif action == actionsMap['down']:
        row += 1
    elif action == actionsMap['left']:
        col -= 1
    elif action == actionsMap['right']:
        col += 1
    return (row, col)

In [286]:
policy = np.array([
    [np.random.choice(getValidActions((i, j))) for j in range(4)] for i in range(4)
])
policy

array([[3, 3, 1, 1],
       [0, 3, 3, 0],
       [1, 1, 3, 2],
       [0, 3, 2, 2]])

In [287]:
stateValues = np.random.uniform(1.0, 5.0, (4, 4))
stateValues

array([[3.05603442, 2.59837491, 2.41775177, 4.05044397],
       [1.20646147, 3.93370981, 4.07945725, 2.15228722],
       [3.78758545, 3.03761023, 4.22367201, 4.6929683 ],
       [2.46028107, 3.0530855 , 4.03717615, 1.5156356 ]])

In [288]:
def policyEvaluation(stateValues, policy, gamma=0.9, theta=0.01):
    delta = float('inf')
    while delta > theta:
        delta = 0
        newState = np.copy(stateValues)
        # for all states
        for i in range(4):
            for j in range(4):
                state = (i, j)

                if state == (1, 1) or state == (3, 3):
                    continue

                action = policy[i][j]
                next_state = move(state, action)


                ns_i, ns_j = next_state

                reward = -1 if next_state == (1, 1) else (10 if next_state == (3, 3) else -0.1)
                prob = transition_probabilities[i * 4 + j][ns_i * 4 + ns_j]

                new_value = prob * (reward + gamma * stateValues[ns_i][ns_j])
                delta = max(delta, abs(new_value - stateValues[i][j]))
                newState[i][j] = new_value

        stateValues[:, :] = newState[:, :]
        print("delta", delta)
    return stateValues

In [289]:
policyEvaluation(stateValues, policy)

delta 4.322837820822732
delta 3.1124432309923673
delta 1.0682153591669101
delta 0.20168632136830542
delta 0.06039050144166146
delta 0.0130692736246662
delta 0.003913304493419673


array([[-0.14632251, -0.09281683, -0.01829031, -0.09281683],
       [-0.02281683,  3.93370981, -0.09281683, -0.01829031],
       [-0.01199336, -0.01831474, -0.09231908, -0.01828607],
       [-0.02215607, -0.09244756, -0.01831474,  1.5156356 ]])

In [290]:
print(tabulate(stateValues, tablefmt='grid'))

+------------+------------+------------+------------+
| -0.146323  | -0.0928168 | -0.0182903 | -0.0928168 |
+------------+------------+------------+------------+
| -0.0228168 |  3.93371   | -0.0928168 | -0.0182903 |
+------------+------------+------------+------------+
| -0.0119934 | -0.0183147 | -0.0923191 | -0.0182861 |
+------------+------------+------------+------------+
| -0.0221561 | -0.0924476 | -0.0183147 |  1.51564   |
+------------+------------+------------+------------+


In [291]:
def policyImprovement(stateValues, policy, gamma=0.9):
    policy_stable = True
    newActions = np.copy(policy)

    for i in range(4):
        for j in range(4):
            state = (i, j)

            if state == (1, 1) or state == (3, 3):
                continue

            old_action = policy[i][j]
            best_value = float('-inf')
            best_action = old_action

            for action in getValidActions(state):
                next_state = move(state, action)
                ns_i, ns_j = next_state
                reward = -1 if next_state == (1, 1) else (10 if next_state == (3, 3) else -0.1)
                prob = transition_probabilities[i * 4 + j][ns_i * 4 + ns_j]
                print(next_state, stateValues)
                value = prob * (reward + gamma * stateValues[ns_i][ns_j])

                if value > best_value:
                    best_value = value
                    best_action = action

            newActions[i][j] = best_action
            if best_action != old_action:
                policy_stable = False

    return newActions, policy_stable


In [292]:
def policyIteration(stateValues, policy, gamma=0.9, theta=0.01):
    while True:
        stateValues = policyEvaluation(stateValues, policy)
        policy, stable = policyImprovement(stateValues, policy)
        print(policy, stable)

        if stable:
            break

    return stateValues, policy

In [293]:
stateValues, policy = policyIteration(stateValues, policy)

delta 0.0008468889308783645
(1, 0) [[-0.14682812 -0.09316903 -0.01835351 -0.09316903]
 [-0.02316903  3.93370981 -0.09316903 -0.01835351]
 [-0.01199405 -0.01832028 -0.09316597 -0.01830872]
 [-0.0221588  -0.09318661 -0.01832028  1.5156356 ]]
(0, 1) [[-0.14682812 -0.09316903 -0.01835351 -0.09316903]
 [-0.02316903  3.93370981 -0.09316903 -0.01835351]
 [-0.01199405 -0.01832028 -0.09316597 -0.01830872]
 [-0.0221588  -0.09318661 -0.01832028  1.5156356 ]]
(1, 1) [[-0.14682812 -0.09316903 -0.01835351 -0.09316903]
 [-0.02316903  3.93370981 -0.09316903 -0.01835351]
 [-0.01199405 -0.01832028 -0.09316597 -0.01830872]
 [-0.0221588  -0.09318661 -0.01832028  1.5156356 ]]
(0, 0) [[-0.14682812 -0.09316903 -0.01835351 -0.09316903]
 [-0.02316903  3.93370981 -0.09316903 -0.01835351]
 [-0.01199405 -0.01832028 -0.09316597 -0.01830872]
 [-0.0221588  -0.09318661 -0.01832028  1.5156356 ]]
(0, 2) [[-0.14682812 -0.09316903 -0.01835351 -0.09316903]
 [-0.02316903  3.93370981 -0.09316903 -0.01835351]
 [-0.01199405 -

In [294]:
printableTable = [[reverseActionsMap[a] for a in policy[i]] for i in range(4)]

printableTable[1][1] = 'x'
printableTable[3][3] = 'win'
print(tabulate(printableTable, tablefmt='grid'))

+-------+-------+-------+------+
| right | right | right | down |
+-------+-------+-------+------+
| right | x     | right | down |
+-------+-------+-------+------+
| right | right | right | down |
+-------+-------+-------+------+
| right | right | right | win  |
+-------+-------+-------+------+
