# Importing required modules

In [1]:
from abc import ABC, abstractmethod
import random
import sys

# Building the environment model (simulator)

Abstract interface for any environment (compare this with the abstract interface for dynamic programming):

In [2]:
# Base Environment is an abstarct base class that inherits from ABC (Abstract Base Class)

class BaseEnvironment(ABC):
    
    @abstractmethod
    # Override this method to return the set of states
    def get_states(self):
        pass
    
    @abstractmethod
    # Override this method to return the set of actions available in the state
    def get_actions(self, state):
        pass
    
    @abstractmethod
    def get_all_actions(self):
        pass
    
    @abstractmethod
    # Ovveride this method to implement action execution
    def do_action_and_get_reward(self, action):
        pass

<code>GridWorld</code> environment: a 4x4 grid where the agent can move up-down and left-right (no diagonal moves). Transitions are deterministic. Cell (3,3) is the target and it is the only one with a chance to stop. Rewards are -1 for every move and 10 when the target is reached. The environment is the same we consider to test dynamic programming, but this time the agent does not know about the transition and reward probabilities.

In [3]:
class GridWorld(BaseEnvironment):
    
    # Top row top column in the grid is (0,0) coordinate
    UP = (-1,0)
    DOWN = (1,0)
    LEFT = (0,-1)
    RIGHT = (0,1)
    STAY = (0,0)
    
    # 'dict' mapping actions to symbolic names
    label = {(-1,0):"UP", (1,0):"DOWN", (0,-1):"LEFT", (0,1):"RIGHT", (0,0):"STAY"}
    
    def __init__(self):
        self.states = list()
        # States are just a list of coordinates
        for ri in range(4):
            for co in range(4):
                self.states.append((ri,co))
        # Actions are a map of state to list of actions to allowable actions in the state
        self.actions = dict()
        self.actions[(0,0)] = [GridWorld.RIGHT, GridWorld.DOWN]
        self.actions[(0,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(0,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(0,3)] = [GridWorld.LEFT, GridWorld.DOWN]
        self.actions[(1,0)] = [GridWorld.RIGHT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(1,3)] = [GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,0)] = [GridWorld.RIGHT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(2,3)] = [GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.actions[(3,0)] = [GridWorld.RIGHT, GridWorld.UP]
        self.actions[(3,1)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.UP]
        self.actions[(3,2)] = [GridWorld.RIGHT, GridWorld.LEFT, GridWorld.UP]
        self.actions[(3,3)] = [GridWorld.STAY, GridWorld.LEFT, GridWorld.UP]
        self.all_actions = [GridWorld.STAY, GridWorld.RIGHT, GridWorld.LEFT, GridWorld.DOWN, GridWorld.UP]
        self.current_state = (0,0)
    
    def get_states(self):
        return self.states
        
    def get_actions(self, state):
        return self.actions[state]
    
    def get_all_actions(self, state):
        return all_actions
    
    # Deterministc transition function
    def transition(self, current_state, action):
        # The next state has the following coordinates (x,y)
        # x = current_state[0] + action[0]
        # y = current_state[1] + action[1] 
        return (current_state[0] + action[0], current_state[1] + action[1]) 
    
    def do_action_and_get_reward(self, action):
        # If the action is not executable in the state, do nothing and return reward 0
        if action not in self.get_actions(self.current_state):
            return 0
        # Compute the next state based on the action
        self.current_state = self.transition(self.current_state, action)
        if (self.current_state != (3,3)):
            # if the new state coincides with the next state the probability is 1
            # otherwise the probability is 0
            return -1
        else:
            # ...except the goal 
            return 10
        
    def task_done(self):
        return self.current_state == (3,3)
    
    def reset(self):
        self.current_state = (0,0)

# Computing policies

Base class for all policies. The method <code>apply</code> enforces the policy, the method <code>update</code> changes the policy.

In [4]:
class BasePolicy(ABC):
    
    @abstractmethod
    # Ovveride this method to implement policy application
    # Returns the action given the state
    def apply(self, state):
        pass

    def update(self, state, action):
        pass

A concrete policy to initialize the search for an optimal policy. Implements an $\epsilon$-greedy technique for action choice.

In [5]:
class EpsilonGreedyPolicy(BasePolicy):
    
    def __init__(self, environment, Q_table, epsilon):
        self.environment = environment
        self.Q_table = Q_table
        self.epsilon = epsilon
        
    def apply(self, state):
        actions = self.environment.get_actions(state)
        if random.random() < self.epsilon:
            # Choose an action at random with probability epsilon
            return random.choice(actions)
        else:
            # Choose the best action accordin to Q_table with probability 1-epsilon
            # If all actions have the same Q-value then break ties randomly
            max_action_value = sys.float_info.min
            best_action = random.choice(actions)
            for action in actions:
                if self.Q_table[state][action] > max_action_value:
                    max_action_value = self.Q_table[state][action]
                    best_action = action
            return action

# SARSA algorithm (on-policy)

In [6]:
class SARSA:
    
    def __init__(self, environment, gamma, alpha, epsilon, episodes):
        self.environment = environment
        self.gamma = gamma
        self.alpha = alpha
        self.epsilon = epsilon
        self.episodes = episodes
        self.Q_table = dict()
        # Initialize the value of each state-action pair to 0
        for state in environment.get_states():
            self.Q_table[state] = dict()
            for action in environment.get_actions(state):
                self.Q_table[state][action] = 0
        # Use epsilon-greedy policy for learning
        self.policy = EpsilonGreedyPolicy(environment, self.Q_table, epsilon)
            
    def apply(self):
        for e in range(self.episodes):
            state = self.environment.current_state 
            action = self.policy.apply(self.environment.current_state) 
            while not self.environment.task_done():
                reward = self.environment.do_action_and_get_reward(action)
                next_state = self.environment.current_state
                next_action = self.policy.apply(next_state)
                temp_diff = self.Q_table[next_state][next_action] - self.Q_table[state][action]
                self.Q_table[state][action] =\
                self.Q_table[state][action] + self.alpha * (reward + self.gamma * temp_diff) 
                state = next_state
                action = next_action
            # Must reset the environment before trying another episode
            self.environment.reset()

# Experimenting the SARSA algorithm

In [7]:
grid_world = GridWorld()
sarsa = SARSA(grid_world, 0.9, 0.2, 0.1, 10000)

In [8]:
sarsa.apply()

Printing the optimal policy $\pi^*$ based on the $Q^*$ values approximated by SARSA. Creating a new class for a greedy policy:

In [9]:
class GreedyPolicy(BasePolicy):
    
    def __init__(self, environment, Q_table):
        self.environment = environment
        self.Q_table = Q_table
        
    def apply(self, state):
        actions = self.environment.get_actions(state)
        max_action_value = -1 * sys.float_info.max
        best_action = None
        for action in actions:
            if self.Q_table[state][action] > max_action_value:
                max_action_value = self.Q_table[state][action]
                best_action = action
        return best_action

In [10]:
greedy_policy = GreedyPolicy(grid_world, sarsa.Q_table)

In [11]:
for ri in range(4):
    for co in range(4):
        print(GridWorld.label[greedy_policy.apply((ri,co))],end=' ')
    print()

RIGHT RIGHT RIGHT DOWN 
DOWN RIGHT RIGHT DOWN 
RIGHT DOWN RIGHT DOWN 
RIGHT RIGHT RIGHT STAY 


Printing the values in the $Q$-table:

In [12]:
for ri in range(4):
    for co in range(4):
        print("({},{}):".format(ri,co))
        actions = sarsa.environment.get_actions((ri,co))
        for action in actions:
            print("{} -> {}".format(GridWorld.label[action],sarsa.Q_table[(ri,co)][action]))
        print

(0,0):
RIGHT -> -9619.168836722341
DOWN -> -9629.453339471214
(0,1):
RIGHT -> -9588.476628707229
LEFT -> -9628.653121787027
DOWN -> -9623.179955730731
(0,2):
RIGHT -> -9511.264371346395
LEFT -> -9615.571254825223
DOWN -> -9575.05384408508
(0,3):
LEFT -> -9572.428297889253
DOWN -> -9146.73128108183
(1,0):
RIGHT -> -9616.820355468895
DOWN -> -9598.615428556577
UP -> -9626.479809595492
(1,1):
RIGHT -> -9583.8374049192
LEFT -> -9628.883395572204
DOWN -> -9618.863064087307
UP -> -9615.652798438943
(1,2):
RIGHT -> -9490.539092228908
LEFT -> -9616.01339499991
DOWN -> -9520.612343120301
UP -> -9586.407055723097
(1,3):
LEFT -> -9553.07674360966
DOWN -> -6867.006218288734
UP -> -9316.439783456959
(2,0):
RIGHT -> -9496.93852631532
DOWN -> -9504.156402499693
UP -> -9626.40275635479
(2,1):
RIGHT -> -9488.840163294892
LEFT -> -9522.659789212608
DOWN -> -9468.780487603939
UP -> -9624.510760661622
(2,2):
RIGHT -> -9260.927023339187
LEFT -> -9504.247495946158
DOWN -> -9406.050696941951
UP -> -9585.1159

# Q-learning algorithm (off-policy)

In [13]:
class Qlearning:
    
    def __init__(self, environment, gamma, alpha, epsilon, episodes):
        self.environment = environment
        self.gamma = gamma
        self.alpha = alpha
        self.epsilon = epsilon
        self.episodes = episodes
        self.Q_table = dict()
        # Initialize the value of each state-action pair to 0
        for state in environment.get_states():
            self.Q_table[state] = dict()
            for action in environment.get_actions(state):
                self.Q_table[state][action] = 0
        # Use epsilon-greedy policy for learning
        self.policy = EpsilonGreedyPolicy(environment, self.Q_table, epsilon)
            
    def apply(self):
        for e in range(self.episodes):
            while not self.environment.task_done():
                state = self.environment.current_state 
                action = self.policy.apply(self.environment.current_state) 
                reward = self.environment.do_action_and_get_reward(action)
                next_state = self.environment.current_state
                # Choose maximum Q-value for next state
                max_next_value = -1 * sys.float_info.max
                next_actions = self.environment.get_actions(next_state)
                for next_action in next_actions:
                    next_value = self.Q_table[next_state][next_action]
                    if next_value > max_next_value:
                        max_next_value = next_value
                # Temporal difference is computed differently w.r.t. SARSA
                temp_diff = max_next_value - self.Q_table[state][action]
                # Update equation is the same as SARSA
                self.Q_table[state][action] =\
                self.Q_table[state][action] + self.alpha * (reward + self.gamma * temp_diff) 
            # Must reset the environment before trying another episode
            self.environment.reset()

# Testing the Q-learning algorithm

In [14]:
grid_world = GridWorld()
q_learning = Qlearning(grid_world, 0.9, 0.2, 0.1, 1000)

In [15]:
q_learning.apply()

In [16]:
greedy_policy = GreedyPolicy(grid_world, q_learning.Q_table)
for ri in range(4):
    for co in range(4):
        print(GridWorld.label[greedy_policy.apply((ri,co))],end=' ')
    print()

RIGHT RIGHT RIGHT DOWN 
RIGHT RIGHT RIGHT DOWN 
RIGHT RIGHT RIGHT DOWN 
RIGHT RIGHT RIGHT STAY 
