In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import animation
from IPython.display import HTML
import gymnasium as gym


def plot_policy(action_values, sample_states=None, action_meanings=None):
    """
    Plot the policy for selected states in the Taxi environment.

    Args:
        action_values: (500, 6) array of Q-values
        sample_states: list of state indices to visualize (default: random 25)
        action_meanings: dict mapping action indices to names
    """
    if action_meanings is None:
        action_meanings = {0: 'S', 1: 'N', 2: 'E', 3: 'W', 4: 'P', 5: 'D'}

    if sample_states is None:
        # Sample 25 random states
        sample_states = np.random.choice(500, size=min(25, 500), replace=False)

    # Get best actions for sampled states
    best_actions = action_values[sample_states].argmax(axis=-1)

    # Create grid visualization
    grid_size = int(np.ceil(np.sqrt(len(sample_states))))
    action_grid = np.full((grid_size, grid_size), -1, dtype=int)

    for idx, state_idx in enumerate(sample_states):
        row = idx // grid_size
        col = idx % grid_size
        action_grid[row, col] = best_actions[idx]

    # Convert to action meanings
    action_grid_str = action_grid.astype(object)
    for key in action_meanings:
        action_grid_str[action_grid_str == key] = action_meanings[key]
    action_grid_str[action_grid == -1] = ''

    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(action_grid, annot=action_grid_str, fmt='', cbar=False,
                cmap='coolwarm', annot_kws={'weight': 'bold', 'size': 14},
                linewidths=2, ax=ax)
    ax.set_title(f"Policy for {len(sample_states)} Sample States", size=16)
    ax.axis('off')
    plt.tight_layout()
    plt.show()


def plot_state_values(action_values, sample_states=None):
    """
    Plot state values V(s) = max_a Q(s,a) for the Taxi environment.

    Args:
        action_values: (500, 6) array of Q-values
        sample_states: list of state indices to visualize (default: random 25)
    """
    if sample_states is None:
        sample_states = np.random.choice(500, size=min(25, 500), replace=False)

    # Compute state values
    state_values = action_values[sample_states].max(axis=-1)

    # Create grid
    grid_size = int(np.ceil(np.sqrt(len(sample_states))))
    value_grid = np.full((grid_size, grid_size), np.nan)

    for idx, state_idx in enumerate(sample_states):
        row = idx // grid_size
        col = idx % grid_size
        value_grid[row, col] = state_values[idx]

    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(value_grid, annot=True, fmt=".2f", cmap='coolwarm',
                annot_kws={'weight': 'bold', 'size': 10}, linewidths=2,
                ax=ax, cbar_kws={'label': 'State Value'})
    ax.set_title(f"State Values for {len(sample_states)} Sample States", size=16)
    ax.axis('off')
    plt.tight_layout()
    plt.show()


def plot_action_values_heatmap(action_values, sample_states=None):
    """
    Plot heatmap of action values for sampled states.

    Args:
        action_values: (500, 6) array of Q-values
        sample_states: list of state indices to visualize (default: 20)
    """
    if sample_states is None:
        sample_states = np.random.choice(500, size=min(20, 500), replace=False)

    action_names = ['South', 'North', 'East', 'West', 'Pickup', 'Dropoff']

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(action_values[sample_states], annot=True, fmt=".2f",
                cmap='coolwarm', xticklabels=action_names,
                yticklabels=[f'State {s}' for s in sample_states],
                linewidths=0.5, ax=ax, cbar_kws={'label': 'Q-value'})
    ax.set_title("Action Values Q(s,a) for Sample States", size=16)
    ax.set_xlabel("Actions", size=12)
    ax.set_ylabel("States", size=12)
    plt.tight_layout()
    plt.show()


def display_video(frames):
    """
    Create HTML5 video from frames.

    Args:
        frames: list of RGB arrays
    """
    import matplotlib
    orig_backend = matplotlib.get_backend()
    matplotlib.use('Agg')
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    matplotlib.use(orig_backend)
    ax.set_axis_off()
    ax.set_aspect('equal')
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0])

    def update(frame):
        im.set_data(frame)
        return [im]

    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                   interval=200, blit=True, repeat=False)
    return HTML(anim.to_html5_video())


def test_agent(env, policy, episodes=5, render=True):
    """
    Test a policy in the Taxi environment and optionally create a video.

    Args:
        env: Gymnasium Taxi environment
        policy: function that takes state and returns action or probability distribution
        episodes: number of episodes to run
        render: whether to render frames for video

    Returns:
        HTML video if render=True, otherwise None
    """
    frames = []
    total_rewards = []
    episode_lengths = []

    for episode in range(episodes):
        state, info = env.reset()
        done = False
        episode_reward = 0
        steps = 0

        if render:
            frames.append(env.render())

        while not done:
            # Get action from policy
            p = policy(state)
            if isinstance(p, np.ndarray):
                action = np.random.choice(len(p), p=p)
            else:
                action = p

            # Take step (Gymnasium API)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

            episode_reward += reward
            steps += 1

            if render:
                frames.append(env.render())

            state = next_state

        total_rewards.append(episode_reward)
        episode_lengths.append(steps)

    print(f"Average Reward: {np.mean(total_rewards):.2f} (+/- {np.std(total_rewards):.2f})")
    print(f"Average Episode Length: {np.mean(episode_lengths):.2f} (+/- {np.std(episode_lengths):.2f})")

    if render and frames:
        return display_video(frames)
    return None


def create_greedy_policy(action_values):
    """
    Create a greedy policy from action values.

    Args:
        action_values: (500, 6) array of Q-values

    Returns:
        policy function
    """
    def policy(state):
        return action_values[state].argmax()
    return policy


def create_epsilon_greedy_policy(action_values, epsilon=0.1):
    """
    Create an epsilon-greedy policy from action values.

    Args:
        action_values: (500, 6) array of Q-values
        epsilon: exploration probability

    Returns:
        policy function that returns probability distribution
    """
    n_actions = action_values.shape[1]

    def policy(state):
        probs = np.ones(n_actions) * epsilon / n_actions
        best_action = action_values[state].argmax()
        probs[best_action] += 1.0 - epsilon
        return probs

    return policy


def analyze_action_distribution(action_values):
    """
    Analyze the distribution of best actions across all states.

    Args:
        action_values: (500, 6) array of Q-values
    """
    action_names = ['South', 'North', 'East', 'West', 'Pickup', 'Dropoff']
    best_actions = action_values.argmax(axis=-1)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Count plot
    unique, counts = np.unique(best_actions, return_counts=True)
    axes[0].bar([action_names[i] for i in unique], counts, color='steelblue')
    axes[0].set_title("Distribution of Best Actions Across States", size=14)
    axes[0].set_xlabel("Action", size=12)
    axes[0].set_ylabel("Count", size=12)
    axes[0].tick_params(axis='x', rotation=45)

    # Value distribution
    all_values = action_values.flatten()
    axes[1].hist(all_values, bins=50, color='coral', alpha=0.7, edgecolor='black')
    axes[1].set_title("Distribution of All Q-values", size=14)
    axes[1].set_xlabel("Q-value", size=12)
    axes[1].set_ylabel("Frequency", size=12)
    axes[1].axvline(np.mean(all_values), color='red', linestyle='--',
                    linewidth=2, label=f'Mean: {np.mean(all_values):.2f}')
    axes[1].legend()

    plt.tight_layout()
    plt.show()

    print("\nAction Statistics:")
    for i, name in enumerate(action_names):
        count = np.sum(best_actions == i)
        pct = 100 * count / len(best_actions)
        print(f"{name:8s}: {count:3d} states ({pct:5.2f}%)")



In [None]:
env = gym.make('Taxi-v3',render_mode = "rgb_array")

In [None]:
env.reset()
frame = env.render()
plt.imshow(frame)

<matplotlib.image.AxesImage at 0x78190706c470>

In [None]:
print(f"Observation shape for the environment is {env.observation_space}")
print(f'Number of actions : {env.action_space.n}')

Observation shape for the environment is Discrete(500)
Number of actions : 6


In [None]:
action_values_proper = np.zeros((500,6))

In [None]:
def target_policy(state):
  action_to_take = []
  max_qsa_action = np.max(action_values_proper[state])
  for index,action in enumerate(action_values_proper[state]):
    if action == max_qsa_action:
      action_to_take.append(index)

  return np.random.choice(action_to_take)





In [None]:
def behaviour_policy(state,epsilon):
  random_number = np.random.random()

  if random_number < epsilon:
    return np.random.randint(6)

  else:
    return target_policy(state)

In [None]:
def Q_learning(behaviour_policy,target_policy,action_values_proper,episodes=500,gamma=0.99,alpha=0.15,epsilon = 0.2):


  episode_path = []

  for episode in range(1,episodes+1):
    state,info = env.reset()
    terminated = False
    truncated = False
    while (terminated == False) and (truncated == False):
      action = behaviour_policy(state,epsilon)
      next_state, reward, terminated, truncated, info = env.step(action)
      episode_path.append([state,action,reward])
      next_action = target_policy(next_state)
      action_values_proper[state][action] = action_values_proper[state][action] + alpha * (reward + gamma * action_values_proper[next_state][next_action] - action_values_proper[state][action])
      state = next_state



In [None]:
Q_learning(behaviour_policy,target_policy,action_values_proper,8000,0.995,0.3,0.7)

In [None]:
policy = create_greedy_policy(action_values_proper)

In [None]:
video = test_agent(env, policy, episodes=30)
video

Average Reward: 8.33 (+/- 2.30)
Average Episode Length: 12.67 (+/- 2.30)


  return datetime.utcnow().replace(tzinfo=utc)
