In [238]:
import numpy as np
from tabulate import tabulate

In [239]:
np.random.seed(47)

In [240]:
transition_probabilities = np.array([
    [0,0.8,0,0,0.2,0,0,0,0,0,0,0,0,0,0,0],
    [0.1,0,0.8,0,0,0.1,0,0,0,0,0,0,0,0,0,0],
    [0,0.1,0,0.8,0,0,0.1,0,0,0,0,0,0,0,0,0],
    [0,0,0.2,0,0,0,0,0.8,0,0,0,0,0,0,0,0],
    [0.1,0,0,0,0,0.8,0,0,0.1,0,0,0,0,0,0,0],
    [0,0.1,0,0,0,0,0.8,0,0,0.1,0,0,0,0,0,0],
    [0,0,0.1,0,0,0,0,0.8,0,0,0.1,0,0,0,0,0],
    [0,0,0,0.1,0,0,0.1,0,0,0,0,0.8,0,0,0,0],
    [0,0,0,0,0.1,0,0,0,0,0.8,0,0,0.1,0,0,0],
    [0,0,0,0,0,0.1,0,0,0,0,0.8,0,0,0.1,0,0],
    [0,0,0,0,0,0,0.1,0,0,0,0,0.8,0,0,0.1,0],
    [0,0,0,0,0,0,0,0.1,0,0,0.1,0,0,0,0,0.8],
    [0,0,0,0,0,0,0,0,0.2,0,0,0,0,0.8,0,0],
    [0,0,0,0,0,0,0,0,0,0.1,0,0,0.1,0,0.8,0],
    [0,0,0,0,0,0,0,0,0,0,0.1,0,0,0.1,0,0.8],
    [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]])

print(transition_probabilities.shape)
print(tabulate(transition_probabilities.tolist(), tablefmt='grid'))

(16, 16)
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0   | 0.8 | 0   | 0   | 0.2 | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0.1 | 0   | 0.8 | 0   | 0   | 0.1 | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0   | 0.1 | 0   | 0.8 | 0   | 0   | 0.1 | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0   | 0   | 0.2 | 0   | 0   | 0   | 0   | 0.8 | 0   | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 0.1 | 0   | 0   | 0   | 0   | 0.8 | 0   | 0   | 0.1 | 0   | 0   | 0   | 0   | 0   | 0   | 0   |
+-----+----

In [241]:
actionsMap = { action: i for i, action in enumerate(['up', 'down', 'left', 'right']) }
reverseActionsMap = { v: k for k, v in actionsMap.items() }
actionsMap, reverseActionsMap

({'up': 0, 'down': 1, 'left': 2, 'right': 3},
 {0: 'up', 1: 'down', 2: 'left', 3: 'right'})

In [242]:
def getValidActions(pos):
    row, col = pos

    actions = []
    if row > 0:
        actions.append(actionsMap['up'])
    if row < 3:
        actions.append(actionsMap['down'])
    if col > 0:
        actions.append(actionsMap['left'])
    if col < 3:
        actions.append(actionsMap['right'])
    return actions

def move(currentState, action):
    row, col = currentState
    if action == actionsMap['up']:
        row -= 1
    elif action == actionsMap['down']:
        row += 1
    elif action == actionsMap['left']:
        col -= 1
    elif action == actionsMap['right']:
        col += 1
    return (row, col)

In [243]:
policy = np.array([
    [np.random.choice(getValidActions((i, j))) for j in range(4)] for i in range(4)
])
policy

array([[3, 3, 1, 1],
       [0, 3, 3, 0],
       [1, 1, 3, 2],
       [0, 3, 2, 2]])

In [244]:
stateValues = np.random.uniform(1.0, 5.0, (4, 4))
stateValues

array([[3.05603442, 2.59837491, 2.41775177, 4.05044397],
       [1.20646147, 3.93370981, 4.07945725, 2.15228722],
       [3.78758545, 3.03761023, 4.22367201, 4.6929683 ],
       [2.46028107, 3.0530855 , 4.03717615, 1.5156356 ]])

In [None]:
def policyEvaluation(stateValues, policy, gamma=0.9, theta=0.01):
    delta = float('inf')
    while delta > theta:
        delta = 0
        newState = np.copy(stateValues)
        # for all states
        for i in range(4):
            for j in range(4):
                state = (i, j)

                if state == (1, 1) or state == (3, 3):
                    continue

                action = policy[i][j]
                next_state = move(state, action)


                ns_i, ns_j = next_state

                reward = -1 if next_state == (1, 1) else (10 if next_state == (3, 3) else -0.1)
                prob = transition_probabilities[i * 4 + j][ns_i * 4 + ns_j]

                new_value = prob * (reward + gamma * stateValues[ns_i][ns_j])
                delta = max(delta, abs(new_value - stateValues[i][j]))
                newState[i][j] = new_value

        stateValues[:, :] = newState[:, :]
        print("delta", delta)
    return stateValues

In [None]:
policyEvaluation(stateValues, policy)

delta 5.054081476396987
delta 3.638938663005831
delta 2.6200358373641994
delta 1.6416526950774561
delta 1.181989940455768
delta 0.8510327571281532
delta 0


array([[1.52865254, 2.23423964, 3.21422173, 4.57530796],
       [2.03227106, 3.93370981, 4.57530796, 6.46570549],
       [3.21422173, 4.57530796, 6.46570549, 9.09125763],
       [4.57530796, 6.46570549, 9.09125763, 1.5156356 ]])

In [247]:
print(tabulate(stateValues, tablefmt='grid'))

+---------+---------+---------+---------+
| 1.52865 | 2.23424 | 3.21422 | 4.57531 |
+---------+---------+---------+---------+
| 2.03227 | 3.93371 | 4.57531 | 6.46571 |
+---------+---------+---------+---------+
| 3.21422 | 4.57531 | 6.46571 | 9.09126 |
+---------+---------+---------+---------+
| 4.57531 | 6.46571 | 9.09126 | 1.51564 |
+---------+---------+---------+---------+


In [None]:
def policyImprovement(stateValues, policy, gamma=0.9):
    policy_stable = True
    newActions = np.copy(policy)

    for i in range(4):
        for j in range(4):
            state = (i, j)

            if state == (1, 1) or state == (3, 3):
                continue

            old_action = policy[i][j]
            best_value = float('-inf')
            best_action = old_action

            for action in getValidActions(state):
                next_state = move(state, action)
                ns_i, ns_j = next_state
                reward = -1 if next_state == (1, 1) else (10 if next_state == (3, 3) else -0.1)
                prob = transition_probabilities[i * 4 + j][ns_i * 4 + ns_j]
                print(next_state, stateValues)
                value = prob * (reward + gamma * stateValues[ns_i][ns_j])

                if value > best_value:
                    best_value = value
                    best_action = action

            newActions[i][j] = best_action
            if best_action != old_action:
                policy_stable = False

    return newActions, policy_stable


In [None]:
def policyIteration(stateValues, policy, gamma=0.9, theta=0.01):
    while True:
        stateValues = policyEvaluation(stateValues, policy)
        policy, stable = policyImprovement(stateValues, policy)
        print(policy, stable)

        if stable:
            break

    return stateValues, policy

In [None]:
stateValues, policy = policyIteration(stateValues, policy)

delta 0
(1, 0) [[1.52865254 2.23423964 3.21422173 4.57530796]
 [2.03227106 3.93370981 4.57530796 6.46570549]
 [3.21422173 4.57530796 6.46570549 9.09125763]
 [4.57530796 6.46570549 9.09125763 1.5156356 ]]
(0, 1) [[1.52865254 2.23423964 3.21422173 4.57530796]
 [2.03227106 3.93370981 4.57530796 6.46570549]
 [3.21422173 4.57530796 6.46570549 9.09125763]
 [4.57530796 6.46570549 9.09125763 1.5156356 ]]
(1, 1) [[1.52865254 2.23423964 3.21422173 4.57530796]
 [2.03227106 3.93370981 4.57530796 6.46570549]
 [3.21422173 4.57530796 6.46570549 9.09125763]
 [4.57530796 6.46570549 9.09125763 1.5156356 ]]
(0, 0) [[1.52865254 2.23423964 3.21422173 4.57530796]
 [2.03227106 3.93370981 4.57530796 6.46570549]
 [3.21422173 4.57530796 6.46570549 9.09125763]
 [4.57530796 6.46570549 9.09125763 1.5156356 ]]
(0, 2) [[1.52865254 2.23423964 3.21422173 4.57530796]
 [2.03227106 3.93370981 4.57530796 6.46570549]
 [3.21422173 4.57530796 6.46570549 9.09125763]
 [4.57530796 6.46570549 9.09125763 1.5156356 ]]
(1, 2) [[1.5

In [None]:
printableTable = []
for i in range(4):
    printableTable.append([reverseActionsMap[a] for a in policy[i]])

printableTable[1][1] = 'x'
printableTable[3][3] = 'win'
print(tabulate(printableTable, tablefmt='grid'))

+-------+-------+-------+------+
| right | right | right | down |
+-------+-------+-------+------+
| right | x     | right | down |
+-------+-------+-------+------+
| right | right | right | down |
+-------+-------+-------+------+
| right | right | right | win  |
+-------+-------+-------+------+
