Step 1: Import Required Libraries

In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt

Step 2: Define the Environment

In [None]:
# Grid Size
grid_size = 5
q_table = np.zeros((grid_size, grid_size, 4))  # Q-values for 4 actions

# Reward Table
rewards = np.full((grid_size, grid_size), -1)  # Default reward -1
rewards[4, 4] = 100  # Goal reward

# Actions: 0 = Up, 1 = Down, 2 = Left, 3 = Right
actions = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}

# Function to check if move is valid
def is_valid_move(x, y):
    return 0 <= x < grid_size and 0 <= y < grid_size

Step 3: Define SARSA Algorithm

In [None]:
# SARSA Parameters
alpha = 0.1  # Learning rate
gamma = 0.9  # Discount factor
epsilon = 0.1  # Exploration rate
episodes = 500  # Training iterations

for episode in range(episodes):
    x, y = 0, 0  # Start position
    action = np.random.choice(list(actions.keys()))  # Initial action

    while (x, y) != (4, 4):  # Until goal is reached
        dx, dy = actions[action]
        new_x, new_y = x + dx, y + dy
        
        if is_valid_move(new_x, new_y):
            reward = rewards[new_x, new_y]
            next_action = np.random.choice(list(actions.keys())) if random.uniform(0, 1) < epsilon else np.argmax(q_table[new_x, new_y])

            # SARSA Q-Update Formula
            q_table[x, y, action] = q_table[x, y, action] + alpha * (reward + gamma * q_table[new_x, new_y, next_action] - q_table[x, y, action])

            x, y = new_x, new_y
            action = next_action  # Update action


Step 4: Visualize the Learned Policy

In [None]:
plt.imshow(np.max(q_table, axis=2), cmap="coolwarm", interpolation="nearest")
plt.colorbar(label="Q-Value Strength")
plt.title("SARSA Q-Table Heatmap (Learned Policy)")
plt.show()

A heatmap showing how well the agent has learned to navigate to the goal!