In [98]:
%matplotlib notebook
import random
import numpy as np
from copy import deepcopy
import gym
import torch
from collections import namedtuple, defaultdict
import matplotlib.pyplot as pp

In [99]:
env = gym.make('Taxi-v2')

In [126]:
class TaxiPolicy:    
    def __init__(self):
        # (row, column, passenger_loc, destination, action)
        self.Q = torch.zeros([500, 6], dtype=torch.float32)
        self.num_actions = 6

    def action(self, state, epsilon=0.0, Q=None):
        if Q is None:
            Q = self.Q
    
        action_max_q, action_max_idx = torch.max(Q[state], 0)
        if epsilon > 0:
            # If probability is greater than epsilon, choose a random action
            p = random.uniform(0, 1)
            if p > epsilon:
                return random.randint(0, self.num_actions - 1)
        # Otherwise choose greedy max action
        return int(action_max_idx)
    
    def sarsa_update(self, state, action, s_, r):
        return learning_rate * (r + (discount * self.Q[s_, self.action(s_)]) - self.Q[state, action])
        
    def train(self, env, iterations, epsilon, learning_rate, discount, algo):
        if algo == 'double-q-learning':
            Q2 = torch.zeros([500, 6], dtype=torch.float32)

        for i in range(iterations):
            state = env.reset()
            a_ = None
            ep_ended = False
            while not ep_ended:
                # If next action is not defined, take action based on current state
                action = a_ if a_ else self.action(state, epsilon)
                
                # (state', reward, ep_ended, prob=1.0)
                s_, r, ep_ended, prob = env.step(action)
                
                if algo == 'sarsa':
                    # Uses next action
                    a_ = self.action(s_, epsilon)
                    update = self.Q[s_, a_]
                elif algo == 'q-learning':
                    update = self.Q[s_, self.action(s_)]
                elif algo == 'expected-sarsa':
                    # Gives all actions (1 - epsilon) / num_action probabilities
                    next_action_probs = torch.ones([self.num_actions]).fill_((1 - epsilon) / self.num_actions)
                    
                    # Adds epsilon probability to best action
                    next_action_probs[self.action(s_)] += epsilon
                    
                    update = torch.sum(next_action_probs * self.Q[s_, :])
                elif algo == 'double-q-learning'
                    q_update = random.randint(0, 2)
                    if q_update == 0:
                        update = Q2[s_, self.action(s_, Q=self.Q)]
                    else:
                        update = self.Q[s_, self.action(s_, Q=Q2)]
                    
                else:
                    raise Exception('Invalid algo')
                    
                self.Q[state, action] += learning_rate * (r + (discount * update) - self.Q[state, action])
                state = s_

    def play(self, env):
        state = env.reset()
        ep_ended = False
                
        env.render()
        while not ep_ended:
            action = self.action(state, epsilon=0)
            s_, r, ep_ended, prob = env.step(action)
            state = s_
            
            env.render()

In [127]:
t = TaxiPolicy()
t.train(env, iterations=10000, epsilon=0.9, learning_rate=0.1, discount=0.9, algo='expected-sarsa')

In [133]:
random.randint(0, 2)

0

In [128]:
print(t.Q)
t.play(env)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.6579, -0.7136, -1.6771, -0.1949,  1.6689, -8.0344],
        [ 0.8803,  2.2542, -0.1101,  2.2944,  5.4320, -5.2717],
        ...,
        [-1.2585,  4.4897, -1.1682, -1.2324, -1.9044, -1.9918],
        [-2.5955, -2.5871, -2.5888, -0.6030, -3.8501, -3.7924],
        [ 0.5403,  0.8449, -0.1032, 15.9046, -1.0015, -1.0000]])
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|Y| : |[34;1mB[0m:[43m [0m|
+---------+

+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|Y| : |[34;1m[43mB[0m[0m: |
+---------+
  (West)
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : | : |
|Y| : |[42mB[0m: |
+---------+
  (Pickup)
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : : : |
| | : |[42m_[0m: |
|Y| : |B: |
+---------+
  (North)
+---------+
|R: | : :[35mG[0m|
| : : : : |
| : : :[42m_[0m: |
| | : | : |
|Y| : |B: |
+---------+
  (North)
+---------+
|R: | : :[35mG[