In [61]:
import numpy as np
import gym
import random
import matplotlib.pyplot as plt
import time
from IPython.display import clear_output

In [69]:
class Agent:
    def __init__(self, env):
        self.q_table = np.zeros((env.observation_space.n, env.action_space.n))
        
        self.alpha = 0.9 # learning rate
        self.gamma = 0.96 # discount factor
        self.epsilon = 0.85 # exploration rate
        self.action = None
        
        self.current_state = env.reset()
        print("state in constructor: ", self.current_state)
        
        self.n_episodes = 100000 # Episodes to play
        self.n_steps = 200 # Max steps in an episode
        
    def learn(self, new_state, reward):
        self.reward = reward
        
        self.q_table[self.current_state, self.action] = self.q_table[self.current_state, self.action] + \
        self.alpha*(reward + self.gamma*np.max(self.q_table[new_state, :]) - \
                    self.q_table[self.current_state, self.action])
        
    def selectAction(self):
        if(random.uniform(0, 1) < self.epsilon):
            return env.action_space.sample()
        else:
            return np.argmax(self.q_table[self.current_state, :])
        
    def updateParameters(self):        
        self.epsilon = np.maximum(self.epsilon - 0.05, 0.1)
        #self.alpha = np.maximum(self.alpha - 0.05, 0.6)

In [70]:
env = gym.make("Taxi-v3")
agent = Agent(env)

state in constructor:  324


In [66]:
# Random play
env.reset()

for i in range(100):
    action = env.action_space.sample()
    state, reward, done, info = env.step(action)
    env.render()
    if(done):
        break
    time.sleep(0.01)
    clear_output(wait=True)

+---------+
|[35mR[0m: | : :G|
| : | : : |
| : : : : |
| | : | : |
|Y| : |[34;1mB[0m:[43m [0m|
+---------+
  (East)


In [None]:
#agent = Agent(env)

for i in range(agent.n_episodes):
    agent.state = env.reset()
    for j in range(agent.n_steps):        
        
        agent.action = agent.selectAction()
        state, reward, done, info = env.step(agent.action)
        agent.learn(state, reward)
        agent.current_state = state

        env.render()
        
        if(done):
            break
        
        print(agent.q_table)
        time.sleep(0.01)
        clear_output(wait=True)
    
    #if ((i%10000) == 0):
        #agent.updateParameters()


+---------+
|[35mR[0m:[43m [0m| : :[34;1mG[0m|
| : | : : |
| : : : : |
| | : | : |
|Y| : |B: |
+---------+
  (East)
[[ -4.27944662  -2.76725068  -3.62591483  -4.70965882 -12.02536308
  -13.83504594]
 [ -5.99706579  -5.42260838  -6.20273021  -5.40374235  -4.60872152
  -14.42420685]
 [  8.35138805   8.57277084   3.14617107   9.62650743  11.6663391
   -0.05943873]
 ...
 [ -2.68924824  -1.95808774  -2.68847064  -3.35906412 -11.85479195
  -10.8539136 ]
 [ -2.63070396  -1.93311     -1.939896    -1.939896   -11.6210736
  -10.9394496 ]
 [  0.          -1.862136    -0.9         -0.99         0.
   -9.        ]]


In [68]:
print(agent.epsilon, agent.alpha)

0.34999999999999976 0.6
