# Part 5: Temporal Difference Learning

In this notebook, we'll learn **Temporal Difference (TD)** methods - the most important class of model-free RL algorithms that combine ideas from Monte Carlo and Dynamic Programming.

## Recap from Notebook 04
- **Monte Carlo methods** learn from complete episodes without needing the MDP model
- **First-visit MC** estimates values by averaging returns from sampled episodes
- **MC Control** (GLIE MC) finds optimal policies using ε-greedy exploration
- **High variance**: Returns have high variance due to stochastic episodes
- **Limitation**: Must wait until episode ends — cannot learn online during episodes

## What This Notebook Covers
- TD(0) prediction: online learning from single steps
- The bias-variance tradeoff (TD vs MC)
- SARSA (On-policy TD control)
- Q-Learning (Off-policy TD control)
- Comparison of on-policy vs off-policy learning

## What This Notebook Does NOT Cover

| Topic | Why Not Here | How It Differs From What We Cover |
|-------|--------------|-----------------------------------|
| **Policy gradient methods** | TD learning improves value functions. Policy gradients directly optimize the policy using gradient ascent, which is fundamentally different. | We learn Q(s,a) and derive policy greedily. Policy gradient methods parameterize π(a\|s;θ) directly and use ∇_θ J(θ) to improve it — better for continuous actions but requires calculus of variations. |
| **Actor-critic methods** | Actor-critic combines value functions with policy gradients. We focus on pure value-based TD methods first. | We use either on-policy (SARSA) or off-policy (Q-learning) value learning. Actor-critic maintains both a policy (actor) and value function (critic) simultaneously — more complex but often more stable. |
| **Deep reinforcement learning** | We use tabular Q-tables. Deep RL uses neural networks for function approximation, adding convergence challenges. | In this notebook, we maintain Q[s,a] tables for all pairs. Deep Q-Networks (DQN) use neural networks Q(s,a;θ) with experience replay and target networks — essential for complex domains. |
| **Eligibility traces (TD(λ))** | We focus on one-step TD (TD(0)). Eligibility traces allow n-step and continuous spectrum between MC and TD. | We update using just the next reward and value: R + γV(s'). TD(λ) uses λ-weighted average of all n-step returns — bridges MC and TD but requires additional bookkeeping. |

## Preview: Bootstrapping - The Key Insight

Temporal Difference learning introduces a powerful concept called **bootstrapping**:

**Monte Carlo** waits until episode end:
$$V(S_t) \leftarrow V(S_t) + \alpha [G_t - V(S_t)]$$
where $G_t = R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \ldots$ (complete return)

**TD Learning** updates immediately using current estimate:
$$V(S_t) \leftarrow V(S_t) + \alpha [R_{t+1} + \gamma V(S_{t+1}) - V(S_t)]$$

The magic: $R_{t+1} + \gamma V(S_{t+1})$ is called the **TD target** — it's an estimate of the return using our current value function! This allows us to:
- Learn **online** (every step, not just at episode end)
- Work with **continuing tasks** (don't need episodes to terminate)
- Reduce **variance** (single reward vs. full trajectory)

## How to Read This Notebook
1. **Theory and algorithms**: Each section introduces TD prediction and control algorithms with mathematical foundations
2. **Bootstrapping intuition**: Understand the key difference between MC (actual returns) and TD (estimated returns)
3. **Step-by-step implementations**: Run code cells to see TD(0), SARSA, and Q-learning solve FrozenLake
4. **On-policy vs off-policy**: Compare how SARSA learns about its behavior policy while Q-learning learns the optimal policy
5. **Visualizations**: Observe how policies and values evolve during online learning

Let's begin!

## Prerequisites
- Understanding of MDPs and Bellman equations (Notebooks 01-02)
- Monte Carlo methods (Notebook 04)

## Setup

In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import time

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
np.random.seed(42)

print("Setup complete!")

In [None]:
# Create environment and helper variables
env = gym.make("FrozenLake-v1", is_slippery=True)

n_states = env.observation_space.n
n_actions = env.action_space.n
action_names = ['LEFT', 'DOWN', 'RIGHT', 'UP']
action_arrows = ['←', '↓', '→', '↑']

print("FrozenLake Environment")
print("=" * 40)
print(f"States: {n_states}")
print(f"Actions: {n_actions}")

In [None]:
# Visualization helper functions
def plot_value_function(V, title="Value Function", ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
    
    desc = env.unwrapped.desc.astype(str)
    nrow, ncol = desc.shape
    V_grid = V.reshape(nrow, ncol)
    
    im = ax.imshow(V_grid, cmap='RdYlGn', vmin=0, vmax=max(V.max(), 0.01))
    plt.colorbar(im, ax=ax, shrink=0.8)
    
    for i in range(nrow):
        for j in range(ncol):
            state = i * ncol + j
            cell = desc[i, j]
            color = 'white' if V_grid[i, j] < V.max() / 2 else 'black'
            ax.text(j, i, f'{cell}\n{V[state]:.3f}', ha='center', va='center',
                   fontsize=9, color=color)
    
    ax.set_xticks(range(ncol))
    ax.set_yticks(range(nrow))
    ax.set_title(title)
    return ax

def plot_policy(Q, title="Policy", ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
    
    desc = env.unwrapped.desc.astype(str)
    nrow, ncol = desc.shape
    colors = {'S': 'lightblue', 'F': 'white', 'H': 'lightcoral', 'G': 'lightgreen'}
    
    for i in range(nrow):
        for j in range(ncol):
            state = i * ncol + j
            cell = desc[i, j]
            
            rect = plt.Rectangle((j, nrow-1-i), 1, 1, fill=True,
                                 facecolor=colors.get(cell, 'white'), edgecolor='black')
            ax.add_patch(rect)
            
            best_action = np.argmax(Q[state])
            
            if cell not in ['H', 'G']:
                ax.text(j + 0.5, nrow - 1 - i + 0.5, 
                       f'{cell}\n{action_arrows[best_action]}',
                       ha='center', va='center', fontsize=14, fontweight='bold')
            else:
                ax.text(j + 0.5, nrow - 1 - i + 0.5, cell,
                       ha='center', va='center', fontsize=14, fontweight='bold')
    
    ax.set_xlim(0, ncol)
    ax.set_ylim(0, nrow)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title(title)
    return ax

print("Visualization functions ready!")

---
# 1. What is Temporal Difference Learning?

**Temporal Difference (TD)** learning combines ideas from:
- **Monte Carlo**: Learn from experience (model-free)
- **Dynamic Programming**: Bootstrap (update estimates based on other estimates)

## Key Insight: Bootstrapping

**Monte Carlo** waits until the end of episode to update:
$$V(S_t) \leftarrow V(S_t) + \alpha (G_t - V(S_t))$$

where $G_t = R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \ldots$ (full return)

**TD** updates immediately using estimated return:
$$V(S_t) \leftarrow V(S_t) + \alpha (R_{t+1} + \gamma V(S_{t+1}) - V(S_t))$$

The term $R_{t+1} + \gamma V(S_{t+1})$ is called the **TD target**.

The difference $(R_{t+1} + \gamma V(S_{t+1}) - V(S_t))$ is the **TD error** ($\delta$).

## Advantages of TD

1. **Online learning**: Update after every step, not just at episode end
2. **Works for continuing tasks**: Don't need episodes to terminate
3. **Lower variance**: Uses single reward + estimate instead of full return
4. **Often faster convergence**: Especially in practice

## Disadvantages

1. **Biased**: Bootstrapping introduces bias from current estimates
2. **Depends on initialization**: Bad initial values can slow learning

---
# 2. TD(0) Prediction

The simplest TD method: update after each step using immediate reward and next state estimate.

$$V(S_t) \leftarrow V(S_t) + \alpha [R_{t+1} + \gamma V(S_{t+1}) - V(S_t)]$$

Where:
- $\alpha$ is the learning rate
- $R_{t+1} + \gamma V(S_{t+1})$ is the TD target
- $R_{t+1} + \gamma V(S_{t+1}) - V(S_t)$ is the TD error $\delta_t$

In [None]:
def td0_prediction(env, policy, gamma, alpha, n_episodes):
    """
    TD(0) Prediction for estimating V^π.
    
    Args:
        env: Gymnasium environment
        policy: Policy to evaluate (π[s,a] probabilities)
        gamma: Discount factor
        alpha: Learning rate
        n_episodes: Number of episodes
    
    Returns:
        V: Estimated state value function
        V_history: V at intervals for visualization
        td_errors: TD errors during training
    """
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    V = np.zeros(n_states)
    V_history = [V.copy()]
    td_errors = []
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        done = False
        
        while not done:
            # Select action according to policy
            action = np.random.choice(n_actions, p=policy[state])
            
            # Take action
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # TD update
            # V(s) <- V(s) + α * [r + γV(s') - V(s)]
            td_target = reward + gamma * V[next_state] * (not done)
            td_error = td_target - V[state]
            V[state] += alpha * td_error
            
            td_errors.append(td_error)
            state = next_state
        
        # Save history at intervals
        if (episode + 1) % (n_episodes // 10) == 0:
            V_history.append(V.copy())
    
    return V, V_history, td_errors

In [None]:
# Run TD(0) prediction for random policy
uniform_policy = np.ones((n_states, n_actions)) / n_actions

print("TD(0) Prediction")
print("=" * 50)

V_td, V_history_td, td_errors = td0_prediction(
    env, uniform_policy, gamma=0.99, alpha=0.1, n_episodes=50000
)

print(f"\nEstimated V^π (random policy):")
print(V_td.reshape(4, 4).round(4))

In [None]:
# Compare TD(0) with true values (from DP)
def extract_mdp(env):
    n_s = env.observation_space.n
    n_a = env.action_space.n
    P = np.zeros((n_s, n_a, n_s))
    R = np.zeros((n_s, n_a))
    for s in range(n_s):
        for a in range(n_a):
            for prob, next_s, reward, done in env.unwrapped.P[s][a]:
                P[s, a, next_s] += prob
                R[s, a] += prob * reward
    return P, R

def policy_evaluation_dp(P, R, policy, gamma, theta=1e-8):
    n_states = P.shape[0]
    n_actions = P.shape[1]
    V = np.zeros(n_states)
    while True:
        V_new = np.zeros(n_states)
        for s in range(n_states):
            for a in range(n_actions):
                V_new[s] += policy[s, a] * (R[s, a] + gamma * np.sum(P[s, a] * V))
        if np.max(np.abs(V_new - V)) < theta:
            break
        V = V_new
    return V

P, R = extract_mdp(env)
V_true = policy_evaluation_dp(P, R, uniform_policy, gamma=0.99)

print("Comparison: TD(0) vs True Values (DP)")
print("=" * 60)
print(f"Mean Absolute Error: {np.mean(np.abs(V_td - V_true)):.4f}")
print(f"Max Absolute Error: {np.max(np.abs(V_td - V_true)):.4f}")

In [None]:
# Visualize TD(0) convergence
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Plot V at different stages
episodes_at = [0, 5000, 10000, 25000, 50000]
for idx, (ax, ep) in enumerate(zip(axes.flat[:-1], episodes_at)):
    hist_idx = min(idx, len(V_history_td)-1)
    plot_value_function(V_history_td[hist_idx], title=f"After {ep} episodes", ax=ax)

# TD error over time
ax = axes.flat[-1]
window = 1000
td_errors_smooth = np.convolve(np.abs(td_errors), np.ones(window)/window, mode='valid')
ax.plot(td_errors_smooth)
ax.set_xlabel('Step')
ax.set_ylabel('|TD Error| (moving avg)')
ax.set_title('TD Error Over Time')

plt.suptitle("TD(0) Prediction Convergence", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

---
# 3. SARSA: On-Policy TD Control

**SARSA** (State-Action-Reward-State-Action) is an on-policy TD control algorithm.

## The SARSA Update

$$Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha [R_{t+1} + \gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t)]$$

The name comes from the quintuple: $(S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1})$

## On-Policy

SARSA is **on-policy**: it learns about the policy it's following.
- Uses $A_{t+1}$ which is selected by the same policy
- The Q-values converge to $Q^\pi$ for the behavior policy $\pi$
- Typically uses ε-greedy policy

In [None]:
def sarsa(env, gamma, alpha, n_episodes, epsilon=0.1, 
          epsilon_decay=1.0, min_epsilon=0.01):
    """
    SARSA: On-policy TD Control.
    
    Args:
        env: Gymnasium environment
        gamma: Discount factor
        alpha: Learning rate
        n_episodes: Number of episodes
        epsilon: Exploration rate
        epsilon_decay: Decay rate for epsilon
        min_epsilon: Minimum epsilon
    
    Returns:
        Q: Learned Q-values
        stats: Training statistics
    """
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    Q = np.zeros((n_states, n_actions))
    
    episode_rewards = []
    episode_lengths = []
    epsilons = []
    
    def epsilon_greedy_action(state, eps):
        if np.random.random() < eps:
            return np.random.randint(n_actions)
        return np.argmax(Q[state])
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        action = epsilon_greedy_action(state, epsilon)
        
        total_reward = 0
        steps = 0
        done = False
        
        while not done:
            # Take action
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Choose next action (for SARSA update)
            next_action = epsilon_greedy_action(next_state, epsilon)
            
            # SARSA update
            # Q(s,a) <- Q(s,a) + α * [r + γ*Q(s',a') - Q(s,a)]
            td_target = reward + gamma * Q[next_state, next_action] * (not done)
            td_error = td_target - Q[state, action]
            Q[state, action] += alpha * td_error
            
            state = next_state
            action = next_action
            total_reward += reward
            steps += 1
        
        episode_rewards.append(total_reward)
        episode_lengths.append(steps)
        epsilons.append(epsilon)
        
        # Decay epsilon
        epsilon = max(min_epsilon, epsilon * epsilon_decay)
    
    stats = {
        'episode_rewards': episode_rewards,
        'episode_lengths': episode_lengths,
        'epsilons': epsilons
    }
    
    return Q, stats

In [None]:
# Run SARSA
print("SARSA Training")
print("=" * 50)

start_time = time.time()
Q_sarsa, stats_sarsa = sarsa(
    env, gamma=0.99, alpha=0.1, n_episodes=100000,
    epsilon=1.0, epsilon_decay=0.99995, min_epsilon=0.01
)
sarsa_time = time.time() - start_time

print(f"Training time: {sarsa_time:.2f} seconds")
print(f"Final epsilon: {stats_sarsa['epsilons'][-1]:.4f}")

In [None]:
# Plot SARSA training progress
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Learning curve
window = 1000
rewards_smooth = np.convolve(stats_sarsa['episode_rewards'], 
                              np.ones(window)/window, mode='valid')
axes[0, 0].plot(rewards_smooth)
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Reward (moving avg)')
axes[0, 0].set_title(f'SARSA Learning Curve (window={window})')

# Epsilon decay
axes[0, 1].plot(stats_sarsa['epsilons'])
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Epsilon')
axes[0, 1].set_title('Exploration Rate Decay')

# Value function
V_sarsa = np.max(Q_sarsa, axis=1)
plot_value_function(V_sarsa, title="Learned V = max Q(s,a)", ax=axes[1, 0])

# Policy
plot_policy(Q_sarsa, title="Learned Policy", ax=axes[1, 1])

plt.suptitle("SARSA Results (100,000 episodes)", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

---
# 4. Q-Learning: Off-Policy TD Control

**Q-Learning** is an off-policy TD control algorithm - the most famous RL algorithm!

## The Q-Learning Update

$$Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha [R_{t+1} + \gamma \max_a Q(S_{t+1}, a) - Q(S_t, A_t)]$$

## Off-Policy

Q-Learning is **off-policy**: it learns about the optimal policy while following a different (exploratory) policy.

Key difference from SARSA:
- **SARSA**: Uses $Q(S_{t+1}, A_{t+1})$ where $A_{t+1}$ comes from the behavior policy
- **Q-Learning**: Uses $\max_a Q(S_{t+1}, a)$ - the value of the best action

This means Q-Learning directly learns $Q^*$ regardless of the policy being followed!

In [None]:
def q_learning(env, gamma, alpha, n_episodes, epsilon=0.1,
               epsilon_decay=1.0, min_epsilon=0.01):
    """
    Q-Learning: Off-policy TD Control.
    
    Args:
        env: Gymnasium environment
        gamma: Discount factor
        alpha: Learning rate
        n_episodes: Number of episodes
        epsilon: Exploration rate
        epsilon_decay: Decay rate for epsilon
        min_epsilon: Minimum epsilon
    
    Returns:
        Q: Learned Q-values
        stats: Training statistics
    """
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    Q = np.zeros((n_states, n_actions))
    
    episode_rewards = []
    episode_lengths = []
    epsilons = []
    
    def epsilon_greedy_action(state, eps):
        if np.random.random() < eps:
            return np.random.randint(n_actions)
        return np.argmax(Q[state])
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        
        total_reward = 0
        steps = 0
        done = False
        
        while not done:
            # Choose action using ε-greedy
            action = epsilon_greedy_action(state, epsilon)
            
            # Take action
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Q-Learning update
            # Q(s,a) <- Q(s,a) + α * [r + γ*max_a' Q(s',a') - Q(s,a)]
            td_target = reward + gamma * np.max(Q[next_state]) * (not done)
            td_error = td_target - Q[state, action]
            Q[state, action] += alpha * td_error
            
            state = next_state
            total_reward += reward
            steps += 1
        
        episode_rewards.append(total_reward)
        episode_lengths.append(steps)
        epsilons.append(epsilon)
        
        # Decay epsilon
        epsilon = max(min_epsilon, epsilon * epsilon_decay)
    
    stats = {
        'episode_rewards': episode_rewards,
        'episode_lengths': episode_lengths,
        'epsilons': epsilons
    }
    
    return Q, stats

In [None]:
# Run Q-Learning
print("Q-Learning Training")
print("=" * 50)

start_time = time.time()
Q_qlearn, stats_qlearn = q_learning(
    env, gamma=0.99, alpha=0.1, n_episodes=100000,
    epsilon=1.0, epsilon_decay=0.99995, min_epsilon=0.01
)
qlearn_time = time.time() - start_time

print(f"Training time: {qlearn_time:.2f} seconds")
print(f"Final epsilon: {stats_qlearn['epsilons'][-1]:.4f}")

In [None]:
# Plot Q-Learning training progress
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Learning curve
window = 1000
rewards_smooth = np.convolve(stats_qlearn['episode_rewards'], 
                              np.ones(window)/window, mode='valid')
axes[0, 0].plot(rewards_smooth)
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Reward (moving avg)')
axes[0, 0].set_title(f'Q-Learning Curve (window={window})')

# Epsilon decay
axes[0, 1].plot(stats_qlearn['epsilons'])
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Epsilon')
axes[0, 1].set_title('Exploration Rate Decay')

# Value function
V_qlearn = np.max(Q_qlearn, axis=1)
plot_value_function(V_qlearn, title="Learned V = max Q(s,a)", ax=axes[1, 0])

# Policy
plot_policy(Q_qlearn, title="Learned Policy", ax=axes[1, 1])

plt.suptitle("Q-Learning Results (100,000 episodes)", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

---
# 5. SARSA vs Q-Learning Comparison

Let's compare the two algorithms in detail.

In [None]:
# Evaluate both policies
def evaluate_policy(env, Q, n_episodes=10000):
    """Evaluate a greedy policy derived from Q."""
    rewards = []
    
    for _ in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False
        
        while not done:
            action = np.argmax(Q[state])
            state, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward
            done = terminated or truncated
        
        rewards.append(total_reward)
    
    return np.array(rewards)

# Get optimal Q from DP for comparison
def value_iteration(P, R, gamma, theta=1e-8):
    n_states, n_actions = R.shape
    V = np.zeros(n_states)
    while True:
        V_new = np.zeros(n_states)
        for s in range(n_states):
            V_new[s] = np.max([R[s, a] + gamma * np.sum(P[s, a] * V) for a in range(n_actions)])
        if np.max(np.abs(V_new - V)) < theta:
            break
        V = V_new
    Q = np.zeros((n_states, n_actions))
    for s in range(n_states):
        for a in range(n_actions):
            Q[s, a] = R[s, a] + gamma * np.sum(P[s, a] * V)
    return Q

Q_optimal = value_iteration(P, R, gamma=0.99)

# Evaluate
print("Policy Evaluation Comparison")
print("=" * 50)

rewards_sarsa_eval = evaluate_policy(env, Q_sarsa, n_episodes=10000)
rewards_qlearn_eval = evaluate_policy(env, Q_qlearn, n_episodes=10000)
rewards_optimal_eval = evaluate_policy(env, Q_optimal, n_episodes=10000)

print(f"SARSA: Success rate = {np.mean(rewards_sarsa_eval)*100:.2f}%")
print(f"Q-Learning: Success rate = {np.mean(rewards_qlearn_eval)*100:.2f}%")
print(f"Optimal (DP): Success rate = {np.mean(rewards_optimal_eval)*100:.2f}%")

In [None]:
# Compare learning curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Learning curves side by side
window = 1000
sarsa_smooth = np.convolve(stats_sarsa['episode_rewards'], 
                            np.ones(window)/window, mode='valid')
qlearn_smooth = np.convolve(stats_qlearn['episode_rewards'], 
                             np.ones(window)/window, mode='valid')

axes[0].plot(sarsa_smooth, label='SARSA', alpha=0.8)
axes[0].plot(qlearn_smooth, label='Q-Learning', alpha=0.8)
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Reward (moving avg)')
axes[0].set_title('Learning Curves Comparison')
axes[0].legend()

# Success rate comparison
methods = ['SARSA', 'Q-Learning', 'Optimal (DP)']
success_rates = [
    np.mean(rewards_sarsa_eval)*100,
    np.mean(rewards_qlearn_eval)*100,
    np.mean(rewards_optimal_eval)*100
]
colors = ['steelblue', 'orange', 'green']

bars = axes[1].bar(methods, success_rates, color=colors, edgecolor='black')
axes[1].set_ylabel('Success Rate (%)')
axes[1].set_title('Final Policy Performance')
axes[1].set_ylim(0, 100)
for bar, rate in zip(bars, success_rates):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{rate:.1f}%', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Compare Q-values with optimal
print("Q-value Comparison with Optimal")
print("=" * 50)

sarsa_error = np.mean(np.abs(Q_sarsa - Q_optimal))
qlearn_error = np.mean(np.abs(Q_qlearn - Q_optimal))

print(f"SARSA Mean Absolute Q-error: {sarsa_error:.4f}")
print(f"Q-Learning Mean Absolute Q-error: {qlearn_error:.4f}")

In [None]:
# Visualize policies side by side
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

plot_policy(Q_sarsa, title="SARSA Policy", ax=axes[0])
plot_policy(Q_qlearn, title="Q-Learning Policy", ax=axes[1])
plot_policy(Q_optimal, title="Optimal Policy (DP)", ax=axes[2])

plt.suptitle("Policy Comparison", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

---
# 6. Key Differences: SARSA vs Q-Learning

| Aspect | SARSA | Q-Learning |
|--------|-------|------------|
| **Type** | On-policy | Off-policy |
| **Update uses** | $Q(S', A')$ where $A'$ from policy | $\max_a Q(S', a)$ |
| **Learns** | $Q^\pi$ for behavior policy | $Q^*$ optimal Q |
| **Behavior** | More conservative/safe | More aggressive/risky |
| **Convergence** | To $Q^\pi$ | To $Q^*$ |

## On-Policy vs Off-Policy

**On-policy (SARSA)**:
- Learns about the policy it's following
- Takes exploration into account
- May be safer in dangerous environments

**Off-policy (Q-Learning)**:
- Learns optimal policy while following any policy
- Can use experience from any source (replay buffer)
- More sample efficient but may be riskier

---
# 7. Effect of Learning Rate

The learning rate α controls how much new information overrides old information.

In [None]:
# Test different learning rates
alphas = [0.01, 0.1, 0.5, 0.9]
results_alpha = {}

print("Testing different learning rates (Q-Learning)")
print("=" * 50)

for alpha in alphas:
    Q, stats = q_learning(
        env, gamma=0.99, alpha=alpha, n_episodes=50000,
        epsilon=1.0, epsilon_decay=0.9999, min_epsilon=0.01
    )
    rewards = evaluate_policy(env, Q, n_episodes=5000)
    results_alpha[alpha] = {
        'Q': Q,
        'stats': stats,
        'success_rate': np.mean(rewards) * 100
    }
    print(f"α = {alpha}: Success rate = {results_alpha[alpha]['success_rate']:.2f}%")

In [None]:
# Plot learning rate comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Learning curves
window = 500
for alpha in alphas:
    rewards_smooth = np.convolve(results_alpha[alpha]['stats']['episode_rewards'],
                                  np.ones(window)/window, mode='valid')
    axes[0].plot(rewards_smooth, label=f'α = {alpha}')

axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Reward (moving avg)')
axes[0].set_title('Learning Curves for Different α')
axes[0].legend()

# Final success rates
success_rates = [results_alpha[a]['success_rate'] for a in alphas]
axes[1].bar([str(a) for a in alphas], success_rates, color='steelblue', edgecolor='black')
axes[1].set_xlabel('Learning Rate (α)')
axes[1].set_ylabel('Success Rate (%)')
axes[1].set_title('Final Performance vs Learning Rate')

for i, (a, rate) in enumerate(zip(alphas, success_rates)):
    axes[1].text(i, rate + 1, f'{rate:.1f}%', ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

---
# Summary and Concept Map

In this notebook, we learned the three most important Temporal Difference algorithms:

```
TEMPORAL DIFFERENCE LEARNING
============================

Key Idea: Bootstrap = Update estimates using other estimates
Advantage: Learn online (every step), model-free, works for continuing tasks
────────────────────────────────────────────────────


TD(0) PREDICTION
────────────────
Problem: Estimate V^π(s) for a given policy π

Algorithm: Update after each step using immediate reward + next state estimate
V(S_t) ← V(S_t) + α [R_{t+1} + γV(S_{t+1}) - V(S_t)]

Components:
- TD target: R_{t+1} + γV(S_{t+1})
- TD error (δ): R_{t+1} + γV(S_{t+1}) - V(S_t)

Properties:
- Online learning (updates every step)
- Lower variance than MC (uses single reward)
- Biased (depends on current estimate V)
- Works for continuing tasks


SARSA (On-Policy TD Control)
─────────────────────────────
Problem: Find optimal policy using on-policy learning

Algorithm: Update Q using action A_{t+1} from behavior policy
Q(S_t, A_t) ← Q(S_t, A_t) + α [R_{t+1} + γQ(S_{t+1}, A_{t+1}) - Q(S_t, A_t)]

Name: State-Action-Reward-State-Action (S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1})

Key property: **On-policy**
- Learns about the policy it's following (typically ε-greedy)
- A_{t+1} comes from the same policy we're evaluating
- Converges to Q^π for behavior policy π
- More conservative/safe (accounts for exploration)

Use when: You want the agent to learn a policy that includes exploration


Q-LEARNING (Off-Policy TD Control)
───────────────────────────────────
Problem: Find optimal policy using off-policy learning

Algorithm: Update Q using max over actions (not actual next action)
Q(S_t, A_t) ← Q(S_t, A_t) + α [R_{t+1} + γ max_a Q(S_{t+1}, a) - Q(S_t, A_t)]

Key property: **Off-policy**
- Learns optimal policy while following any behavior policy
- Uses max_a Q(S_{t+1}, a) instead of Q(S_{t+1}, A_{t+1})
- Converges to Q* (optimal Q) regardless of behavior policy
- More aggressive/risky (assumes greedy actions)

Use when: You want to learn the optimal policy from any experience


ON-POLICY VS OFF-POLICY
────────────────────────

| Aspect | On-Policy (SARSA) | Off-Policy (Q-Learning) |
|--------|-------------------|-------------------------|
| **Learns** | Q^π (for behavior policy) | Q* (optimal policy) |
| **Update uses** | A_{t+1} from policy | max_a Q(s',a) |
| **Behavior** | Conservative (safe) | Aggressive (risky) |
| **Exploration** | Must explore to learn | Learns from any policy |
| **Example** | Learn to drive cautiously while being cautious | Learn optimal driving while practicing cautiously |


TD VS MC VS DP
──────────────

| Property | DP | MC | TD |
|----------|----|----|----|
| Model-free | No ✗ | Yes ✓ | Yes ✓ |
| Bootstraps | Yes ✓ | No ✗ | Yes ✓ |
| Online (step-by-step) | Yes ✓ | No ✗ | Yes ✓ |
| Works for continuing tasks | Yes ✓ | No ✗ | Yes ✓ |
| Unbiased | Yes ✓ | Yes ✓ | No ✗ |
| Low variance | Yes ✓ | No ✗ | Yes ✓ |


BIAS-VARIANCE TRADEOFF
──────────────────────

Monte Carlo:
- Uses actual returns G_t
- Unbiased (correct on average)
- High variance (different episodes vary a lot)

Temporal Difference:
- Uses estimated returns R + γV(s')
- Biased (depends on current estimate)
- Low variance (single step of randomness)

Trade-off: TD often learns faster in practice despite bias!
```

## What's Next?

In the final notebook (**06_algorithm_comparison.ipynb**), we'll:
- Compare all algorithms (DP, MC, TD) side by side
- Discuss when to use which method
- Review hyperparameter tuning strategies
- Summarize the entire RL tutorial

## Key Takeaways

1. **TD combines best of MC and DP**: Model-free like MC, bootstraps like DP
2. **Updates every step**: Don't need to wait for episode end (online learning)
3. **SARSA (on-policy)**: Learns about the policy being followed, more conservative
4. **Q-Learning (off-policy)**: Learns optimal policy regardless of behavior, more aggressive
5. **Bias-variance tradeoff**: TD has lower variance but is biased; MC is unbiased but high variance
6. **Learning rate α**: Controls how much new information overrides old estimates

In [None]:
---
# Your Turn

Now it's time to test your understanding with some hands-on exercises!

## Exercise 1: Implement TD(0) Prediction from Scratch

Modify the TD(0) prediction implementation to track and visualize the TD error over time.

**Task**: Complete the code below to implement TD(0) with enhanced tracking:

```python
# YOUR CODE HERE
# Implement TD(0) prediction that returns both V and TD errors per state

def td0_with_tracking(env, policy, gamma, alpha, n_episodes):
    """
    TD(0) Prediction with detailed tracking.
    
    Returns:
        V: Final value function
        td_errors_by_state: Dictionary mapping state -> list of TD errors
    """
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    V = np.zeros(n_states)
    td_errors_by_state = {s: [] for s in range(n_states)}
    
    # TODO: Implement the algorithm
    # Hint: For each step, record the TD error for the visited state
    
    pass  # Replace with your implementation

# TODO: Run your implementation and plot TD errors for states 0, 6, 14
```

<details>
<summary>Click to see hint</summary>

For each episode:
1. Reset environment and get initial state
2. Until done:
   - Select action from policy
   - Take step and observe next_state, reward
   - Compute TD error: δ = r + γV(s') - V(s)
   - Update: V(s) += α * δ
   - Record δ in td_errors_by_state[s]

</details>

<details>
<summary>Click to see solution</summary>

```python
def td0_with_tracking(env, policy, gamma, alpha, n_episodes):
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    
    V = np.zeros(n_states)
    td_errors_by_state = {s: [] for s in range(n_states)}
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        done = False
        
        while not done:
            action = np.random.choice(n_actions, p=policy[state])
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Compute TD error
            td_target = reward + gamma * V[next_state] * (not done)
            td_error = td_target - V[state]
            
            # Update V
            V[state] += alpha * td_error
            
            # Track TD error
            td_errors_by_state[state].append(td_error)
            
            state = next_state
    
    return V, td_errors_by_state

# Run and visualize
uniform_policy = np.ones((n_states, n_actions)) / n_actions
V, td_errors = td0_with_tracking(env, uniform_policy, gamma=0.99, alpha=0.1, n_episodes=10000)

# Plot TD errors for selected states
fig, ax = plt.subplots(figsize=(12, 5))
for s in [0, 6, 14]:
    if len(td_errors[s]) > 0:
        # Moving average for smoothness
        window = 100
        errors = td_errors[s]
        if len(errors) >= window:
            smooth = np.convolve(errors, np.ones(window)/window, mode='valid')
            ax.plot(smooth, label=f'State {s}', alpha=0.7)

ax.set_xlabel('Visit Count')
ax.set_ylabel('TD Error (smoothed)')
ax.set_title('TD Error Convergence for Selected States')
ax.legend()
plt.show()

print(f"Final V(0) = {V[0]:.4f}")
```

</details>

## Exercise 2: Tune SARSA or Q-Learning Hyperparameters

**Task**: The learning rate α and exploration rate ε significantly affect learning. Experiment with different values and find the best combination for fastest convergence.

```python
# YOUR CODE HERE
# Test combinations of α and ε for Q-Learning

learning_rates = [0.01, 0.1, 0.5]
epsilons = [0.05, 0.1, 0.2]
results = {}

for alpha in learning_rates:
    for epsilon_start in epsilons:
        # TODO: Run Q-learning with these hyperparameters
        # TODO: Evaluate the learned policy
        # TODO: Record success rate and learning curve
        pass

# TODO: Create a heatmap showing success rate for each (α, ε) combination
```

<details>
<summary>Click to see hint</summary>

For each (α, ε) combination:
1. Run q_learning() with n_episodes=50000
2. Evaluate resulting Q using evaluate_policy()
3. Store success rate in results[(alpha, epsilon)]
4. Use plt.imshow() or seaborn.heatmap() to visualize

</details>

<details>
<summary>Click to see solution</summary>

```python
learning_rates = [0.01, 0.1, 0.5]
epsilons = [0.05, 0.1, 0.2]
results = np.zeros((len(learning_rates), len(epsilons)))

print("Testing hyperparameter combinations...")
for i, alpha in enumerate(learning_rates):
    for j, epsilon_start in enumerate(epsilons):
        # Run Q-learning
        Q, stats = q_learning(
            env, gamma=0.99, alpha=alpha, n_episodes=50000,
            epsilon=epsilon_start, epsilon_decay=0.9999, min_epsilon=0.01
        )
        
        # Evaluate
        rewards = evaluate_policy(env, Q, n_episodes=5000)
        success_rate = np.mean(rewards) * 100
        results[i, j] = success_rate
        
        print(f"α={alpha:.2f}, ε={epsilon_start:.2f}: {success_rate:.1f}% success")

# Visualize
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(results, cmap='RdYlGn', vmin=0, vmax=100)
ax.set_xticks(range(len(epsilons)))
ax.set_yticks(range(len(learning_rates)))
ax.set_xticklabels([f'{e}' for e in epsilons])
ax.set_yticklabels([f'{a}' for a in learning_rates])
ax.set_xlabel('Initial Epsilon (ε)')
ax.set_ylabel('Learning Rate (α)')
ax.set_title('Q-Learning Success Rate (%) by Hyperparameters')

# Add text annotations
for i in range(len(learning_rates)):
    for j in range(len(epsilons)):
        text = ax.text(j, i, f'{results[i, j]:.1f}',
                      ha="center", va="center", color="black", fontweight='bold')

plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()

# Find best combination
best_idx = np.unravel_index(np.argmax(results), results.shape)
print(f"\nBest combination: α={learning_rates[best_idx[0]]}, ε={epsilons[best_idx[1]]}")
print(f"Success rate: {results[best_idx]:.1f}%")
```

</details>

## Exercise 3: Compare On-Policy (SARSA) vs Off-Policy (Q-Learning) on Cliff Walking

**Task**: The classic "Cliff Walking" environment highlights the difference between SARSA and Q-Learning. SARSA learns a safer path while Q-Learning learns the optimal (risky) path.

```python
# YOUR CODE HERE
# Create CliffWalking environment and compare SARSA vs Q-Learning

# Step 1: Create the environment
# env_cliff = gym.make('CliffWalking-v0')

# Step 2: Run SARSA
# Q_sarsa, stats_sarsa = sarsa(env_cliff, gamma=0.99, alpha=0.1, 
#                               n_episodes=5000, epsilon=0.1)

# Step 3: Run Q-Learning  
# Q_qlearn, stats_qlearn = q_learning(env_cliff, gamma=0.99, alpha=0.1,
#                                      n_episodes=5000, epsilon=0.1)

# Step 4: Visualize the learned paths
# Hint: Extract the greedy policy from each Q and visualize the path from start
```

<details>
<summary>Click to see hint</summary>

CliffWalking is a 4×12 grid where:
- Start: bottom-left (state 36)
- Goal: bottom-right (state 47)
- Cliff: bottom row between start and goal (states 37-46)

SARSA (on-policy) will learn a safer path above the cliff because it accounts for ε-greedy exploration.

Q-Learning (off-policy) will learn the optimal path along the cliff edge because it assumes greedy actions.

</details>

<details>
<summary>Click to see solution</summary>

```python
# Create environment
env_cliff = gym.make('CliffWalking-v0')
n_states_cliff = env_cliff.observation_space.n
n_actions_cliff = env_cliff.action_space.n

print("Cliff Walking Environment")
print("Grid: 4 rows × 12 columns")
print("Start: bottom-left, Goal: bottom-right")
print("Cliff: bottom row between start and goal\n")

# Run SARSA (on-policy)
print("Training SARSA (on-policy)...")
Q_sarsa_cliff, stats_sarsa_cliff = sarsa(
    env_cliff, gamma=0.99, alpha=0.5, n_episodes=5000,
    epsilon=0.1, epsilon_decay=1.0, min_epsilon=0.1
)

# Run Q-Learning (off-policy)
print("Training Q-Learning (off-policy)...")
Q_qlearn_cliff, stats_qlearn_cliff = q_learning(
    env_cliff, gamma=0.99, alpha=0.5, n_episodes=5000,
    epsilon=0.1, epsilon_decay=1.0, min_epsilon=0.1
)

# Visualize learned policies
def visualize_cliff_policy(Q, title):
    """Visualize policy on cliff walking grid."""
    policy_grid = np.zeros((4, 12), dtype=int)
    for s in range(48):
        row = s // 12
        col = s % 12
        policy_grid[row, col] = np.argmax(Q[s])
    
    fig, ax = plt.subplots(figsize=(12, 4))
    
    # Draw grid
    for i in range(4):
        for j in range(12):
            state = i * 12 + j
            
            # Color cells
            if state == 36:
                color = 'lightblue'  # Start
                text = 'S'
            elif state == 47:
                color = 'lightgreen'  # Goal
                text = 'G'
            elif 37 <= state <= 46:
                color = 'red'  # Cliff
                text = 'C'
            else:
                color = 'white'
                action = policy_grid[i, j]
                text = ['←', '↓', '→', '↑'][action]
            
            rect = plt.Rectangle((j, 3-i), 1, 1, facecolor=color, edgecolor='black')
            ax.add_patch(rect)
            ax.text(j+0.5, 3-i+0.5, text, ha='center', va='center', 
                   fontsize=12, fontweight='bold')
    
    ax.set_xlim(0, 12)
    ax.set_ylim(0, 4)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title(title, fontsize=14)
    return ax

# Plot both policies
fig, axes = plt.subplots(1, 2, figsize=(16, 4))
plt.sca(axes[0])
visualize_cliff_policy(Q_sarsa_cliff, "SARSA Policy (Safe Path)")
plt.sca(axes[1])
visualize_cliff_policy(Q_qlearn_cliff, "Q-Learning Policy (Optimal but Risky)")
plt.tight_layout()
plt.show()

# Compare learning curves
fig, ax = plt.subplots(figsize=(10, 5))
window = 100
sarsa_smooth = np.convolve(stats_sarsa_cliff['episode_rewards'], 
                            np.ones(window)/window, mode='valid')
qlearn_smooth = np.convolve(stats_qlearn_cliff['episode_rewards'],
                             np.ones(window)/window, mode='valid')

ax.plot(sarsa_smooth, label='SARSA (on-policy)', linewidth=2)
ax.plot(qlearn_smooth, label='Q-Learning (off-policy)', linewidth=2)
ax.set_xlabel('Episode')
ax.set_ylabel('Episode Reward (moving avg)')
ax.set_title('Learning Curves: SARSA vs Q-Learning on Cliff Walking')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nKey Observation:")
print("- SARSA learns a SAFE path (goes up and around the cliff)")
print("- Q-Learning learns the OPTIMAL path (along the cliff edge)")
print("- During training with ε-greedy, SARSA accounts for exploration risk")
print("- Q-Learning assumes greedy actions, so learns the shortest path")
```

</details>

## Conceptual Question: When to Use SARSA vs Q-Learning?

**Question**: You're building an RL agent to control a robot in a warehouse. The robot can fall off loading docks (bad!) but you want it to learn the fastest routes. Should you use SARSA or Q-Learning during training? Why?

<details>
<summary>Click to see hint</summary>

Think about:
- Safety during learning (robot can get damaged)
- The final deployed policy (will it be greedy or ε-greedy?)
- Whether the agent needs to account for its own exploration

</details>

<details>
<summary>Click to see answer</summary>

**Use SARSA during training** because:

1. **Safety matters**: SARSA is on-policy, meaning it learns about the ε-greedy policy it's actually following during training. This means it will learn to avoid dangerous areas (loading docks) even when exploring randomly.

2. **Realistic learning**: The robot will account for the fact that it sometimes takes random actions (due to ε-greedy), so the learned policy will be more cautious.

3. **Deployment consideration**: If you deploy with ε-greedy (small ε for safety), SARSA's learned policy is appropriate. If you deploy with pure greedy policy, you might want Q-Learning.

**However, Q-Learning could be better if**:
- You can train in simulation (no real damage from falls)
- You plan to deploy a fully greedy policy (ε=0)
- You want the theoretically optimal solution

**Best approach in practice**:
- Train with SARSA for safe exploration
- As training progresses, decay ε toward 0
- Or train with Q-Learning in simulation, then transfer to real robot with high ε initially

**Real-world consideration**: Many modern systems use off-policy methods (like Q-Learning) but with experience replay buffers and safety constraints, getting benefits of both approaches.

</details>

In [None]:
print("Congratulations! You've completed Part 5 of the RL Tutorial!")
print("\nKey takeaways:")
print("- TD methods update after every step using bootstrapping")
print("- SARSA is on-policy: learns about the policy it follows")
print("- Q-Learning is off-policy: learns optimal policy directly")
print("- Both are model-free and work for continuing tasks")
print("- Learning rate α controls the speed/stability trade-off")
print("\nNext: 06_algorithm_comparison.ipynb")