In [1]:
import gym 
import numpy as np 
from tqdm import trange # Processing Bar

In [2]:
# experiment settings
gamma = 0.9 # discount factor
epsilon=0.2 # exploration parameter
n_episodes = 4000  # number of training episodes
seed = 41684 

alpha = 1.0
gamma = 0.9
epsilon = 1.0

In [3]:
np.random.seed(seed) 

In [4]:
# 0 = south, 1 = north, 2 = east, 3 = west, 4 = pickup, 5 = dropoff
env = gym.make("Taxi-v2").env

In [5]:
n_states, n_actions = env.observation_space.n, env.action_space.n
print('{} states'.format(n_states))
print('{} actions'.format(n_actions))
env.render() 

500 states
6 actions
+---------+
|[35mR[0m: | : :G|
| : : : : |
| :[43m [0m: : : |
| | : | : |
|Y| : |[34;1mB[0m: |
+---------+



In [6]:
# initialize the agent’s Q-table to zeros
def init_q(s, a):
    """
    s: number of states
    a: number of actions
    """
    return np.zeros((s, a))

# epsilon-greedy exploration strategy
def epsilon_greedy(Q, epsilon, n_actions, s):
    """
    Q: Q Table
    epsilon: exploration parameter
    n_actions: number of actions
    s: state
    """
    # selects a random action with probability epsilon
    if np.random.random() <= epsilon:
        return np.random.randint(n_actions)
    else:
        return np.argmax(Q[s, :])
    
# SARSA Process
def q_learning(alpha, gamma, epsilon, n_episodes):
    """
    alpha: learning rate
    gamma: exploration parameter
    n_episodes: number of episodes
    """
    # initialize Q table
    Q = init_q(n_states, n_actions)
    t = trange(n_episodes)
    for i in t:
        # initial state
        s = env.reset() 
        done = False
        while not done:
            a = epsilon_greedy(Q, epsilon, n_actions, s)
            s_, reward, done, _ = env.step(a)
            # update Q table
            Q[s, a] += alpha * (reward + (gamma * max(Q[s_,:])) - Q[s, a])
            if done:
                break
            s= s_
    env.close()
    return Q 

In [7]:
# training


Q = q_learning(alpha, gamma, epsilon, n_episodes)

print("alpha: ", alpha)
print("n_episodes: ", n_episodes)
print("Q(462,4): ",Q[462,4])   
print("Q(398,3): ",Q[398,3])   
print("Q(253,0): ",Q[253,0])   
print("Q(377,1): ",Q[377,1])   
print("Q(83,5): ",Q[83,5])  
print("------------------------------") 

100%|██████████| 4000/4000 [04:05<00:00, 16.30it/s]

alpha:  1.0
n_episodes:  4000
Q(462,4):  -11.374402515013
Q(398,3):  4.348907000000002
Q(253,0):  -0.5856821172999982
Q(377,1):  9.683000000000002
Q(83,5):  -12.82326603716053
------------------------------





In [8]:
print("Q(421,2): ",Q[421,2])   
print("Q(126,0): ",Q[126,0])   
print("Q(343,1): ",Q[343,1])
print("Q(11,3): ",Q[11,3])
print("Q(444,4): ",Q[444,4])   
print("Q(496,1): ",Q[496,1])  
print("Q(257,0): ",Q[257,0])   
print("Q(222,2): ",Q[222,2])   
print("Q(391,3): ",Q[391,3])   
print("Q(82,5): ",Q[82,5])     

Q(421,2):  -3.1369622635116987
Q(126,0):  -3.8232660371605287
Q(343,1):  -2.3744025150129984
Q(11,3):  -2.3744025150129984
Q(444,4):  -12.136962263511698
Q(496,1):  2.9140163000000023
Q(257,0):  5.9432300000000025
Q(222,2):  0.46035320300000193
Q(391,3):  -3.8232660371605287
Q(82,5):  -10.527113905569998
