Solve Gymnasium Taxi environment

In [111]:
import numpy as np
import gymnasium as gym
env = gym.make('Taxi-v3',is_rainy=True,fickle_passenger=True)
env=env.unwrapped



Passenger locations:

0: Red

1: Green

2: Yellow

3: Blue

4: In taxi

Destinations:

0: Red

1: Green

2: Yellow

3: Blue

An observation is returned as an int() that encodes the corresponding state, calculated by ((taxi_row * 5 + taxi_col) * 5 + passenger_location) * 4 + destination



0: Move south (down)

1: Move north (up)

2: Move east (right)

3: Move west (left)

4: Pickup passenger

5: Drop off passenger


In [None]:
# get environment details
num_states = env.observation_space.n  # 16 cells
num_actions = env.action_space.n  # 4 actions
P = env.P  # transition probabilities: {state: [(trans. prob., next state, reward, done), ...]}
print(P)
# hyperparameters
gamma = 0.99  # discount factor
theta = 1e-3  # convergence threshold


{0: {0: [(0.8, 100, -1, False), (0.1, 0, -1, False), (0.1, 20, -1, False)], 1: [(0.8, 0, -1, False), (0.1, 0, -1, False), (0.1, 20, -1, False)], 2: [(0.8, 20, -1, False), (0.1, 100, -1, False), (0.1, 0, -1, False)], 3: [(0.8, 0, -1, False), (0.1, 0, -1, False), (0.1, 0, -1, False)], 4: [(1.0, 16, -1, False)], 5: [(1.0, 0, -10, False)]}, 1: {0: [(0.8, 101, -1, False), (0.1, 1, -1, False), (0.1, 21, -1, False)], 1: [(0.8, 1, -1, False), (0.1, 1, -1, False), (0.1, 21, -1, False)], 2: [(0.8, 21, -1, False), (0.1, 101, -1, False), (0.1, 1, -1, False)], 3: [(0.8, 1, -1, False), (0.1, 1, -1, False), (0.1, 1, -1, False)], 4: [(1.0, 17, -1, False)], 5: [(1.0, 1, -10, False)]}, 2: {0: [(0.8, 102, -1, False), (0.1, 2, -1, False), (0.1, 22, -1, False)], 1: [(0.8, 2, -1, False), (0.1, 2, -1, False), (0.1, 22, -1, False)], 2: [(0.8, 22, -1, False), (0.1, 102, -1, False), (0.1, 2, -1, False)], 3: [(0.8, 2, -1, False), (0.1, 2, -1, False), (0.1, 2, -1, False)], 4: [(1.0, 18, -1, False)], 5: [(1.0, 2, 

In [113]:
# initialise the value function with all zeros
V = np.zeros(num_states)
# initialise the policy with all zeros
policy = np.zeros(num_states, dtype=int)


In [114]:
# Value Iteration Algorithm
while True:
    delta = 0  # to track convergence

    # loop over all states
    for state in range(num_states):
        ### calculate the q-values ###
        v = V[state]
        vals = np.zeros(len(P[state].items()))
        vals[:] = -np.inf
        for idx, p in P[state].items(): #different action possible from state state
            val =0
            for pp in p: #possible outcomes of choosing an action
                prob = pp[0]
                next_state = pp[1]
                reward=pp[2]
                done = pp[3]
                val += prob*(reward+gamma*V[next_state])
            vals[idx]=val
        V[state]=np.max(vals)
        # update the delta
        delta = max(delta, abs(v - V[state]))
    
    # check for convergence
    if delta < theta:
        break
print(V)

### determine the optimal policy ###
for state in range(num_states):
    vals = np.zeros(len(P[state].items()))
    vals[:] = -np.inf
    for idx, p in P[state].items():
        val =0
        for pp in p:
            prob = pp[0]
            next_state = pp[1]
            reward=pp[2]
            done = pp[3]
            val += prob*(reward+gamma*V[next_state])
        vals[idx]=val
    policy[state]=np.argmax(vals)
print(policy)


[944.27715813 842.19068364 888.48840145 851.96693135 751.9094434
 842.19068364 746.9564528  786.67393437 837.85981672 789.86149541
 888.48840145 799.09415423 767.84765071 794.40213191 762.80241794
 851.96693135 954.83438655 851.71719183 898.48253541 861.59247827
 931.25402472 830.44033599 876.16078244 840.09461387 758.82144613
 849.83450855 753.82869343 793.86824254 828.50723453 780.98755665
 878.63093966 790.12814309 774.88879266 801.65876789 769.80305001
 859.69025633 941.67963909 859.4381177  888.52541131 869.39370389
 866.54402743 772.05238977 814.90517359 781.10027548 818.14509165
 915.44207923 812.8092501  855.61378985 811.71100676 765.05046727
 860.92810194 774.02559779 792.48727256 819.79580563 787.29990178
 878.99404257 876.31597315 925.70825103 870.64368916 888.89236724
 855.94671331 762.49096705 804.87394799 771.43963103 831.00539289
 929.6643524  825.5953448  868.99921243 801.71483324 755.56580525
 850.3923211  764.44251159 799.04479323 826.55402379 793.81973585
 886.186861

In [118]:
# evaluate the policy over a number of episodes
def evaluate_policy(env, policy, episodes=100):
    success = 0
    total_rewards = 0
    for _ in range(episodes):
        observation, info = env.reset()
        episode_over = 0
        total_reward =  0

        while not episode_over:
            # Take the action and see what happens
            action = policy[observation]
            observation, reward, terminated, truncated, info = env.step(action)
            total_reward += reward
            if(truncated):
                episode_over=1
                print("truncated")
            if(terminated):
                episode_over=2
                # print(observation, reward, terminated, truncated, info)
                
            # print(episode_over,reward)
        if(episode_over==2 and reward == 20): # do not look at total reward, look at the reward in the last step, 20 in that case (dependent on the problem / environment)
            success+=1
        total_rewards += total_reward
    total_reward=total_rewards/episodes
    return (success / episodes),total_reward

env = gym.make('Taxi-v3')
success_rate,reward = evaluate_policy(env, policy,10000)
print(f'Reward:{reward} SuccessRate: {success_rate}')

env = gym.make('Taxi-v3', is_rainy=True)
success_rate,reward = evaluate_policy(env, policy,10000)
print(f'Reward:{reward} SuccessRate: {success_rate}')

env = gym.make('Taxi-v3', fickle_passenger=True)
success_rate,reward = evaluate_policy(env, policy,10000)
print(f'Reward:{reward} SuccessRate: {success_rate}')


env = gym.make('Taxi-v3',is_rainy=True,fickle_passenger=True)
success_rate,reward = evaluate_policy(env, policy,10000)
print(f'Reward:{reward} SuccessRate: {success_rate}')

Reward:7.9519 SuccessRate: 1.0
Reward:4.4838 SuccessRate: 1.0
Reward:8.3648 SuccessRate: 1.0
Reward:5.0454 SuccessRate: 1.0
