# SARSA: The Cautious Learner

Welcome to SARSA - the algorithm that learns to be safe by considering its own mistakes!

## What You'll Learn

By the end of this notebook, you'll understand:
- The SARSA update rule (with a cautious driver analogy!)
- On-policy vs off-policy: the crucial difference
- Why SARSA is "safer" than Q-learning
- The famous cliff walking example
- When to use SARSA vs Q-learning

**Prerequisites:** Notebook 3 (Q-Learning)

**Time:** ~30 minutes

---
## The Big Picture: The Two Drivers

Imagine two drivers learning a mountain road with cliffs:

```
    ┌────────────────────────────────────────────────────────────────┐
    │          SARSA vs Q-LEARNING: THE TWO DRIVERS                  │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Q-LEARNING = The Aggressive Driver                           │
    │    "The shortest path is right along the cliff edge."         │
    │    "Sure, I might slip sometimes, but the optimal             │
    │     route is the fastest one!"                                │
    │                                                                │
    │    → Learns the OPTIMAL path (ignoring exploration risk)      │
    │    → Assumes future actions are PERFECT (greedy)              │
    │    → Falls off cliff occasionally during learning             │
    │                                                                │
    │  SARSA = The Cautious Driver                                  │
    │    "I know I sometimes make mistakes."                        │
    │    "I'll stay away from the cliff edge because                │
    │     I might accidentally steer wrong!"                        │
    │                                                                │
    │    → Learns a SAFE path (accounting for own mistakes)         │
    │    → Assumes future actions include EXPLORATION               │
    │    → Avoids dangerous situations altogether                   │
    │                                                                │
    │  KEY INSIGHT:                                                 │
    │    Q-learning asks: "What's best if I'm PERFECT?"            │
    │    SARSA asks: "What's best given how I ACTUALLY behave?"    │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle, Circle, FancyArrowPatch, Arrow
from matplotlib.colors import LinearSegmentedColormap
from collections import defaultdict

# Visualize the two drivers
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Q-Learning (aggressive)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('Q-Learning\n"The Aggressive Driver"', fontsize=14, fontweight='bold', color='#d32f2f')

# Draw road and cliff
road = Rectangle((0, 3), 10, 4, facecolor='#e0e0e0', edgecolor='black', linewidth=2)
ax1.add_patch(road)
cliff = Rectangle((0, 0), 10, 3, facecolor='#795548', edgecolor='black', linewidth=2)
ax1.add_patch(cliff)
ax1.text(5, 1.5, 'CLIFF', ha='center', va='center', fontsize=14, color='white', fontweight='bold')

# Q-learning path (along edge)
path_x = [1, 3, 5, 7, 9]
path_y = [5, 3.5, 3.5, 3.5, 5]
ax1.plot(path_x, path_y, 'r-', linewidth=3, marker='o', markersize=10, label='Q-learning path')
ax1.text(5, 4.5, 'Optimal but RISKY!', ha='center', fontsize=11, color='#d32f2f', fontweight='bold')

# Occasional fall
ax1.annotate('', xy=(5, 2), xytext=(5, 3.5),
             arrowprops=dict(arrowstyle='->', lw=2, color='#d32f2f', linestyle='--'))
ax1.text(5.3, 2.5, 'Sometimes\nfalls!', fontsize=9, color='#d32f2f')

ax1.text(5, 9, 'Learns: "Edge is fastest"', ha='center', fontsize=11, style='italic')
ax1.text(5, 8.2, '(Ignores exploration mistakes)', ha='center', fontsize=10, color='#666')

# Right: SARSA (cautious)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('SARSA\n"The Cautious Driver"', fontsize=14, fontweight='bold', color='#388e3c')

# Draw road and cliff
road2 = Rectangle((0, 3), 10, 4, facecolor='#e0e0e0', edgecolor='black', linewidth=2)
ax2.add_patch(road2)
cliff2 = Rectangle((0, 0), 10, 3, facecolor='#795548', edgecolor='black', linewidth=2)
ax2.add_patch(cliff2)
ax2.text(5, 1.5, 'CLIFF', ha='center', va='center', fontsize=14, color='white', fontweight='bold')

# SARSA path (away from edge)
path_x2 = [1, 3, 5, 7, 9]
path_y2 = [5, 5.5, 6, 5.5, 5]
ax2.plot(path_x2, path_y2, 'g-', linewidth=3, marker='o', markersize=10, label='SARSA path')
ax2.text(5, 4.5, 'Longer but SAFE!', ha='center', fontsize=11, color='#388e3c', fontweight='bold')

ax2.text(5, 9, 'Learns: "Stay away from edge"', ha='center', fontsize=11, style='italic')
ax2.text(5, 8.2, '(Accounts for own mistakes)', ha='center', fontsize=10, color='#666')

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("THE KEY DIFFERENCE")
print("="*70)
print("""
Q-LEARNING (Off-Policy):
  Updates using: max Q(s', a')
  → "What if I take the BEST action next?"
  → Learns optimal policy, but risky during learning

SARSA (On-Policy):
  Updates using: Q(s', a') where a' is the ACTUAL next action
  → "What if I take the action I WOULD actually take next?"
  → Learns a policy that accounts for exploration
""")
print("="*70)

---
## The SARSA Update Rule

SARSA gets its name from the quintuple: **(S, A, R, S', A')**

```
    ┌────────────────────────────────────────────────────────────────┐
    │                   THE SARSA UPDATE RULE                        │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Q(s, a) ← Q(s, a) + α × ( r + γ×Q(s', a') - Q(s, a) )        │
    │                              └────┬─────┘                      │
    │                            ACTUAL next action                  │
    │                          (from exploration policy!)            │
    │                                                                │
    │  Compare to Q-Learning:                                       │
    │  Q(s, a) ← Q(s, a) + α × ( r + γ×max_a'Q(s', a') - Q(s, a) )  │
    │                              └─────────┬───────┘               │
    │                                 BEST action                    │
    │                             (ignores exploration!)             │
    │                                                                │
    │  THE NAME: S-A-R-S'-A'                                        │
    │    State → Action → Reward → next State → next Action         │
    │                                                                │
    │  We need ALL FIVE elements to do the update!                  │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the SARSA update components

fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')

ax.text(7, 9.5, 'The SARSA Quintuple: S, A, R, S\', A\'', ha='center', fontsize=16, fontweight='bold')

# Timeline boxes
elements = [
    {'name': 'S', 'desc': 'State', 'x': 1.5, 'color': '#bbdefb', 'edge': '#1976d2'},
    {'name': 'A', 'desc': 'Action', 'x': 4, 'color': '#fff3e0', 'edge': '#f57c00'},
    {'name': 'R', 'desc': 'Reward', 'x': 6.5, 'color': '#c8e6c9', 'edge': '#388e3c'},
    {'name': "S'", 'desc': 'Next State', 'x': 9, 'color': '#bbdefb', 'edge': '#1976d2'},
    {'name': "A'", 'desc': 'Next Action', 'x': 11.5, 'color': '#ffcdd2', 'edge': '#d32f2f'},
]

for i, elem in enumerate(elements):
    # Box
    box = FancyBboxPatch((elem['x']-0.8, 6), 1.6, 2, boxstyle="round,pad=0.1",
                          facecolor=elem['color'], edgecolor=elem['edge'], linewidth=3)
    ax.add_patch(box)
    ax.text(elem['x'], 7.3, elem['name'], ha='center', va='center', fontsize=18, fontweight='bold')
    ax.text(elem['x'], 6.5, elem['desc'], ha='center', fontsize=10)
    
    # Arrow to next
    if i < len(elements) - 1:
        next_x = elements[i+1]['x']
        ax.annotate('', xy=(next_x-0.9, 7), xytext=(elem['x']+0.9, 7),
                    arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Highlight the key difference
highlight = FancyBboxPatch((10.5, 5.8), 2.2, 2.5, boxstyle="round,pad=0.05",
                            facecolor='none', edgecolor='#d32f2f', linewidth=3, linestyle='--')
ax.add_patch(highlight)
ax.text(11.6, 5.4, 'KEY: Uses ACTUAL\nnext action, not max!', ha='center', fontsize=10, color='#d32f2f', fontweight='bold')

# Update equation at bottom
eq_box = FancyBboxPatch((2, 1.5), 10, 2.5, boxstyle="round,pad=0.1",
                         facecolor='#f5f5f5', edgecolor='#333', linewidth=2)
ax.add_patch(eq_box)
ax.text(7, 3.3, 'SARSA Update:', ha='center', fontsize=12, fontweight='bold')
ax.text(7, 2.5, r'Q(S, A) ← Q(S, A) + α × ( R + γ×Q(S\', A\') - Q(S, A) )', 
        ha='center', fontsize=13, family='monospace')
ax.text(7, 1.8, 'Uses all five elements!', ha='center', fontsize=10, color='#666', style='italic')

plt.tight_layout()
plt.show()

---
## On-Policy vs Off-Policy: The Core Distinction

```
    ┌────────────────────────────────────────────────────────────────┐
    │              ON-POLICY vs OFF-POLICY LEARNING                  │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  ON-POLICY (SARSA):                                           │
    │    "Learn about the policy I'm actually following"            │
    │                                                                │
    │    • Target policy = Behavior policy (same!)                  │
    │    • Evaluates: Q^π (value of current policy)                 │
    │    • If I explore, my Q-values account for exploration        │
    │                                                                │
    │  OFF-POLICY (Q-Learning):                                     │
    │    "Learn about the optimal policy while following another"   │
    │                                                                │
    │    • Target policy ≠ Behavior policy (different!)             │
    │    • Evaluates: Q* (value of optimal policy)                  │
    │    • Q-values assume greedy actions, even if we explore       │
    │                                                                │
    │  ANALOGY:                                                     │
    │    ON-POLICY: "How good am I at basketball right now?"        │
    │    OFF-POLICY: "How good would I be if I were perfect?"       │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize on-policy vs off-policy

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: On-policy (SARSA)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('ON-POLICY (SARSA)\n"Same policy for acting and learning"', 
              fontsize=14, fontweight='bold', color='#388e3c')

# Behavior policy box
bp1 = FancyBboxPatch((3, 5), 4, 2, boxstyle="round,pad=0.1",
                      facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax1.add_patch(bp1)
ax1.text(5, 6.3, 'ε-Greedy Policy', ha='center', fontsize=12, fontweight='bold')
ax1.text(5, 5.5, '(Used for BOTH)', ha='center', fontsize=10)

# Arrows
ax1.annotate('', xy=(3.5, 4.9), xytext=(3.5, 4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#388e3c'))
ax1.text(3.5, 3.5, 'Acting', ha='center', fontsize=10, color='#388e3c', fontweight='bold')

ax1.annotate('', xy=(6.5, 4.9), xytext=(6.5, 4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#388e3c'))
ax1.text(6.5, 3.5, 'Learning', ha='center', fontsize=10, color='#388e3c', fontweight='bold')

# Result
ax1.text(5, 2, 'Learns Q^π (value of THIS policy)', ha='center', fontsize=11, 
         style='italic', color='#388e3c')
ax1.text(5, 1.2, 'Accounts for exploration mistakes!', ha='center', fontsize=10, color='#666')

ax1.text(5, 8.5, '✓ Safer learning\n✓ Accounts for own behavior', 
         ha='center', fontsize=10, color='#388e3c')

# Right: Off-policy (Q-Learning)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('OFF-POLICY (Q-Learning)\n"Different policies for acting and learning"', 
              fontsize=14, fontweight='bold', color='#d32f2f')

# Behavior policy
bp2 = FancyBboxPatch((1.5, 5), 3, 2, boxstyle="round,pad=0.1",
                      facecolor='#bbdefb', edgecolor='#1976d2', linewidth=2)
ax2.add_patch(bp2)
ax2.text(3, 6.2, 'ε-Greedy', ha='center', fontsize=11, fontweight='bold')
ax2.text(3, 5.5, 'Behavior Policy', ha='center', fontsize=9)

# Target policy
tp2 = FancyBboxPatch((5.5, 5), 3, 2, boxstyle="round,pad=0.1",
                      facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
ax2.add_patch(tp2)
ax2.text(7, 6.2, 'Greedy', ha='center', fontsize=11, fontweight='bold')
ax2.text(7, 5.5, 'Target Policy', ha='center', fontsize=9)

# Arrows
ax2.annotate('', xy=(3, 4.9), xytext=(3, 4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#1976d2'))
ax2.text(3, 3.5, 'Acting', ha='center', fontsize=10, color='#1976d2', fontweight='bold')

ax2.annotate('', xy=(7, 4.9), xytext=(7, 4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#d32f2f'))
ax2.text(7, 3.5, 'Learning', ha='center', fontsize=10, color='#d32f2f', fontweight='bold')

# Result
ax2.text(5, 2, 'Learns Q* (value of OPTIMAL policy)', ha='center', fontsize=11, 
         style='italic', color='#d32f2f')
ax2.text(5, 1.2, 'Ignores exploration in value estimates!', ha='center', fontsize=10, color='#666')

ax2.text(5, 8.5, '✓ Finds optimal policy\n✗ Risky during learning', 
         ha='center', fontsize=10, color='#d32f2f')

plt.tight_layout()
plt.show()

---
## The Cliff Walking Environment

This classic example perfectly illustrates the difference between SARSA and Q-learning:

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    CLIFF WALKING                               │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐           │
    │  │   │   │   │   │   │   │   │   │   │   │   │   │           │
    │  ├───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┤           │
    │  │   │   │   │   │   │   │   │   │   │   │   │   │           │
    │  ├───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┤           │
    │  │   │   │   │   │   │   │   │   │   │   │   │   │           │
    │  ├───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┤           │
    │  │ S │ C │ C │ C │ C │ C │ C │ C │ C │ C │ C │ G │           │
    │  └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘           │
    │                                                                │
    │  S = Start (3, 0)                                             │
    │  G = Goal (3, 11)                                             │
    │  C = Cliff (-100 reward, sent back to start!)                 │
    │  Regular step = -1 reward                                     │
    │                                                                │
    │  THE DILEMMA:                                                 │
    │    • Shortest path: Right along the bottom (near cliff)       │
    │    • Safest path: Up, across the top, then down               │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
class CliffWalking:
    """
    The Cliff Walking environment.
    
    A 4x12 grid where the agent must go from Start to Goal.
    The bottom row (except Start and Goal) is a cliff.
    Falling off the cliff gives -100 reward and returns to start.
    """
    
    def __init__(self):
        self.height = 4
        self.width = 12
        self.start = (3, 0)    # Bottom left
        self.goal = (3, 11)    # Bottom right
        self.cliff = [(3, i) for i in range(1, 11)]  # Bottom row except start/goal
        
        self.action_names = ['UP', 'RIGHT', 'DOWN', 'LEFT']
        self.action_symbols = ['↑', '→', '↓', '←']
        self.reset()
    
    def reset(self):
        """Reset to start position."""
        self.pos = self.start
        return self.pos
    
    def step(self, action):
        """Take an action and return (next_state, reward, done)."""
        row, col = self.pos
        
        # Move
        if action == 0: row = max(0, row - 1)              # UP
        elif action == 1: col = min(self.width - 1, col + 1)  # RIGHT
        elif action == 2: row = min(self.height - 1, row + 1) # DOWN
        elif action == 3: col = max(0, col - 1)            # LEFT
        
        self.pos = (row, col)
        
        # Check if fell off cliff
        if self.pos in self.cliff:
            self.pos = self.start  # Sent back to start!
            return self.pos, -100, False
        
        # Check if reached goal
        if self.pos == self.goal:
            return self.pos, -1, True
        
        return self.pos, -1, False


# Visualize the cliff walking environment
env = CliffWalking()

fig, ax = plt.subplots(figsize=(14, 5))

# Draw grid
for row in range(env.height):
    for col in range(env.width):
        y = env.height - 1 - row  # Flip y for visualization
        
        if (row, col) == env.start:
            color = '#bbdefb'
            label = 'S'
        elif (row, col) == env.goal:
            color = '#c8e6c9'
            label = 'G'
        elif (row, col) in env.cliff:
            color = '#795548'
            label = 'C'
        else:
            color = 'white'
            label = ''
        
        rect = Rectangle((col, y), 1, 1, facecolor=color, edgecolor='black', linewidth=1)
        ax.add_patch(rect)
        
        text_color = 'white' if (row, col) in env.cliff else 'black'
        ax.text(col + 0.5, y + 0.5, label, ha='center', va='center', 
                fontsize=12, fontweight='bold', color=text_color)

ax.set_xlim(0, env.width)
ax.set_ylim(0, env.height)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('Cliff Walking Environment\n(S = Start, G = Goal, C = Cliff)', fontsize=14, fontweight='bold')

# Legend
ax.text(6, -0.5, 'Cliff: -100 reward, back to start!   |   Regular step: -1 reward', 
        ha='center', fontsize=11, color='#666')

plt.tight_layout()
plt.show()

print("\nThe agent must navigate from S to G without falling off the cliff!")
print("Two possible strategies:")
print("  1. Short path: Along the edge (risky with exploration)")
print("  2. Safe path: Up, across, down (longer but safer)")

---
## Implementing SARSA

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    SARSA ALGORITHM                             │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Initialize Q(s, a) = 0 for all s, a                          │
    │                                                                │
    │  For each episode:                                            │
    │      s ← initial state                                        │
    │      a ← ε-greedy action for s                                │
    │                                                                │
    │      For each step in episode:                                │
    │          Take action a, observe r, s'                         │
    │          a' ← ε-greedy action for s'  ← KEY DIFFERENCE!       │
    │                                                                │
    │          # SARSA update (uses actual next action a')          │
    │          Q(s, a) ← Q(s, a) + α × (r + γ×Q(s', a') - Q(s, a))  │
    │                                                                │
    │          s ← s'                                               │
    │          a ← a'                                               │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
def sarsa(env, n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1, verbose=False):
    """
    SARSA: On-policy TD control.
    
    The key difference from Q-learning:
    We use the ACTUAL next action a' (from ε-greedy policy),
    not the max action.
    
    Args:
        env: The environment
        n_episodes: Number of episodes to train
        alpha: Learning rate
        gamma: Discount factor
        epsilon: Exploration rate
        verbose: Print progress
    
    Returns:
        Q: Learned action-value function
        rewards_history: Total reward per episode
        paths: Saved paths for visualization
    """
    # Initialize Q-values
    Q = defaultdict(lambda: np.zeros(4))
    rewards_history = []
    paths = []
    
    def epsilon_greedy(state):
        """Choose action using ε-greedy policy."""
        if np.random.random() < epsilon:
            return np.random.randint(0, 4)  # Explore
        return np.argmax(Q[state])  # Exploit
    
    for episode in range(n_episodes):
        # Initialize
        state = env.reset()
        action = epsilon_greedy(state)  # Choose FIRST action
        total_reward = 0
        path = [state]
        
        for step in range(200):  # Max steps per episode
            # Take action, observe reward and next state
            next_state, reward, done = env.step(action)
            total_reward += reward
            path.append(next_state)
            
            # ========================================
            # KEY: Choose next action BEFORE update!
            # This is what makes SARSA "on-policy"
            # ========================================
            next_action = epsilon_greedy(next_state)
            
            # ========================================
            # SARSA UPDATE: uses Q(s', a') not max!
            # ========================================
            td_target = reward + gamma * Q[next_state][next_action]
            td_error = td_target - Q[state][action]
            Q[state][action] += alpha * td_error
            
            # Move to next state-action pair
            state = next_state
            action = next_action  # Use the SAME action we chose!
            
            if done:
                break
        
        rewards_history.append(total_reward)
        
        # Save some paths for visualization
        if episode in [0, n_episodes//4, n_episodes//2, n_episodes-1]:
            paths.append((episode, path.copy()))
        
        if verbose and (episode + 1) % 100 == 0:
            avg_reward = np.mean(rewards_history[-50:])
            print(f"Episode {episode+1:4d} | Avg Reward (last 50): {avg_reward:.1f}")
    
    return dict(Q), rewards_history, paths


def q_learning(env, n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1, verbose=False):
    """
    Q-Learning: Off-policy TD control.
    
    Uses max Q(s', a') instead of actual next action.
    """
    Q = defaultdict(lambda: np.zeros(4))
    rewards_history = []
    paths = []
    
    for episode in range(n_episodes):
        state = env.reset()
        total_reward = 0
        path = [state]
        
        for step in range(200):
            # ε-greedy action selection
            if np.random.random() < epsilon:
                action = np.random.randint(0, 4)
            else:
                action = np.argmax(Q[state])
            
            next_state, reward, done = env.step(action)
            total_reward += reward
            path.append(next_state)
            
            # ========================================
            # Q-LEARNING UPDATE: uses MAX, not actual!
            # ========================================
            td_target = reward + gamma * np.max(Q[next_state])
            td_error = td_target - Q[state][action]
            Q[state][action] += alpha * td_error
            
            state = next_state
            
            if done:
                break
        
        rewards_history.append(total_reward)
        
        if episode in [0, n_episodes//4, n_episodes//2, n_episodes-1]:
            paths.append((episode, path.copy()))
        
        if verbose and (episode + 1) % 100 == 0:
            avg_reward = np.mean(rewards_history[-50:])
            print(f"Episode {episode+1:4d} | Avg Reward (last 50): {avg_reward:.1f}")
    
    return dict(Q), rewards_history, paths


# Train both algorithms
print("TRAINING ON CLIFF WALKING")
print("="*60)

env = CliffWalking()

print("\nTraining SARSA...")
Q_sarsa, rewards_sarsa, paths_sarsa = sarsa(env, n_episodes=500, verbose=True)

print("\nTraining Q-Learning...")
Q_qlearn, rewards_qlearn, paths_qlearn = q_learning(env, n_episodes=500, verbose=True)

print("\n" + "="*60)
print("TRAINING COMPLETE!")

In [None]:
# Compare learning curves

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Raw rewards
ax1 = axes[0]
ax1.plot(rewards_sarsa, alpha=0.3, color='#388e3c', label='SARSA (raw)')
ax1.plot(rewards_qlearn, alpha=0.3, color='#d32f2f', label='Q-Learning (raw)')

# Smoothed
window = 20
sarsa_smooth = np.convolve(rewards_sarsa, np.ones(window)/window, mode='valid')
qlearn_smooth = np.convolve(rewards_qlearn, np.ones(window)/window, mode='valid')

ax1.plot(range(window-1, len(rewards_sarsa)), sarsa_smooth, 
         color='#388e3c', linewidth=3, label='SARSA (smoothed)')
ax1.plot(range(window-1, len(rewards_qlearn)), qlearn_smooth, 
         color='#d32f2f', linewidth=3, label='Q-Learning (smoothed)')

ax1.axhline(y=-13, color='gray', linestyle='--', alpha=0.5, label='Optimal (no cliff)')

ax1.set_xlabel('Episode', fontsize=12)
ax1.set_ylabel('Total Reward per Episode', fontsize=12)
ax1.set_title('Learning Curves: SARSA vs Q-Learning', fontsize=14, fontweight='bold')
ax1.legend(loc='lower right')
ax1.grid(True, alpha=0.3)

# Right: Focus on later episodes (after learning)
ax2 = axes[1]
late_sarsa = rewards_sarsa[-100:]
late_qlearn = rewards_qlearn[-100:]

ax2.boxplot([late_sarsa, late_qlearn], labels=['SARSA', 'Q-Learning'])
ax2.axhline(y=-13, color='gray', linestyle='--', alpha=0.5, label='Optimal')

ax2.set_ylabel('Total Reward (last 100 episodes)', fontsize=12)
ax2.set_title('Reward Distribution After Learning', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

# Add mean annotations
ax2.text(1, np.mean(late_sarsa) + 3, f'Mean: {np.mean(late_sarsa):.1f}', 
         ha='center', fontsize=10, color='#388e3c')
ax2.text(2, np.mean(late_qlearn) + 3, f'Mean: {np.mean(late_qlearn):.1f}', 
         ha='center', fontsize=10, color='#d32f2f')

plt.tight_layout()
plt.show()

print("\nOBSERVATIONS:")
print("-"*60)
print(f"SARSA average reward (last 100): {np.mean(late_sarsa):.1f}")
print(f"Q-Learning average reward (last 100): {np.mean(late_qlearn):.1f}")
print("\nQ-Learning has MORE VARIANCE (occasional cliff falls!)")
print("SARSA is MORE CONSISTENT (stays away from cliff)")

In [None]:
# Visualize the learned paths!

def extract_greedy_path(Q, env):
    """Extract the greedy path from Q-values."""
    path = []
    state = env.reset()
    path.append(state)
    
    for _ in range(50):  # Max steps
        action = np.argmax(Q.get(state, np.zeros(4)))
        next_state, _, done = env.step(action)
        path.append(next_state)
        state = next_state
        if done:
            break
    
    return path


# Get greedy paths
sarsa_path = extract_greedy_path(Q_sarsa, CliffWalking())
qlearn_path = extract_greedy_path(Q_qlearn, CliffWalking())

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, path, title, color in [(axes[0], sarsa_path, 'SARSA: Safe Path', '#388e3c'),
                                 (axes[1], qlearn_path, 'Q-Learning: Optimal Path', '#d32f2f')]:
    # Draw grid
    for row in range(env.height):
        for col in range(env.width):
            y = env.height - 1 - row
            
            if (row, col) == env.start:
                cell_color = '#bbdefb'
            elif (row, col) == env.goal:
                cell_color = '#c8e6c9'
            elif (row, col) in env.cliff:
                cell_color = '#795548'
            else:
                cell_color = 'white'
            
            rect = Rectangle((col, y), 1, 1, facecolor=cell_color, 
                              edgecolor='black', linewidth=1)
            ax.add_patch(rect)
    
    # Draw path
    path_x = [p[1] + 0.5 for p in path]
    path_y = [env.height - 1 - p[0] + 0.5 for p in path]
    
    ax.plot(path_x, path_y, color=color, linewidth=4, marker='o', markersize=8, 
            alpha=0.8, label='Greedy path')
    
    # Start and end markers
    ax.scatter([path_x[0]], [path_y[0]], s=200, c='blue', marker='s', zorder=5, label='Start')
    ax.scatter([path_x[-1]], [path_y[-1]], s=200, c='green', marker='*', zorder=5, label='Goal')
    
    ax.set_xlim(0, env.width)
    ax.set_ylim(0, env.height)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title(f'{title}\n(Path length: {len(path)-1} steps)', fontsize=14, fontweight='bold', color=color)

plt.suptitle('Learned Paths: SARSA vs Q-Learning', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("\nPATH ANALYSIS:")
print("-"*60)
print(f"SARSA path length: {len(sarsa_path)-1} steps (takes the safe route!)")
print(f"Q-Learning path length: {len(qlearn_path)-1} steps (takes the optimal route!)")
print("\nBut Q-Learning's 'optimal' path is only optimal if you NEVER make mistakes!")
print("With ε-greedy exploration, Q-Learning occasionally falls off the cliff.")

In [None]:
# Visualize Q-values to understand the learned policies

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, Q, title, cmap_name in [(axes[0], Q_sarsa, 'SARSA: Avoids cliff edge', 'Greens'),
                                  (axes[1], Q_qlearn, 'Q-Learning: Edge is optimal', 'Reds')]:
    
    # Draw grid with values and arrows
    for row in range(env.height):
        for col in range(env.width):
            y = env.height - 1 - row
            
            if (row, col) == env.start:
                cell_color = '#bbdefb'
            elif (row, col) == env.goal:
                cell_color = '#c8e6c9'
            elif (row, col) in env.cliff:
                cell_color = '#795548'
            else:
                # Color based on max Q-value
                max_q = np.max(Q.get((row, col), np.zeros(4)))
                cell_color = 'white'
            
            rect = Rectangle((col, y), 1, 1, facecolor=cell_color, 
                              edgecolor='black', linewidth=1)
            ax.add_patch(rect)
            
            # Draw arrow for best action (if not cliff or goal)
            if (row, col) not in env.cliff and (row, col) != env.goal:
                q_vals = Q.get((row, col), np.zeros(4))
                best_action = np.argmax(q_vals)
                
                # Arrow directions
                dx, dy = [(0, 0.35), (0.35, 0), (0, -0.35), (-0.35, 0)][best_action]
                
                ax.arrow(col + 0.5 - dx/2, y + 0.5 - dy/2, dx, dy,
                         head_width=0.15, head_length=0.1, 
                         fc='black', ec='black', linewidth=1)
    
    ax.set_xlim(0, env.width)
    ax.set_ylim(0, env.height)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title(title, fontsize=14, fontweight='bold')

plt.suptitle('Learned Policies (Arrows = Best Action)', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("\nNOTICE:")
print("-"*60)
print("SARSA: Arrows point UP early (away from cliff edge)")
print("Q-Learning: Arrows point RIGHT immediately (along cliff edge)")
print("\nThis is the core difference in learned behavior!")

---
## Why SARSA is Safer: A Deeper Look

```
    ┌────────────────────────────────────────────────────────────────┐
    │              WHY SARSA AVOIDS THE CLIFF                        │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Consider being one step from the cliff edge:                 │
    │                                                                │
    │      ┌───┬───┐                                                │
    │      │ A │   │  A = Agent                                     │
    │      ├───┼───┤                                                │
    │      │ C │ C │  C = Cliff                                     │
    │      └───┴───┘                                                │
    │                                                                │
    │  Q-LEARNING thinks:                                           │
    │    "If I move right (best action), Q(right) is high"          │
    │    "I'll learn that being here is GOOD"                       │
    │    → Ignores the 10% chance of random DOWN action!            │
    │                                                                │
    │  SARSA thinks:                                                │
    │    "My next action will be ε-greedy"                          │
    │    "10% of the time, I'll randomly go DOWN (into cliff!)"     │
    │    → Q(s, a) includes the ACTUAL risk of exploration          │
    │    → Being near the cliff edge has LOWER value                │
    │    → I learn to stay away from dangerous positions!           │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Demonstrate the difference in Q-values near the cliff

print("Q-VALUES NEAR THE CLIFF")
print("="*60)

# Position (2, 1) is one row above the cliff at (3, 1)
critical_state = (2, 1)  # One step above cliff

print(f"\nState {critical_state} (one step above cliff):")
print("-"*40)

print("\nSARSA Q-values:")
sarsa_q = Q_sarsa.get(critical_state, np.zeros(4))
for i, name in enumerate(['UP', 'RIGHT', 'DOWN', 'LEFT']):
    danger = " ← DANGER!" if name == 'DOWN' else ""
    print(f"  {name}: {sarsa_q[i]:.2f}{danger}")

print("\nQ-Learning Q-values:")
qlearn_q = Q_qlearn.get(critical_state, np.zeros(4))
for i, name in enumerate(['UP', 'RIGHT', 'DOWN', 'LEFT']):
    danger = " ← DANGER!" if name == 'DOWN' else ""
    print(f"  {name}: {qlearn_q[i]:.2f}{danger}")

print("\n" + "="*60)
print("ANALYSIS:")
print("-"*60)
print(f"SARSA Q(DOWN): {sarsa_q[2]:.2f} (very negative - cliff death!)")
print(f"Q-Learning Q(DOWN): {qlearn_q[2]:.2f}")
print("\nSARSA learns that DOWN is terrible because with ε-greedy,")
print("there's always a chance of accidentally going DOWN!")
print("\nQ-Learning ignores this risk because it assumes max action.")

In [None]:
# Effect of epsilon on the difference

print("EFFECT OF EXPLORATION RATE (ε)")
print("="*60)
print("\nWith higher ε, SARSA becomes MORE cautious!")
print("(More exploration = more chance of accidents)")
print("-"*60)

epsilons = [0.0, 0.1, 0.2, 0.3]
sarsa_results = []
qlearn_results = []

for eps in epsilons:
    env = CliffWalking()
    
    _, rewards_s, _ = sarsa(env, n_episodes=300, epsilon=eps)
    _, rewards_q, _ = q_learning(env, n_episodes=300, epsilon=eps)
    
    sarsa_results.append(np.mean(rewards_s[-50:]))
    qlearn_results.append(np.mean(rewards_q[-50:]))
    
    print(f"ε = {eps:.1f}: SARSA avg = {sarsa_results[-1]:.1f}, Q-Learning avg = {qlearn_results[-1]:.1f}")

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(epsilons))
width = 0.35

bars1 = ax.bar(x - width/2, sarsa_results, width, label='SARSA', color='#388e3c')
bars2 = ax.bar(x + width/2, qlearn_results, width, label='Q-Learning', color='#d32f2f')

ax.axhline(y=-13, color='gray', linestyle='--', alpha=0.5, label='Optimal (no exploration)')

ax.set_xlabel('Epsilon (ε)', fontsize=12)
ax.set_ylabel('Average Reward', fontsize=12)
ax.set_title('Effect of Exploration Rate on Performance', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels([f'ε={e}' for e in epsilons])
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("INSIGHT:")
print("-"*60)
print("• With ε=0 (no exploration), both are similar")
print("• As ε increases, Q-Learning suffers more cliff falls!")
print("• SARSA adapts by taking safer paths as ε increases")
print("• Q-Learning keeps learning the 'optimal' edge path")
print("="*60)

---
## When to Use SARSA vs Q-Learning

```
    ┌────────────────────────────────────────────────────────────────┐
    │          SARSA vs Q-LEARNING: WHEN TO USE EACH                │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  USE SARSA WHEN:                                              │
    │    ✓ Exploration mistakes are COSTLY                          │
    │      (robotics, real-world systems, cliff-like scenarios)     │
    │    ✓ You care about the policy you're ACTUALLY following      │
    │    ✓ Safety during learning is important                      │
    │    ✓ You'll continue using ε-greedy after learning            │
    │                                                                │
    │  USE Q-LEARNING WHEN:                                         │
    │    ✓ You want the truly OPTIMAL policy                        │
    │    ✓ Exploration mistakes are cheap (simulation)              │
    │    ✓ You'll use greedy policy after learning (ε → 0)          │
    │    ✓ You have a separate exploration strategy                 │
    │                                                                │
    │  EXAMPLES:                                                    │
    │                                                                │
    │    SARSA: Robot navigation (falling is expensive!)           │
    │           Real-world trading (losses are real!)              │
    │           Medical treatment (mistakes can harm patients)      │
    │                                                                │
    │    Q-Learning: Video game AI (can reset)                     │
    │                Simulation environments (no real cost)         │
    │                When you'll stop exploring after training      │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Summary comparison table

fig, ax = plt.subplots(figsize=(12, 8))
ax.set_xlim(0, 12)
ax.set_ylim(0, 10)
ax.axis('off')

ax.text(6, 9.5, 'SARSA vs Q-Learning: Complete Comparison', ha='center', fontsize=18, fontweight='bold')

# Table headers
headers = ['Property', 'SARSA', 'Q-Learning']
header_x = [1.5, 5, 9]
for x, h in zip(header_x, headers):
    ax.text(x, 8.5, h, ha='center', fontsize=12, fontweight='bold')

ax.axhline(y=8.2, xmin=0.05, xmax=0.95, color='black', linewidth=2)

# Table rows
rows = [
    ('Type', 'On-policy', 'Off-policy'),
    ('Update uses', 'Q(s\', a\')', 'max Q(s\', a\')'),
    ('Learns', 'Q^π (current policy)', 'Q* (optimal policy)'),
    ('Safety', 'Safer ✓', 'Riskier'),
    ('Optimal policy?', 'Not always', 'Yes ✓'),
    ('Best for', 'Costly mistakes', 'Cheap simulation'),
]

for i, (prop, sarsa_val, ql_val) in enumerate(rows):
    y = 7.5 - i * 1.0
    ax.text(1.5, y, prop, ha='center', fontsize=11)
    
    sarsa_color = '#388e3c' if '✓' in sarsa_val else 'black'
    ql_color = '#388e3c' if '✓' in ql_val else 'black'
    
    ax.text(5, y, sarsa_val, ha='center', fontsize=11, color=sarsa_color)
    ax.text(9, y, ql_val, ha='center', fontsize=11, color=ql_color)

# Summary boxes
sarsa_box = FancyBboxPatch((3, 0.3), 3, 1.2, boxstyle="round,pad=0.1",
                           facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax.add_patch(sarsa_box)
ax.text(4.5, 1.1, 'SARSA', ha='center', fontsize=11, fontweight='bold', color='#388e3c')
ax.text(4.5, 0.6, 'Safe & Practical', ha='center', fontsize=10)

ql_box = FancyBboxPatch((7, 0.3), 3, 1.2, boxstyle="round,pad=0.1",
                         facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
ax.add_patch(ql_box)
ax.text(8.5, 1.1, 'Q-Learning', ha='center', fontsize=11, fontweight='bold', color='#d32f2f')
ax.text(8.5, 0.6, 'Optimal & Bold', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

---
## Summary: Key Takeaways

### The SARSA Update

```
Q(S, A) ← Q(S, A) + α × (R + γ×Q(S', A') - Q(S, A))
                              └── actual next action!
```

### Key Concepts

| Concept | Description |
|---------|-------------|
| **On-Policy** | Learns about the policy it's following |
| **S-A-R-S'-A'** | Uses all 5 elements: State, Action, Reward, next State, next Action |
| **Safety** | Accounts for exploration in value estimates |

### SARSA vs Q-Learning

| | SARSA | Q-Learning |
|-|-------|------------|
| Type | On-policy | Off-policy |
| Update | Q(s', a') | max Q(s', a') |
| Learns | Q^π | Q* |
| Safety | Higher | Lower |
| Use when | Mistakes costly | Simulation OK |

---
## Test Your Understanding

**1. What does SARSA stand for and why?**
<details>
<summary>Click to reveal answer</summary>
SARSA stands for State-Action-Reward-State'-Action'. It's named after the quintuple (S, A, R, S', A') that it uses for each update. The key is that it needs the next action A' before making the update, unlike Q-learning which uses max.
</details>

**2. What's the difference between on-policy and off-policy learning?**
<details>
<summary>Click to reveal answer</summary>
On-policy (SARSA): Learns about the SAME policy it uses for acting. The behavior policy = target policy.

Off-policy (Q-Learning): Learns about a DIFFERENT policy than it follows. Behavior policy (ε-greedy) ≠ Target policy (greedy).
</details>

**3. Why does SARSA take the safer path in cliff walking?**
<details>
<summary>Click to reveal answer</summary>
SARSA's Q-values include the actual risk of exploration. When near the cliff, SARSA knows that ε% of the time it will randomly take a bad action and fall. So being near the cliff has lower value. SARSA learns to stay away!

Q-learning assumes perfect (greedy) future behavior, so it thinks the edge is fine.
</details>

**4. When should you use SARSA instead of Q-learning?**
<details>
<summary>Click to reveal answer</summary>
Use SARSA when:
- Mistakes are costly (real robots, trading, medical systems)
- You'll continue exploring after learning
- Safety during learning matters more than finding the absolute optimal policy

Use Q-learning when:
- You're in simulation where mistakes are free
- You'll use a greedy policy after training
- You want the truly optimal policy
</details>

**5. What happens to SARSA's path as ε increases?**
<details>
<summary>Click to reveal answer</summary>
As ε increases, SARSA becomes MORE cautious. Higher ε means more random actions, which means higher risk of accidentally falling off the cliff. SARSA's Q-values reflect this higher risk, so it learns to stay even further from dangerous areas.

Q-learning's learned policy doesn't change much with ε because it always assumes greedy future behavior.
</details>

---
## What's Next?

Excellent work! You now understand SARSA and the crucial distinction between on-policy and off-policy learning.

In the final notebook of this section, we'll compare ALL the algorithms we've learned!

**Continue to:** [Notebook 5: Comparing Algorithms](05_comparing_algorithms.ipynb)

---

*SARSA: "I know I'm not perfect, so I'll plan accordingly."*