In [2]:
# Simple q-learning example to solve maze problem

import numpy as np
import random as rd

def init_maze():
    return [[-1, -1, -1, -1, 0, -1],
            [-1, -1, -1, 0, -1, 0],
            [-1, -1, -1, 0, -1, -1],
            [-1, 0, 0, -1, 0, -1],
            [0, -1, -1, 0, -1, 0],
            [-1, 0, -1, -1, 0, 0]]

def is_valid_action(maze, state, action):
    # return True if action is valid
    if maze[state][action] == -1:
        return False
    else:
        return True

def init_q_table(n_states, n_actions):
    # return q_table that is all zeros
    return np.zeros((n_states, n_actions))

def q_learning(q_table, state, action, reward, next_state, alpha, gamma):
    # update q_table based on bellman equation
    q_table[state][action] = (1 - alpha) * q_table[state][action] + \
                                    alpha * (reward + gamma * np.max(q_table[next_state]))
    return q_table

def epsilon_greedy(q_table, state, epsilon):
    # return random action in some cases, otherwise return best action
    if rd.random() < epsilon:
        return rd.randint(0, 5)
    else:
        return np.argmax(q_table[state])

def policy(q_table, state):
    # return best action for state
    return np.argmax(q_table[state])

def train(q_table, n_episodes, alpha, gamma, epsilon):
    maze = init_maze()
    # train q_table
    for i in range(n_episodes):
        state = rd.randint(0, 5)
        while state != 5:
            action = epsilon_greedy(q_table, state, epsilon)
            if not is_valid_action(maze, state, action):
                continue
            else:
                reward = 0 if action != 5 else 100
            next_state = action
            q_table = q_learning(q_table, state, action, reward, next_state, alpha, gamma)
            state = next_state
    return q_table

def print_q_table(q_table):
    print("Q-table:")
    for i in range(len(q_table)):
        for j in range(len(q_table[i])):
            print("{:6.2f}".format(q_table[i][j]), end=" ")
        print()

q_table = init_q_table(6, 6)
n_episodes = 100
alpha = 0.1
gamma = 0.9
epsilon = 0.1
q_table = train(q_table, n_episodes, alpha, gamma, epsilon)

print_q_table(q_table)

Q-table:
  0.00   0.00   0.00   0.00  44.17   0.00 
  0.00   0.00   0.00   7.29   0.00  99.99 
  0.00   0.00   0.00  69.92   0.00   0.00 
  0.00  89.76   3.95   0.00   0.00   0.00 
  0.23   0.00   0.00  72.91   0.00   0.00 
  0.00   0.00   0.00   0.00   0.00   0.00 


In [4]:
state = 2
while state != 5:
    action = policy(q_table, state)
    print(f"{state}->{action}")
    state = action

2->3
3->1
1->5
