In [6]:
import random
import numpy as np
from copy import deepcopy
import gym
import torch
from collections import namedtuple, defaultdict

In [7]:
env = gym.make('Taxi-v2')

In [32]:
class TaxiPolicy:    
    def __init__(self):
        # (row, column, passenger_loc, destination, action)
        self.Q = torch.zeros([500, 6], dtype=torch.float32)
        self.Qmem = {}
        self.num_actions = 6

    # Behavioral action
    def action_b(self, state):
        """Chooses random behavioral action."""
        return random.randint(0, self.num_actions - 1)
    def p_action_b(self, state, action):
        """Returns random behavioral action probability"""
        return 1 / self.num_actions
    
    # Action using Q0
    def action(self, state, epsilon=0.0, Q=None):
        action_max_q, action_max_idx = torch.max(self.Q[state], 0)
        if epsilon > 0:
            # If probability is greater than epsilon, choose a random action
            p = random.uniform(0, 1)
            if p > epsilon:
                return random.randint(0, self.num_actions - 1)
        # Otherwise choose greedy max action
        return int(action_max_idx)

    def p_action(self, state, action, epsilon=0.0):
        action_max_q, action_max_idx = torch.max(self.Q[state], 0)
        if action == action_max_idx:
            return epsilon + ((1 - epsilon) / self.num_actions)
        else:
            return (1 - epsilon) / self.num_actions
    
    # Trains DynaQ
    def train(self, env, iterations, epsilon, learning_rate, discount, planning_iterations):
        for i in range(iterations):
            state = env.reset()
            q_update = 0
            a_ = None
            ep_ended = False
            while not ep_ended:
                # If next action is not defined, take action based on current state
                action = a_ if a_ else self.action(state, epsilon)
                
                # (state', reward, ep_ended, prob=1.0)
                s_, r, ep_ended, prob = env.step(action)
                
                # Update memory for planning
                if state not in self.Qmem:
                    self.Qmem[state] = {}
                self.Qmem[state][action] = (s_, r)
                
                # Expected reward from next best action
                update = self.Q[s_, self.action(s_)]
                
                self.Q[state, action] += learning_rate * (r + (discount * update) - self.Q[state, action])
                
                state = s_
                
                for _ in range(planning_iterations):
                    s__ = random.choice(list(self.Qmem.keys()))
                    a__ = random.choice(list(self.Qmem[s__].keys()))
                    s___, r = self.Qmem[s__][a__]
                    
                    update = self.Q[s___, self.action(s___)]
                    
                    self.Q[s__, a__] += learning_rate * (r + (discount * update) - self.Q[s__, a__])

    def play(self, env):
        state = env.reset()
        ep_ended = False
                
        env.render()
        while not ep_ended:
            action = self.action(state, epsilon=0)
            s_, r, ep_ended, prob = env.step(action)
            state = s_
            
            env.render()

In [33]:
t = TaxiPolicy()
t.train(env,
        iterations=1000,
        epsilon=0.9,
        learning_rate=0.1,
        discount=0.9,
        planning_iterations=10)

In [34]:
print(t.Q)
t.play(env)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.5661,  2.9035,  1.5471,  2.8889,  4.3486, -6.1194],
        [ 4.2854,  5.9018,  4.2842,  5.9250,  7.7146, -3.0856],
        ...,
        [ 0.0000,  0.0000,  0.0000,  2.9131, -7.3813, -7.3806],
        [ 1.4466,  2.7467,  1.4710,  2.8746, -7.4487, -7.4995],
        [14.2848, 11.8332, 14.2871, 16.9997,  5.2838,  5.2972]])
+---------+
|[35mR[0m: | : :G|
| :[43m [0m: : : |
| : : : : |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+

+---------+
|[35mR[0m: | : :G|
| : :[43m [0m: : |
| : : : : |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+
  (East)
+---------+
|[35mR[0m: | : :G|
| : : : : |
| : :[43m [0m: : |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+
  (South)
+---------+
|[35mR[0m: | : :G|
| : : : : |
| : : :[43m [0m: |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+
  (East)
+---------+
|[35mR[0m: | : :G|
| : : : : |
| : : : : |
| | : |[43m [0m: |
|Y| : |[34;1mB[0m: |
+---------+
  (South