## Tabular Case Sarsa, Q-learning and Expected Sarsa Algorithms

Implementation of three algorithms for reinforcement learning on simple gridworld problem.

In [318]:
import numpy as np
import matplotlib.pyplot as plt

In [319]:
class Environment():
    ''' Gridworld enviroment dimensions (grid_h, grid_w) with different possible obstacles, teleports.'''
    def env_init(self, env_info={}):
        self.start_loc = (0, 0)
        self.goal_loc = (3, 5)
        self.teleport = [(0, 1), (1, 1), (1, 3), (2, 3), (3, 3), (2, 5)]
        
        self.grid_w = 6
        self.grid_h = 4
        
    def state(self, loc):
        return loc[0] * self.grid_w + loc[1]
    
    def env_start(self):
        reward = 0
        self.agent_loc = self.start_loc
        state = self.state(self.agent_loc)
        termination = False
        self.reward_state_term = (reward, state, termination)
        return self.reward_state_term
    
    def env_step(self, action):
        
        if action == 0: # Up action
            possible_next_loc = (self.agent_loc[0] - 1, self.agent_loc[1])
            if possible_next_loc[0] >= 0:
                self.agent_loc = possible_next_loc
            else:
                pass
        elif action == 1: # Left action
            possible_next_loc = (self.agent_loc[0], self.agent_loc[1] - 1)
            if possible_next_loc[1] >= 0:
                self.agent_loc = possible_next_loc
            else:
                pass
        elif action == 2: # Down action
            possible_next_loc = (self.agent_loc[0] + 1, self.agent_loc[1])
            if possible_next_loc[0] < self.grid_h:
                self.agent_loc = possible_next_loc
            else:
                pass
        elif action == 3: # Right action
            possible_next_loc = (self.agent_loc[0], self.agent_loc[1] + 1)
            if possible_next_loc[1] < self.grid_w:
                self.agent_loc = possible_next_loc
            else:
                pass
        else:
            raise Exception("Action not recognised!")
            
        reward = -1
        terminal = False
        
        if self.agent_loc == self.goal_loc:
            terminal = True
        elif self.agent_loc in self.teleport:
            self.agent_loc = self.start_loc
        else:
            pass
        
        self.reward_state_term = (reward, self.state(self.agent_loc), terminal)
        return self.reward_state_term
    
    def env_cleanup(self):
        self.agent_loc = self.start_loc

        

In [343]:
class BaseTD0Agent():
    def agent_init(self, agent_info = {}):
        self.gamma = agent_info.get('gamma') # Discount factor
        self.alpha = agent_info.get('alpha') # Step Size
        self.epsilon = agent_info.get('epsilon') # E greedy policy
        
        self.n_actions = agent_info.get('n_actions')
        self.n_states = agent_info.get('n_states')
        
        self.Q = np.random.uniform(size = (self.n_states, self.n_actions))
        self.policy = np.zeros((self.n_states, self.n_actions))
        
    def get_e_greedy_policy(self):
        policy = np.zeros(self.Q.shape)
        # For each state
        for i in range(policy.shape[0]):
            policy[i, :] = self.epsilon/self.n_actions
            arg_max = self.Q[i, :].argmax()
            policy[i, arg_max] += 1 - self.epsilon
        return policy
        
    def get_greedy_policy(self):
        policy = np.zeros(self.Q.shape)
        # For each state
        for i in range(policy.shape[0]):
            arg_max = self.Q[i, :].argmax()
            policy[i, arg_max] = 1
        return policy
            
    def agent_start(self, state):
        # Choose new action
        self.policy = self.get_e_greedy_policy()
        
        action = np.random.choice(range(self.n_actions), p = self.policy[state])
        self.last_state = state
        self.last_action = action
        return action
    
    def agent_step(self, reward, state):
        pass
    
    def agent_end(self, reward):
        # Update values when in terminal state for sarsa
        target = reward
        self.Q[self.last_state, self.last_action] += self.alpha * (target - self.Q[self.last_state, self.last_action])
        
        
    def agent_cleanup(self):
        self.last_state = None
        self.last_action = None
        
    def agent_message(self, message):
        if message == 'get_Q':
            return self.Q
        elif message == 'check_policy':
            policy = self.get_greedy_policy()
            
            print("State 0: ", self.get_action_name(policy[0, ]))
            print("State 6: ", self.get_action_name(policy[6, ]))
            print("State 12: ", self.get_action_name(policy[12, ]))
            print("State 13: ", self.get_action_name(policy[13, ]))
            print("State 14: ", self.get_action_name(policy[14, ]))
            print("State 8: ", self.get_action_name(policy[8, ]))
            print("State 2: ", self.get_action_name(policy[2, ]))
            print("State 3: ", self.get_action_name(policy[3, ]))
            print("State 4: ", self.get_action_name(policy[4, ]))
            print("State 10: ", self.get_action_name(policy[10, ]))
            print("State 16: ", self.get_action_name(policy[16, ]))
            print("State 22: ", self.get_action_name(policy[22, ]))

        else:
            raise Exception("SARSA Agent message not understood!")
    
    def get_action_name(self, x):
        ind = x.argmax()
        if ind == 0:
            return 'Up'
        elif ind == 1:
            return 'Left'
        elif ind == 2:
            return 'Down'
        elif ind == 3:
            return 'Right'
        else: 
            'Other'  


In [344]:
class SarsaAgent(BaseTD0Agent):
    def agent_step(self, reward, state):
        # Update Policy
        self.policy = self.get_e_greedy_policy()
        
        # Choose new actions: SARSA -> last_state, last_action, reward, state, action
        action = np.random.choice(range(self.n_actions), p = self.policy[state])
        
        # Update rule for sarsa
        target = reward + self.gamma * self.Q[state, action]
        self.Q[self.last_state, self.last_action] += self.alpha * (target - self.Q[self.last_state, self.last_action])
        
        # Remember current state as last for next step
        self.last_state = state
        self.last_action = action
        return action

In [345]:
class QAgent(BaseTD0Agent):
    def agent_step(self, reward, state):
        
        # Update rule for Q learning
        target = reward + self.gamma * self.Q[state, :].max()
        self.Q[self.last_state, self.last_action] += self.alpha * (target - self.Q[self.last_state, self.last_action])
        
        # Update Policy
        self.policy = self.get_e_greedy_policy()
        
        # Choose new actions:
        action = np.random.choice(range(self.n_actions), p = self.policy[state])
        
        # Remember current state as last for next step
        self.last_state = state
        self.last_action = action
        return action

In [346]:
class ExpectedSarsaAgent(BaseTD0Agent):
    def agent_step(self, reward, state):
        
        # Update rule for Expected SARSA learning
        target = reward + self.gamma * np.dot(self.Q[state, :], self.policy[state, :])
        self.Q[self.last_state, self.last_action] += self.alpha * (target - self.Q[self.last_state, self.last_action])
        
        # Update Policy
        self.policy = self.get_e_greedy_policy()
        
        # Choose new actions:
        action = np.random.choice(range(self.n_actions), p = self.policy[state])
        
        # Remember current state as last for next step
        self.last_state = state
        self.last_action = action
        return action

In [347]:
class Experiment():
    def exp_init(self, env, agent, agent_info):
        self.env = env
        self.agent = agent
        
        self.env.env_init()
        self.agent.agent_init(agent_info)
    
    def run_episode(self):
        r, s, t = self.env.env_start()
        a = self.agent.agent_start(s)

        while t == False:
            r, s, t = self.env.env_step(a)
            if t == False:
                a = self.agent.agent_step(r, s)
            else: #Termination
                self.agent.agent_end(r)
        self.agent.agent_cleanup()
        self.env.env_cleanup()
        
    def run_n_episodes(self, n):
        for i in range(n):
            self.run_episode()

In [348]:
## Sarsa Agent Example
env = Environment()

agent = SarsaAgent()
agent_info = {'gamma': 1, 'alpha': 0.2, 'epsilon': 0.2, 'n_actions': 4, 'n_states': 24}

exp = Experiment()

exp.exp_init(env = env, agent = agent, agent_info = agent_info)

exp.run_n_episodes(150)
agent.agent_message('check_policy')

State 0:  Down
State 6:  Down
State 12:  Right
State 13:  Right
State 14:  Up
State 8:  Up
State 2:  Right
State 3:  Right
State 4:  Down
State 10:  Down
State 16:  Down
State 22:  Right


In [349]:
## Q Agent Example
env = Environment()

agent = QAgent()
agent_info = {'gamma': 1, 'alpha': 0.2, 'epsilon': 0.2, 'n_actions': 4, 'n_states': 24}

exp = Experiment()

exp.exp_init(env = env, agent = agent, agent_info = agent_info)
exp.run_n_episodes(50)
agent.agent_message('check_policy')

State 0:  Down
State 6:  Down
State 12:  Right
State 13:  Right
State 14:  Up
State 8:  Up
State 2:  Right
State 3:  Right
State 4:  Down
State 10:  Down
State 16:  Down
State 22:  Right


In [350]:
## Expected Sarsa Agent Example
env = Environment()

agent = ExpectedSarsaAgent()
agent_info = {'gamma': 1, 'alpha': 0.2, 'epsilon': 0.2, 'n_actions': 4, 'n_states': 24}

exp = Experiment()

exp.exp_init(env = env, agent = agent, agent_info = agent_info)
exp.run_n_episodes(150)
agent.agent_message('check_policy')

State 0:  Down
State 6:  Down
State 12:  Right
State 13:  Right
State 14:  Up
State 8:  Up
State 2:  Right
State 3:  Right
State 4:  Down
State 10:  Down
State 16:  Down
State 22:  Right
