In [1]:
import numpy as np

# --- 1. ENVIRONMENT ---
class GridWorld:
    def __init__(self):
        self.grid_size = 4
        self.terminal_states = [0, 15]
        self.actions = [0, 1, 2, 3] # UP, DOWN, LEFT, RIGHT

    def step(self, state, action):
        if state in self.terminal_states:
            return state, 0, True

        row, col = divmod(state, self.grid_size)

        if action == 0:   row = max(row - 1, 0) # UP
        elif action == 1: row = min(row + 1, self.grid_size - 1) # DOWN
        elif action == 2: col = max(col - 1, 0) # LEFT
        elif action == 3: col = min(col + 1, self.grid_size - 1) # RIGHT

        next_state = row * self.grid_size + col
        reward = -1
        done = next_state in self.terminal_states
        return next_state, reward, done

    def reset(self):
        start_state = np.random.randint(0, 16)
        while start_state in self.terminal_states:
            start_state = np.random.randint(0, 16)
        return start_state

# --- 2. THE ALGORITHM: SARSA ---
def sarsa_learning():
    env = GridWorld()

    # Parameters
    num_episodes = 5000
    alpha = 0.1   # Learning Rate
    gamma = 1.0   # Discount Factor
    epsilon = 0.1 # Exploration Rate

    # Initialize Q-Table (16 States x 4 Actions)
    Q = np.zeros((16, 4))

    print("Training with SARSA (5000 Episodes)...")

    for _ in range(num_episodes):
        state = env.reset()

        # SARSA Step 1: Choose Action A (Epsilon-Greedy) BEFORE the loop
        if np.random.rand() < epsilon:
            action = np.random.choice(env.actions)
        else:
            action = np.argmax(Q[state])

        done = False

        while not done:
            # SARSA Step 2: Take Action A, observe R, S'
            next_state, reward, done = env.step(state, action)

            # SARSA Step 3: Choose Next Action A' (Epsilon-Greedy) based on S'
            # Note: We pick the next action NOW, before updating
            if np.random.rand() < epsilon:
                next_action = np.random.choice(env.actions)
            else:
                next_action = np.argmax(Q[next_state])

            # SARSA Step 4: Update Q(S, A) using Q(S', A')
            # Formula: Q(s,a) = Q(s,a) + alpha * [ R + gamma * Q(s',a') - Q(s,a) ]

            # Value of next state (0 if terminal)
            q_next = 0 if done else Q[next_state, next_action]

            target = reward + gamma * q_next
            Q[state, action] += alpha * (target - Q[state, action])

            # SARSA Step 5: Move to next state pair
            state = next_state
            action = next_action

    return Q

# --- 3. EXECUTION & RESULTS ---
def print_policy(Q):
    actions_map = {0: '↑', 1: '↓', 2: '←', 3: '→'}
    print("\nFinal Policy (SARSA):")
    print("-" * 17)

    grid_output = []
    for s in range(16):
        if s in [0, 15]:
            grid_output.append(" T ")
            continue
        best_action = np.argmax(Q[s])
        grid_output.append(f" {actions_map[best_action]} ")

    for i in range(0, 16, 4):
        print("|".join(grid_output[i:i+4]))
        print("-" * 17)

if __name__ == "__main__":
    q_table = sarsa_learning()
    print_policy(q_table)

Training with SARSA (5000 Episodes)...

Final Policy (SARSA):
-----------------
 T | ← | ← | ← 
-----------------
 ↑ | ← | → | ↓ 
-----------------
 ↑ | ↑ | → | ↓ 
-----------------
 → | → | → | T 
-----------------


This code implements the **SARSA (State-Action-Reward-State-Action) reinforcement learning algorithm** to find an optimal policy for navigating a **GridWorld environment**.

Let's break it down:

**1. `GridWorld` Class:**
   - This class defines the environment. It's a 4x4 grid. `terminal_states` (0 and 15) are the goal states. `actions` are UP, DOWN, LEFT, RIGHT.
   - The `step` method takes a current state and an action, then returns the `next_state`, the `reward` (always -1 for non-terminal steps), and whether the episode is `done` (if a terminal state is reached).
   - The `reset` method initializes the agent to a random starting state that is not a terminal state.

**2. `sarsa_learning()` Function:**
   - This is the core of the SARSA algorithm.
   - **Parameters:**
     - `num_episodes`: How many training iterations to run.
     - `alpha` (learning rate): How much to update the Q-value based on the new information.
     - `gamma` (discount factor): How much future rewards are valued.
     - `epsilon` (exploration rate): The probability of choosing a random action (exploration) instead of the best known action (exploitation).
   - **Q-Table:** `Q = np.zeros((16, 4))` is initialized. This table stores the estimated maximum future rewards for taking a specific `action` in a specific `state`.
   - **Learning Loop:** For each episode:
     - The agent starts in a `reset` state.
     - It chooses an `action` using an **epsilon-greedy policy** (mostly exploits the best known action, but sometimes explores randomly).
     - It enters a `while not done` loop:
       - Takes the chosen `action`, observes the `next_state`, `reward`, and whether the episode is `done`.
       - Chooses the `next_action` also using an epsilon-greedy policy based on the `next_state`.
       - **SARSA Update Rule:** The Q-value for the current `(state, action)` pair is updated using the observed `reward` and the Q-value of the `(next_state, next_action)` pair. This is the key difference from Q-learning: SARSA is 'on-policy', meaning it learns the value of the policy it's currently following.
       - The `state` and `action` are then updated to `next_state` and `next_action` to continue the episode.
   - It returns the learned `Q` table.

**3. `print_policy(Q)` Function:**
   - This function takes the final `Q` table and visualizes the optimal `policy` (the best action to take in each state) found by SARSA.
   - For each non-terminal state, it finds the action with the highest Q-value and prints the corresponding arrow (↑, ↓, ←, →). Terminal states are marked 'T'.

**4. `if __name__ == "__main__":` Block:**
   - This ensures that `sarsa_learning()` is called to train the agent, and then `print_policy()` is called to display the results, when the script is run directly.