# Importing required modules

In [1]:
from abc import ABC, abstractmethod
import random
import sys

# Building the environment model (simulator)

Abstract interface for any environment (compare this with the abstract interface for dynamic programming):

In [2]:
class BaseEnvironment(ABC):
    
    @abstractmethod
    # Override this method to return the set of states
    def get_states(self):
        pass
    
    @abstractmethod
    # Override this method to return the set of actions available in the state
    def get_actions(self, state):
        pass
    
    @abstractmethod
    def get_all_actions(self):
        pass
    
    @abstractmethod
    # Ovveride this method to implement action execution
    def do_action_and_get_reward(self, action):
        pass

<code>GridWorld</code> environment: a 4x4 grid where the agent can move up-down and left-right (no diagonal moves). Transitions are deterministic. Cell (3,3) is the target and it is the only one with a chance to stop. Rewards are -1 for every move and 10 when the target is reached. The environment is the same we consider to test dynamic programming, but this time the agent does not know about the transition and reward probabilities.

In [29]:
class GridWorld(BaseEnvironment):
    
    UP = (-1,0)
    DOWN = (1,0)
    LEFT = (0,-1)
    RIGHT = (0,1)
    STAY = (0,0)
    
    label = {(-1,0):"UP", (1,0):"DOWN", (0,-1):"LEFT", (0,1):"RIGHT", (0,0):"STAY"}
    
    def __init__(self):
        self.states = list()
        for ri in range(4):
            for co in range(4):
                self.states.append((ri,co))
        self.actions = dict()
        self.actions[(0,0)] = [GridWorld.RIGHT, GridWorld.DOWN]
        self.actions[(0,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(0,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(0,3)] = [GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(1,0)] = [GridWorld.RIGHT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,3)] = [GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,0)] = [GridWorld.RIGHT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,3)] = [GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(3,0)] = [GridWorld.RIGHT, GridWorld.UP]
        self.actions[(3,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.UP]
        self.actions[(3,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.UP]
        self.actions[(3,3)] = [GridWorld.STAY, GridWorld.LEFT, GridWorld.UP]
        self.all_actions = [GridWorld.STAY, GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.current_state = (0,0)
    
    def get_states(self):
        return self.states
        
    def get_actions(self, state):
        return self.actions[state]
    
    def get_all_actions(self, state):
        return all_actions
    
    def transition(self, current_state, action):
        return (current_state[0] + action[0], current_state[1] + action[1]) 
    
    def do_action_and_get_reward(self, action):
        # If the action is not executable in the state, do nothing and return reward 0
        if action not in self.get_actions(self.current_state):
            return 0
        # Compute the next state based on the action
        self.current_state = self.transition(self.current_state, action)
        if (self.current_state != (3,3)):
            # The reward is -1 for every next state...
            return -1
        else:
            # ...except the goal 
            return 10
        
    def task_done(self):
        return self.current_state == (3,3)
    
    def reset(self):
        self.current_state = (0,0)

# Computing policies

Base class for all policies. The method <code>apply</code> enforces the policy, the method <code>update</code> changes the policy.

In [30]:
class BasePolicy(ABC):
    
    @abstractmethod
    # Ovveride this method to implement policy application
    # Returns the action given the state
    def apply(self, state):
        pass

A concrete policy to initialize the search for an optimal policy. Implements an $\epsilon$-greedy technique for action choice.

In [31]:
class EpsilonGreedyPolicy(BasePolicy):
    
    def __init__(self, environment, Q_table, epsilon):
        self.environment = environment
        self.Q_table = Q_table
        self.epsilon = epsilon
        
    def apply(self, state):
        actions = self.environment.get_actions(state)
        if random.random() < self.epsilon:
            # Choose an action at random with probability epsilon
            return random.choice(actions)
        else:
            # Choose the best action accordin to Q_table with probability 1-epsilon
            # If all actions have the same Q-value then break ties randomly
            max_action_value = sys.float_info.min
            best_action = random.choice(actions)
            for action in actions:
                if self.Q_table[state][action] > max_action_value:
                    max_action_value = self.Q_table[state][action]
                    best_action = action
            return action

# SARSA algorithm (on-policy)

In [32]:
class SARSA:
    
    def __init__(self, environment, gamma, alpha, epsilon, episodes):
        self.environment = environment
        self.gamma = gamma
        self.alpha = alpha
        self.epsilon = epsilon
        self.episodes = episodes
        self.Q_table = dict()
        # Initialize the value of each state-action pair to 0
        for state in environment.get_states():
            self.Q_table[state] = dict()
            for action in environment.get_actions(state):
                self.Q_table[state][action] = 0
        # Use epsilon-greedy policy for learning
        self.policy = EpsilonGreedyPolicy(environment, self.Q_table, epsilon)
            
    def apply(self):
        for e in range(self.episodes):
            state = self.environment.current_state 
            action = self.policy.apply(self.environment.current_state) 
            while not self.environment.task_done():
                reward = self.environment.do_action_and_get_reward(action)
                next_state = self.environment.current_state
                next_action = self.policy.apply(next_state)
                temp_diff = self.Q_table[next_state][next_action] - self.Q_table[state][action]
                self.Q_table[state][action] =\
                self.Q_table[state][action] + self.alpha * (reward + self.gamma * temp_diff) 
                state = next_state
                action = next_action
            # Must reset the environment before trying another episode
            self.environment.reset()

# Experimenting the SARSA algorithm

In [86]:
grid_world = GridWorld()
sarsa = SARSA(grid_world, 0.9, 0.2, 0.1, 10000)

In [87]:
sarsa.apply()

Printing the optimal policy $\pi^*$ based on the $Q^*$ values approximated by SARSA. Creating a new class for a greedy policy:

In [88]:
class GreedyPolicy(BasePolicy):
    
    def __init__(self, environment, Q_table):
        self.environment = environment
        self.Q_table = Q_table
        
    def apply(self, state):
        actions = self.environment.get_actions(state)
        max_action_value = -1 * sys.float_info.max
        best_action = None
        for action in actions:
            if self.Q_table[state][action] > max_action_value:
                max_action_value = self.Q_table[state][action]
                best_action = action
        return best_action

In [89]:
greedy_policy = GreedyPolicy(grid_world, sarsa.Q_table)

In [90]:
for ri in range(4):
    for co in range(4):
        print(GridWorld.label[greedy_policy.apply((ri,co))],end=' ')
    print()

RIGHT RIGHT RIGHT DOWN 
RIGHT RIGHT RIGHT DOWN 
DOWN RIGHT RIGHT DOWN 
RIGHT RIGHT RIGHT STAY 


Printing the values in the $Q$-table:

In [91]:
for ri in range(4):
    for co in range(4):
        print("({},{}):".format(ri,co))
        actions = sarsa.environment.get_actions((ri,co))
        for action in actions:
            print("{} -> {}".format(GridWorld.label[action],sarsa.Q_table[(ri,co)][action]))
        print

(0,0):
RIGHT -> -8954.944147686572
DOWN -> -8976.29853339469

(0,1):
RIGHT -> -8838.632941265932
LEFT -> -8975.113138891733
DOWN -> -8945.044777910418

(0,2):
RIGHT -> -8738.83332289077
LEFT -> -8946.77269966464
DOWN -> -8827.442427288508

(0,3):
LEFT -> -8816.005791926542
DOWN -> -8765.222350595872

(1,0):
RIGHT -> -8953.818456591764
DOWN -> -8987.791377855614
UP -> -8980.166695323658

(1,1):
RIGHT -> -8835.048928878821
LEFT -> -8980.226243728392
DOWN -> -8950.212536771904
UP -> -8926.726320871927

(1,2):
RIGHT -> -8744.386451547296
LEFT -> -8950.18451232361
DOWN -> -8842.994399859861
UP -> -8815.59092695276

(1,3):
LEFT -> -8819.373495055672
DOWN -> -7149.172693729554
UP -> -8774.949994169725

(2,0):
RIGHT -> -8931.720218290611
DOWN -> -8912.073888633506
UP -> -8985.308049993017

(2,1):
RIGHT -> -8877.572894442688
LEFT -> -8959.432874406475
DOWN -> -8911.623763644262
UP -> -8947.847369481122

(2,2):
RIGHT -> -8705.336324870621
LEFT -> -8924.712174444288
DOWN -> -8821.542356627033
UP 

# Q-learning algorithm (off-policy)

In [92]:
class Qlearning:
    
    def __init__(self, environment, gamma, alpha, epsilon, episodes):
        self.environment = environment
        self.gamma = gamma
        self.alpha = alpha
        self.epsilon = epsilon
        self.episodes = episodes
        self.Q_table = dict()
        # Initialize the value of each state-action pair to 0
        for state in environment.get_states():
            self.Q_table[state] = dict()
            for action in environment.get_actions(state):
                self.Q_table[state][action] = 0
        # Use epsilon-greedy policy for learning
        self.policy = EpsilonGreedyPolicy(environment, self.Q_table, epsilon)
            
    def apply(self):
        for e in range(self.episodes):
            while not self.environment.task_done():
                state = self.environment.current_state 
                action = self.policy.apply(self.environment.current_state) 
                reward = self.environment.do_action_and_get_reward(action)
                next_state = self.environment.current_state
                # Choose maximum Q-value for next state
                max_next_value = -1 * sys.float_info.max
                next_actions = self.environment.get_actions(next_state)
                for next_action in next_actions:
                    next_value = self.Q_table[next_state][next_action]
                    if next_value > max_next_value:
                        max_next_value = next_value
                # Temporal difference is computed differently w.r.t. SARSA
                temp_diff = max_next_value - self.Q_table[state][action]
                # Update equation is the same as SARSA
                self.Q_table[state][action] =\
                self.Q_table[state][action] + self.alpha * (reward + self.gamma * temp_diff) 
            # Must reset the environment before trying another episode
            self.environment.reset()

# Testing the Q-learning algorithm

In [94]:
grid_world = GridWorld()
q_learning = Qlearning(grid_world, 0.9, 0.2, 0.1, 1000)

In [95]:
q_learning.apply()

In [96]:
greedy_policy = GreedyPolicy(grid_world, q_learning.Q_table)
for ri in range(4):
    for co in range(4):
        print(GridWorld.label[greedy_policy.apply((ri,co))],end=' ')
    print()

RIGHT RIGHT RIGHT DOWN 
RIGHT RIGHT RIGHT DOWN 
RIGHT RIGHT RIGHT DOWN 
RIGHT RIGHT RIGHT STAY 
