In [1]:
import numpy as np

# GridWorld environment
class GridWorld:
    def __init__(self):
        self.grid = np.zeros((4, 4))  # 4x4 grid
        self.goal = (3, 3)  # Goal state
        self.obstacles = [(1, 1), (2, 2)]  # Obstacles
        self.state = (0, 0)  # Start state

    def reset(self):
        self.state = (0, 0)
        return self.state

    def step(self, action):
        x, y = self.state
        if action == 0:  # Up
            x = max(x - 1, 0)
        elif action == 1:  # Down
            x = min(x + 1, 3)
        elif action == 2:  # Left
            y = max(y - 1, 0)
        elif action == 3:  # Right
            y = min(y + 1, 3)

        # Check for obstacles
        if (x, y) in self.obstacles:
            return self.state, -1, False  # Stay in the same state, negative reward

        self.state = (x, y)

        # Check if goal is reached
        if self.state == self.goal:
            return self.state, 1, True  # Positive reward, episode ends
        else:
            return self.state, -0.1, False  # Small negative reward for each step

In [2]:
# SARSA algorithm
def sarsa(env, episodes=1000, alpha=0.1, gamma=0.9, epsilon=0.1):
    # Initialize Q-table
    Q = np.zeros((4, 4, 4))  # 4x4 grid, 4 actions

    for episode in range(episodes):
        state = env.reset()
        action = epsilon_greedy(Q, state, epsilon)

        done = False
        while not done:
            next_state, reward, done = env.step(action)
            next_action = epsilon_greedy(Q, next_state, epsilon)

            # Update Q-value
            Q[state[0], state[1], action] += alpha * (
                reward + gamma * Q[next_state[0], next_state[1], next_action] - Q[state[0], state[1], action]
            )

            state, action = next_state, next_action

    return Q

# Epsilon-greedy policy
def epsilon_greedy(Q, state, epsilon):
    if np.random.rand() < epsilon:
        return np.random.randint(4)  # Random action
    else:
        return np.argmax(Q[state[0], state[1]])  # Greedy action

In [3]:
# Create the environment
env = GridWorld()

# Train the agent using SARSA
Q = sarsa(env)

# Display the learned Q-values
print("Learned Q-values:")
print(Q)

Learned Q-values:
[[[-7.58203927e-02 -2.21332697e-01 -8.03103998e-02  1.01758607e-01]
  [ 1.85416259e-02 -7.94668587e-01 -9.18628402e-02  2.63921212e-01]
  [ 2.14710696e-01  3.35895095e-01  4.58664887e-02  4.28536254e-01]
  [ 3.42520323e-01  6.04396907e-01  1.71958241e-01  3.59684445e-01]]

 [[-1.76011780e-01  7.63881071e-03 -1.76046386e-01 -6.12182931e-01]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 4.04524455e-04 -2.75168144e-01 -3.76079459e-01  5.74770617e-01]
  [ 3.21126167e-01  7.97530760e-01  2.33385769e-01  4.72334521e-01]]

 [[-1.08612768e-01 -9.78765048e-02 -1.15383760e-01  2.11106793e-01]
  [-2.35828447e-01  4.43081162e-01 -5.69797899e-02 -2.58444638e-01]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 4.49299676e-01  1.00000000e+00 -1.27816915e-01  6.88258539e-01]]

 [[-6.56779679e-02 -5.59311490e-02 -5.60543500e-02  6.21589966e-02]
  [ 7.95661052e-03  4.90866746e-02 -2.15200000e-02  6.36621323e-01]
  [-1.90000000e-01 -1.90