In [1]:
import random

ROUND_TO = 1

class Observation:
    
    def __init__(self, state):
        self.cart_pos = round(state[0], ROUND_TO)
        self.cart_velocity = round(state[1], ROUND_TO)
        self.pole_pos = round(state[2], ROUND_TO)
        self.pole_velocity = round(state[3], ROUND_TO)
        
    def __repr__(self):
        return "({}, {}, {}, {})".format(self.cart_pos, self.cart_velocity, self.pole_pos, self.pole_velocity)
    
    def __hash__(self):
        return hash(repr(self))
    
    def __eq__(self, other):
        return (self.cart_pos == other.cart_pos
            and self.cart_velocity == other.cart_velocity
            and self.pole_pos == other.pole_pos
            and self.pole_velocity == other.pole_velocity)
            
def epsilon_greedy(state, q_values, actions, epsilon):
    if random.random() < epsilon:
        return random.choice(actions)
    
    
    current_max = None
    current_actions = []

    for a in actions:
        val = q_values.get((state, a), 0)
        if current_max == None or val > current_max:
            current_max = val
            current_actions = [a]
        elif val == current_max:
            current_actions.append(a)
    
    return random.choice(current_actions)

In [2]:
import gym

ACTIONS = [0, 1, 2, 3, 4, 5]
OBS_SPACE_BUCKETS = 10

Q = {}

EPISODES = 100000
NUM_LOGS = 10
ALPHA = 0.5
EPSILON = 0.1
GAMMA = 1

policy = epsilon_greedy
environment = gym.make('Taxi-v2')

total_reward = 0

for i in range(0, EPISODES):

    s = environment.reset()
    a = policy(s, Q, ACTIONS, EPSILON)
    done = False


    while not done:
        s_prime, r, done, info = environment.step(a)
        a_prime = policy(s_prime, Q, ACTIONS, EPSILON)
        
        Q[(s, a)] = Q.get((s, a), 0) + ALPHA * (r + GAMMA * Q.get((s_prime, a_prime), 0) - Q.get((s, a), 0))
        
        s = s_prime
        a = a_prime
        
        total_reward += r
    
    if i % (EPISODES // NUM_LOGS) == 0:
        print("Episode {} with average reward {}".format(i, total_reward / (EPISODES // NUM_LOGS)))
        total_reward = 0

environment.close()

Episode 0 with average reward -0.659
Episode 1000 with average reward -46.028
Episode 2000 with average reward -0.704
Episode 3000 with average reward 0.016
Episode 4000 with average reward 0.412
Episode 5000 with average reward -0.638
Episode 6000 with average reward -0.344
Episode 7000 with average reward -0.646
Episode 8000 with average reward 0.692
Episode 9000 with average reward -0.763
Episode 10000 with average reward -0.496
Episode 11000 with average reward -1.146
Episode 12000 with average reward -1.003
Episode 13000 with average reward -1.345
Episode 14000 with average reward -1.977
Episode 15000 with average reward -1.802
Episode 16000 with average reward -2.456
Episode 17000 with average reward -0.63
Episode 18000 with average reward -0.23
Episode 19000 with average reward -1.206
Episode 20000 with average reward 0.436
Episode 21000 with average reward -0.409
Episode 22000 with average reward -0.911
Episode 23000 with average reward 0.285
Episode 24000 with average reward -

In [9]:
from gym import wrappers

valid_env = gym.make('Taxi-v2')
# valid_env = wrappers.Monitor(valid_env, "./gym-results", force=True)

valid_s = valid_env.reset()
valid_a = policy(s, Q, ACTIONS, EPSILON)
valid_done = False
valid_reward = 0
valid_timesteps = 0

while not valid_done:
    valid_env.render()
    valid_s_prime, valid_r, valid_done, valid_info = valid_env.step(valid_a)
    valid_a_prime = policy(valid_s_prime, Q, ACTIONS, 0)

#     valid_a_prime = 0 if valid_s_prime.cart_velocity > 0 else 1

    valid_s = valid_s_prime
    valid_a = valid_a_prime

    valid_reward += r
    valid_timesteps += 1
#     print(valid_a)

print(valid_reward)
valid_env.close()

+---------+
|[34;1m[43mR[0m[0m: | : :G|
| : : : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+

+---------+
|[34;1mR[0m: | : :G|
|[43m [0m: : : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (South)
+---------+
|[34;1m[43mR[0m[0m: | : :G|
| : : : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (North)
+---------+
|[42mR[0m: | : :G|
| : : : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (Pickup)
+---------+
|R: | : :G|
|[42m_[0m: : : : |
| : : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (South)
+---------+
|R: | : :G|
| : : : : |
|[42m_[0m: : : : |
| | : | : |
|[35mY[0m| : |B: |
+---------+
  (South)
+---------+
|R: | : :G|
| : : : : |
| : : : : |
|[42m_[0m| : | : |
|[35mY[0m| : |B: |
+---------+
  (South)
+---------+
|R: | : :G|
| : : : : |
| : : : : |
| | : | : |
|[35m[42mY[0m[0m| : |B: |
+---------+
  (South)
160


In [4]:
Q

{(189, 0): -10.277526159906142,
 (289, 1): -14.886870178787529,
 (189, 4): -19.048181916311847,
 (189, 3): -10.359971952541123,
 (169, 3): -11.911288078866276,
 (149, 4): -16.853207900178603,
 (149, 2): -10.838006073160233,
 (169, 0): -3.17471213210736,
 (269, 5): -13.702518191502104,
 (269, 0): -11.227389377273697,
 (369, 3): -9.445001765240514,
 (369, 4): -14.550609233690194,
 (369, 5): -14.43559079727834,
 (369, 0): -8.84948417075476,
 (469, 2): -16.71099417394884,
 (489, 5): -19.503105341403348,
 (489, 0): -13.646323773905845,
 (489, 1): -8.063731949970311,
 (389, 3): -5.7020296836036435,
 (369, 2): -13.50219417028538,
 (389, 5): -19.195978364176696,
 (389, 0): -14.130238591502653,
 (489, 2): -16.004524100234093,
 (489, 4): -20.970003383574237,
 (489, 3): -16.294426560737765,
 (469, 5): -16.383183288335136,
 (469, 3): -14.29555172500039,
 (469, 0): -16.418592921370426,
 (469, 1): -4.401951635091303,
 (369, 1): -4.636290925869411,
 (269, 3): -3.484884810995764,
 (249, 5): -13.119217