Copyright **`(c)`** 2024 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free under certain conditions — see the [`license`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

In [1]:
from itertools import permutations
from tqdm.auto import tqdm
from icecream import ic

In [11]:
DIM = 3
GOAL_STATE = tuple(list(range(1, DIM**2)) + [0])

In [12]:
def print_board(board):
    r"""Print board"""
    print("\n".join("|".join(f'{i:1}' for i in board[r * DIM : (r + 1) * DIM]) for r in range(DIM)))


print_board(GOAL_STATE)


def i2rc(i):
    r"""Convert index to (row, column)"""
    return divmod(i, DIM)


def rc2i(r, c):
    r"""Convert (row, column) to index"""
    return r * DIM + c

1|2|3
4|5|6
7|8|0


In [13]:
ACTIONS = list()
for i in range(DIM**2):
    valid = list()
    r, c = i2rc(i)
    if r > 0:
        valid.append(rc2i(r - 1, c))
    if r < DIM - 1:
        valid.append(rc2i(r + 1, c))
    if c > 0:
        valid.append(rc2i(r, c - 1))
    if c < DIM - 1:
        valid.append(rc2i(r, c + 1))
    ACTIONS.append(tuple(valid))

for i, a in enumerate(ACTIONS):
    actions = [i2rc(_) for _ in a]
    state = i2rc(i)
    ic(state, actions)

ic| state: (0, 0), actions: [(1, 0), (0, 1)]
ic| state: (0, 1), actions: [(1, 1), (0, 0), (0, 2)]
ic| state: (0, 2), actions: [(1, 2), (0, 1)]
ic| state: (1, 0), actions: [(0, 0), (2, 0), (1, 1)]
ic| state: (1, 1), actions: [(0, 1), (2, 1), (1, 0), (1, 2)]
ic| state: (1, 2), actions: [(0, 2), (2, 2), (1, 1)]
ic| state: (2, 0), actions: [(1, 0), (2, 1)]
ic| state: (2, 1), actions: [(1, 1), (2, 0), (2, 2)]
ic| state: (2, 2), actions: [(1, 2), (2, 1)]


In [14]:
def do_action(state, action):
    r"""Perform action in state, returns new state"""
    new_state = list(state)
    new_state[state.index(0)] = new_state[action]
    new_state[action] = 0
    return tuple(new_state)


def get_possible_actions(state, final=GOAL_STATE):
    r"""Gets a list of sction/reward for current state"""
    if state == final:
        return [(state.index(0), 0)]
    else:
        return [(a, -1) for a in ACTIONS[state.index(0)]]

In [15]:
def greedy_policy(state, value):
    r"""Given a state a a value function, returns list of optimal greedy moves with associated immediate reward"""
    q = dict()
    for a, r in get_possible_actions(state):
        new_state = do_action(state, a)
        q[a] = r + value[new_state]
    max_v = max(q.values())
    return set((a, -1 if v < 0 else 0) for a, v in q.items() if v == max_v)


def random_policy(state):
    r"""Given a state a a value function, returns list of optimal greedy moves with associated immediate reward"""
    return set(get_possible_actions(state))


def describe(value):
    for i in range(1, 1 - int(min(value.values()))):
        s = [s for s, v in value.items() if v == -i]
        if s:
            print(f"Found {len(s):,} states at distance {i}")

## Reachability Analysis

In [16]:
fronteer = [GOAL_STATE]
reachable = set()
while fronteer:
    s = fronteer.pop()
    reachable.add(s)
    for a, _ in get_possible_actions(s, final=None):
        ns = do_action(s, a)
        if ns not in reachable:
            reachable.add(ns)
            fronteer.append(ns)
print(f"Found {len(reachable)} states out of {len(list(permutations(range(DIM**2))))}")

Found 181440 states out of 362880


## Value Iteration

In [8]:
value = {tuple(s): 0 for s in permutations(range(DIM**2)) if tuple(s) in reachable}
print(f"v() contains {len(value):,} states")

current_policy = {s: random_policy(s) for s in value.keys()}
for steps in tqdm(range(1000)):
    new_value = dict()
    for state in value:
        new_value[state] = 0
        actions = current_policy[state]
        for a, r in actions:
            new_value[state] += 1 / len(actions) * (r + value[do_action(state, a)])
    value = new_value

value

v() contains 12 states


  0%|          | 0/1000 [00:00<?, ?it/s]

{(0, 1, 3, 2): -19.99999999999997,
 (0, 2, 1, 3): -19.99999999999997,
 (0, 3, 2, 1): -35.99999999999994,
 (1, 0, 3, 2): -10.999999999999986,
 (1, 2, 0, 3): -10.999999999999986,
 (1, 2, 3, 0): 0.0,
 (2, 0, 1, 3): -26.99999999999996,
 (2, 3, 0, 1): -34.99999999999994,
 (2, 3, 1, 0): -31.99999999999995,
 (3, 0, 2, 1): -34.99999999999994,
 (3, 1, 0, 2): -26.99999999999996,
 (3, 1, 2, 0): -31.99999999999995}

## Value Update

In [17]:
value = {tuple(s): 0 for s in permutations(range(DIM**2)) if tuple(s) in reachable}
print(f"v() contains {len(value):,} states")

current_policy = dict()
new_policy = {s: random_policy(s) for s in value.keys()}

stopping_condition = False
steps = 0
with tqdm() as pbar:
    while not stopping_condition:
        steps += 1
        current_policy = new_policy
        new_value = dict()
        for state in value:
            new_value[state] = 0
            actions = current_policy[state]
            for a, r in actions:
                new_value[state] += 1 / len(actions) * (r + value[do_action(state, a)])
        epsilon = max(abs(new_value[s]-value[s]) for s in value)
        value = new_value
        new_policy = {s: greedy_policy(s, value) for s in value.keys()}
        pbar.update(1)
        if epsilon < 1e-10:
            stopping_condition = True
describe(value)

v() contains 181,440 states


0it [00:00, ?it/s]

Found 2 states at distance 1
Found 4 states at distance 2
Found 8 states at distance 3
Found 16 states at distance 4
Found 20 states at distance 5
Found 39 states at distance 6
Found 62 states at distance 7
Found 116 states at distance 8
Found 152 states at distance 9
Found 286 states at distance 10
Found 396 states at distance 11
Found 748 states at distance 12
Found 1,024 states at distance 13
Found 1,893 states at distance 14
Found 2,512 states at distance 15
Found 4,485 states at distance 16
Found 5,638 states at distance 17
Found 9,529 states at distance 18
Found 10,878 states at distance 19
Found 16,993 states at distance 20
Found 17,110 states at distance 21
Found 23,952 states at distance 22
Found 20,224 states at distance 23
Found 24,047 states at distance 24
Found 13,926 states at distance 25
Found 14,560 states at distance 26
Found 6,274 states at distance 27
Found 3,720 states at distance 28
Found 570 states at distance 29
Found 133 states at distance 30


In [18]:
value = {tuple(s): 0 for s in permutations(range(DIM**2)) if tuple(s) in reachable}
print(f"v() contains {len(value):,} states")

current_policy = dict()
new_policy = {s: random_policy(s) for s in value.keys()}

stopping_condition = False
steps = 0
with tqdm() as pbar:
    while not stopping_condition:
        steps += 1
        current_policy = new_policy
        for state in value:
            value[state] = 0
            actions = current_policy[state]
            for a, r in actions:
                value[state] += 1 / len(actions) * (r + value[do_action(state, a)])
        new_policy = {s: greedy_policy(s, value) for s in value.keys()}
        pbar.update(1)
        if steps > 32:
            stopping_condition = True
describe(value)

v() contains 181,440 states


0it [00:00, ?it/s]

Found 2 states at distance 1
Found 4 states at distance 2
Found 8 states at distance 3
Found 16 states at distance 4
Found 20 states at distance 5
Found 39 states at distance 6
Found 62 states at distance 7
Found 116 states at distance 8
Found 152 states at distance 9
Found 286 states at distance 10
Found 396 states at distance 11
Found 748 states at distance 12
Found 1,024 states at distance 13
Found 1,893 states at distance 14
Found 2,512 states at distance 15
Found 4,485 states at distance 16
Found 5,638 states at distance 17
Found 9,529 states at distance 18
Found 10,878 states at distance 19
Found 16,993 states at distance 20
Found 17,110 states at distance 21
Found 23,952 states at distance 22
Found 20,224 states at distance 23
Found 24,047 states at distance 24
Found 13,926 states at distance 25
Found 14,560 states at distance 26
Found 6,274 states at distance 27
Found 3,720 states at distance 28
Found 570 states at distance 29
Found 133 states at distance 30
