# Environment Setup

## Imports and trivial functions

In [430]:
import numpy as np
import matplotlib as plt
import copy
from pprint import pprint
import random
import operator

def pretty_print(two_d_array: list) -> None:
    for i in two_d_array:
        pv = [round(x, 2) for x in i]
        print(*pv, sep='\t')

## Setting up State Space

In [431]:
state_space = np.zeros((5, 10))
for i in range(0, len(state_space)-1):
    for j in range(0, len(state_space[i]-1)):
        state_space[i][j] = -1
for i in range(0, len(state_space[-1])):
    if 0 < i and i < len(state_space[-1])-1:
        state_space[-1][i] = -100
    elif i == len(state_space[-1])-1:
        state_space[-1][i] = 20
original_state_space = copy.deepcopy(state_space)
print(state_space)

[[  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [  -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.   -1.]
 [   0. -100. -100. -100. -100. -100. -100. -100. -100.   20.]]


## Define possible actions for each state

In [432]:
possible_actions = {}
for i in range(0, len(state_space)):
    for j in range(0, len(state_space[i])):
        list_of_actions = []
        # Each state has a maximum of 4 possible actions (New states to transition to)
        # We check the boundaries to ensure we do not add an action outside our state space
        if not (i-1 < 0):
            list_of_actions.append((i-1, j))
        if not (i+1 >= len(state_space)):
            list_of_actions.append((i+1, j))
        if not (j-1 < 0):
            list_of_actions.append((i, j-1))
        if not(j+1 >= len(state_space[i])):
            list_of_actions.append((i, j+1))
        possible_actions[(i, j)]= list_of_actions

### Defining the start node and the terminal states

In [433]:
init_state = (4, 0)
terminal_states = [(4, x) for x in range(1, len(state_space[-1]))]

## $\epsilon$-greedy Policy

In [434]:
# input: State Space, current state, List of possible actions (locations), epsilon value
# returns: new action (location tuple)
def policy(current_state: tuple, e: float = 0.1) -> tuple:
    action_payoffs = {}
    # retreiving payoffs of possible actions
    for action in possible_actions[current_state]:
        i, j = action
        action_payoffs[action] = state_space[i][j]
    # Decide with the e-greedy policy the action to return
    if random.random() > e:
        return max(action_payoffs, key=action_payoffs.get)
    return random.choice(list(action_payoffs.keys()))

# SARSA

Input: 
- Starting state 
- List of terminal states
- Iteration limit
- Learning rate $\alpha$
- Discount factor $\gamma$
- Policy value $\epsilon$

In [435]:
def SARSA(init: tuple, terminal_states: list, limit: int = 10000,
          alpha: float = 0.9, gamma: float = 0.9, epsilon: float = 0.1) -> list:
    for _ in range(0, limit):
        current_state = copy.deepcopy(init)
        while True:
            # retreive best action from policy
            action = policy(current_state=current_state, e=epsilon) 
            i, j = current_state
            k, l = action

            # SARSA Update rule
            state_space[i][j] += alpha * (original_state_space[k][l] 
                              + (gamma * state_space[k][l]) - state_space[i][j])

            # Move to next state
            current_state = action
            if current_state in terminal_states:
                break
    return state_space

# Q-Learning

In [436]:
def Q_Learning(init: tuple, terminal_states: list, limit: int = 10000,
          alpha: float = 0.9, gamma: float = 0.9, epsilon: float = 0.1) -> list:
    for _ in range(0, limit):
        current_state = copy.deepcopy(init)
        action = policy(current_state=current_state, e=epsilon)
        while True:
            # retreive best action from policy
            action = policy(current_state=current_state, e=epsilon)

            # retreiving list of possible pay-offs from current state
            action_payoffs = {}
            for act in possible_actions[current_state]:
                a, b = act
                action_payoffs[act] = state_space[a][b]
            
            i, j = current_state
            k, l = max(action_payoffs, key=action_payoffs.get)

            # Q-Learning update rule (Utilising max pay-off possible)
            state_space[i][j] += alpha * (original_state_space[i][j] 
                              + (gamma * state_space[k][l]) - state_space[i][j])
            
            # Moving to new state chosen by policy
            current_state = action
            if current_state in terminal_states:
                break
    return state_space

# Tests

## SARSA Test

In [442]:
def SARSA_test(limit, alpha, gamma, epsilon) -> None:
    state_space = SARSA(init_state, terminal_states, limit=limit, alpha=alpha, gamma=gamma, epsilon=epsilon)
    pretty_print(state_space)
    state_space = copy.deepcopy(original_state_space)

print('SARSA Test')
SARSA_test(limit=1000, alpha=0.9, gamma=0.9, epsilon=0.1)

SARSA Test
-34.36	-16.38	-10.24	-11.78	-11.87	-12.45	-10.3	-11.34	-10.0	-11.58
-10.05	-15.17	-10.82	-10.67	-34.18	-11.76	-13.97	-10.09	-11.7	-9.94
-13.13	-19.41	-116.42	-19.27	-11.19	-13.75	-11.36	-10.07	13.83	-16.26
-10.24	-13.95	-23.14	-19.03	-28.92	-63.73	-19.36	-171.99	19.05	36.36
-37.76	-100.0	-100.0	-100.0	-100.0	-100.0	-100.0	-100.0	-100.0	20.0


## Q-Learning Test

In [443]:
def Q_test(limit, alpha, gamma, epsilon) -> None:
    state_space = Q_Learning(init_state, terminal_states, limit=limit, alpha=alpha, gamma=gamma, epsilon=epsilon)
    pretty_print(state_space)
    state_space = copy.deepcopy(original_state_space)
print('Q-Learning Test')
Q_test(limit=1000, alpha=0.9, gamma=0.9, epsilon=0.1)

Q-Learning Test
-9.33	-16.38	-10.24	-11.78	-11.87	-12.45	-10.3	-11.34	-10.0	-11.58
-6.16	-6.63	-10.82	-10.67	-34.18	-11.76	-13.97	-10.09	-11.7	-9.94
-5.74	-6.16	-7.65	-19.27	-11.19	-13.75	-11.36	-10.07	13.83	-16.26
-5.26	-5.74	-6.16	-19.03	-28.92	-63.73	-19.36	-171.99	19.05	36.36
-4.74	-100.0	-100.0	-100.0	-100.0	-100.0	-100.0	-100.0	-100.0	20.0
