# SARSA 与 Q-Learning 对比

## 学习目标

通过本教程，你将掌握：
- SARSA 算法的原理与实现
- On-Policy 与 Off-Policy 的本质区别
- Expected SARSA 的方差减少机制
- 悬崖行走环境中的行为差异分析

## 前置知识

- Q-Learning 算法基础
- 时序差分学习概念

---

## 第一部分：On-Policy vs Off-Policy

### 1.1 基本概念

| 术语 | 定义 |
|------|------|
| **行为策略** (Behavior Policy) | 实际用来与环境交互的策略 |
| **目标策略** (Target Policy) | 正在学习和优化的策略 |

### 1.2 Off-Policy: Q-Learning

$$Q(S,A) \leftarrow Q(S,A) + \alpha[R + \gamma \max_a Q(S',a) - Q(S,A)]$$

- 使用 **max** 选择下一状态的最优动作
- 不管实际采取什么动作
- 行为策略 ≠ 目标策略

### 1.3 On-Policy: SARSA

$$Q(S,A) \leftarrow Q(S,A) + \alpha[R + \gamma Q(S',A') - Q(S,A)]$$

- 使用**实际采取的动作** $A'$
- 行为策略 = 目标策略

---

## 第二部分：SARSA 算法

### 2.1 算法名称来源

**S**tate-**A**ction-**R**eward-**S**tate-**A**ction

更新需要的五元组：$(S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1})$

### 2.2 算法伪代码

```
算法: SARSA

1. 初始化 Q(s, a) = 0
2. 对于每个回合:
   a. 初始化状态 S
   b. 使用策略选择动作 A  ← 关键：在循环外先选择
   c. 重复:
      i.   执行 A，观察 R, S'
      ii.  使用策略选择 A'  ← 关键：更新前先选择下一动作
      iii. Q(S, A) ← Q(S, A) + α[R + γ Q(S', A') - Q(S, A)]
      iv.  S ← S', A ← A'  ← 关键：动作传递
   d. 直到终止
```

---

## 第三部分：代码实现

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from typing import Tuple, List, Dict, Any

np.random.seed(42)
plt.rcParams['figure.figsize'] = (12, 5)

print("库导入完成")

In [None]:
class CliffWalkingEnv:
    """悬崖行走环境"""
    ACTIONS = {0: (-1, 0), 1: (0, 1), 2: (1, 0), 3: (0, -1)}
    
    def __init__(self, height=4, width=12):
        self.height, self.width = height, width
        self.start = (height - 1, 0)
        self.goal = (height - 1, width - 1)
        self.cliff = [(height - 1, j) for j in range(1, width - 1)]
        self.state = self.start
        self.n_actions = 4
    
    def reset(self):
        self.state = self.start
        return self.state
    
    def step(self, action):
        di, dj = self.ACTIONS[action]
        new_i = int(np.clip(self.state[0] + di, 0, self.height - 1))
        new_j = int(np.clip(self.state[1] + dj, 0, self.width - 1))
        next_state = (new_i, new_j)
        
        if next_state in self.cliff:
            self.state = self.start
            return self.state, -100.0, False
        
        self.state = next_state
        if self.state == self.goal:
            return self.state, 0.0, True
        return self.state, -1.0, False
    
    def render(self, path=None):
        grid = [['.' for _ in range(self.width)] for _ in range(self.height)]
        for pos in self.cliff: grid[pos[0]][pos[1]] = 'C'
        grid[self.start[0]][self.start[1]] = 'S'
        grid[self.goal[0]][self.goal[1]] = 'G'
        if path:
            for pos in path[1:-1]:
                if pos not in self.cliff and pos != self.start and pos != self.goal:
                    grid[pos[0]][pos[1]] = '*'
        print("┌" + "─" * (self.width * 2 + 1) + "┐")
        for row in grid: print("│ " + " ".join(row) + " │")
        print("└" + "─" * (self.width * 2 + 1) + "┘")

In [None]:
env = CliffWalkingEnv()
print("悬崖行走环境:")
env.render()

### 3.1 SARSA 智能体

In [None]:
class SARSAAgent:
    """
    SARSA 智能体 (On-Policy TD Control)
    
    与 Q-Learning 的关键区别:
    - 更新使用实际采取的下一个动作 A'
    - 学习当前策略的价值函数，而非最优策略
    """
    
    def __init__(self, n_actions, learning_rate=0.1, discount_factor=0.99,
                 epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        self.n_actions = n_actions
        self.lr = learning_rate
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.q_table = defaultdict(lambda: np.zeros(n_actions))

In [None]:
    def get_action(self, state, training=True):
        """ε-greedy 动作选择"""
        if training and np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        q_values = self.q_table[state]
        return np.random.choice(np.where(np.isclose(q_values, np.max(q_values)))[0])

In [None]:
    def update(self, state, action, reward, next_state, next_action, done):
        """
        SARSA 更新规则
        
        Q(S,A) ← Q(S,A) + α[R + γ Q(S',A') - Q(S,A)]
        
        注意：需要 next_action 参数
        """
        current_q = self.q_table[state][action]
        target = reward if done else reward + self.gamma * self.q_table[next_state][next_action]
        self.q_table[state][action] += self.lr * (target - current_q)
        return target - current_q

In [None]:
    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

# 组装类
SARSAAgent.get_action = get_action
SARSAAgent.update = update
SARSAAgent.decay_epsilon = decay_epsilon

### 3.2 Q-Learning 智能体 (对比用)

In [None]:
class QLearningAgent:
    """Q-Learning 智能体 (Off-Policy)"""
    
    def __init__(self, n_actions, learning_rate=0.1, discount_factor=0.99,
                 epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        self.n_actions = n_actions
        self.lr = learning_rate
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.q_table = defaultdict(lambda: np.zeros(n_actions))
    
    def get_action(self, state, training=True):
        if training and np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        q_values = self.q_table[state]
        return np.random.choice(np.where(np.isclose(q_values, np.max(q_values)))[0])
    
    def update(self, state, action, reward, next_state, done):
        """Q-Learning: 使用 max 而非实际动作"""
        current_q = self.q_table[state][action]
        target = reward if done else reward + self.gamma * np.max(self.q_table[next_state])
        self.q_table[state][action] += self.lr * (target - current_q)
    
    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

### 3.3 训练函数

In [None]:
def train_q_learning(env, agent, episodes=500, max_steps=200, verbose=False):
    """训练 Q-Learning"""
    history = {'rewards': [], 'steps': []}
    
    for episode in range(episodes):
        state = env.reset()
        total_reward, steps = 0.0, 0
        
        for _ in range(max_steps):
            action = agent.get_action(state, training=True)
            next_state, reward, done = env.step(action)
            agent.update(state, action, reward, next_state, done)
            state, total_reward, steps = next_state, total_reward + reward, steps + 1
            if done: break
        
        agent.decay_epsilon()
        history['rewards'].append(total_reward)
        history['steps'].append(steps)
        
        if verbose and (episode + 1) % 100 == 0:
            print(f"Q-Learning Episode {episode+1}: Avg = {np.mean(history['rewards'][-100:]):.2f}")
    
    return history

In [None]:
def train_sarsa(env, agent, episodes=500, max_steps=200, verbose=False):
    """训练 SARSA - 注意动作选择的时机"""
    history = {'rewards': [], 'steps': []}
    
    for episode in range(episodes):
        state = env.reset()
        action = agent.get_action(state, training=True)  # 关键：先选择初始动作
        total_reward, steps = 0.0, 0
        
        for _ in range(max_steps):
            next_state, reward, done = env.step(action)
            next_action = agent.get_action(next_state, training=True)  # 关键：更新前选择下一动作
            agent.update(state, action, reward, next_state, next_action, done)
            state, action = next_state, next_action  # 关键：动作传递
            total_reward += reward
            steps += 1
            if done: break
        
        agent.decay_epsilon()
        history['rewards'].append(total_reward)
        history['steps'].append(steps)
        
        if verbose and (episode + 1) % 100 == 0:
            print(f"SARSA Episode {episode+1}: Avg = {np.mean(history['rewards'][-100:]):.2f}")
    
    return history

---

## 第四部分：对比实验

In [None]:
print("="*60)
print("悬崖行走: Q-Learning vs SARSA 对比实验")
print("="*60)

EPISODES = 500
LEARNING_RATE = 0.5
EPSILON = 0.1  # 固定探索率

env = CliffWalkingEnv()

In [None]:
# 训练 Q-Learning
print("\n训练 Q-Learning...")
q_agent = QLearningAgent(n_actions=4, learning_rate=LEARNING_RATE, epsilon=EPSILON,
                          epsilon_decay=1.0, epsilon_min=EPSILON)
q_history = train_q_learning(env, q_agent, episodes=EPISODES, verbose=True)

In [None]:
# 训练 SARSA
print("\n训练 SARSA...")
sarsa_agent = SARSAAgent(n_actions=4, learning_rate=LEARNING_RATE, epsilon=EPSILON,
                          epsilon_decay=1.0, epsilon_min=EPSILON)
sarsa_history = train_sarsa(env, sarsa_agent, episodes=EPISODES, verbose=True)

### 4.1 学习曲线对比

In [None]:
def plot_comparison(q_history, sarsa_history, window=10):
    """绘制对比曲线"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 奖励曲线
    q_smooth = np.convolve(q_history['rewards'], np.ones(window)/window, mode='valid')
    sarsa_smooth = np.convolve(sarsa_history['rewards'], np.ones(window)/window, mode='valid')
    
    axes[0].plot(q_smooth, label='Q-Learning', color='blue', alpha=0.8)
    axes[0].plot(sarsa_smooth, label='SARSA', color='red', alpha=0.8)
    axes[0].axhline(y=-13, color='green', linestyle='--', alpha=0.5, label='最优 (-13)')
    axes[0].set_xlabel('Episode')
    axes[0].set_ylabel('Total Reward')
    axes[0].set_title('学习曲线对比')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # 步数曲线
    q_steps = np.convolve(q_history['steps'], np.ones(window)/window, mode='valid')
    sarsa_steps = np.convolve(sarsa_history['steps'], np.ones(window)/window, mode='valid')
    
    axes[1].plot(q_steps, label='Q-Learning', color='blue', alpha=0.8)
    axes[1].plot(sarsa_steps, label='SARSA', color='red', alpha=0.8)
    axes[1].set_xlabel('Episode')
    axes[1].set_ylabel('Steps')
    axes[1].set_title('回合步数对比')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
plot_comparison(q_history, sarsa_history)

print("\n最后100回合统计:")
print(f"Q-Learning: 平均奖励 = {np.mean(q_history['rewards'][-100:]):.2f}")
print(f"SARSA:      平均奖励 = {np.mean(sarsa_history['rewards'][-100:]):.2f}")

### 4.2 策略路径对比

In [None]:
def extract_path(agent, env, max_steps=50):
    """提取贪心策略路径"""
    state = env.reset()
    path = [state]
    for _ in range(max_steps):
        action = agent.get_action(state, training=False)
        next_state, _, done = env.step(action)
        path.append(next_state)
        state = next_state
        if done: break
    return path

In [None]:
print("\n" + "="*60)
print("学到的策略路径")
print("="*60)

print("\nQ-Learning (倾向沿悬崖边的最短路径):")
q_path = extract_path(q_agent, env)
env.reset()
env.render(q_path)
print(f"路径长度: {len(q_path) - 1} 步")

In [None]:
print("\nSARSA (倾向远离悬崖的安全路径):")
sarsa_path = extract_path(sarsa_agent, env)
env.reset()
env.render(sarsa_path)
print(f"路径长度: {len(sarsa_path) - 1} 步")

---

## 第五部分：行为差异分析

### 5.1 为什么 Q-Learning 选择悬崖边？

Q-Learning 更新使用 $\max$，学习**最优策略的价值**：

- 假设执行最优策略，不会掉入悬崖
- 沿悬崖边的路径最短，奖励最高
- 但训练时的 ε-greedy 探索会导致掉崖

### 5.2 为什么 SARSA 选择安全路径？

SARSA 使用实际动作，学习**当前策略的价值**：

- 考虑到探索时可能随机选择动作
- 靠近悬崖时，探索可能导致掉落
- 因此远离悬崖的路径价值更高

In [None]:
def visualize_values(q_agent, sarsa_agent, env):
    """对比价值函数"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for idx, (agent, name) in enumerate([(q_agent, 'Q-Learning'), (sarsa_agent, 'SARSA')]):
        v_table = np.zeros((env.height, env.width))
        for i in range(env.height):
            for j in range(env.width):
                if (i, j) in agent.q_table:
                    v_table[i, j] = np.max(agent.q_table[(i, j)])
        
        im = axes[idx].imshow(v_table, cmap='RdYlGn', aspect='auto')
        axes[idx].set_title(f'{name} 价值函数 V(s)')
        plt.colorbar(im, ax=axes[idx])
        
        for pos in env.cliff:
            axes[idx].add_patch(plt.Rectangle((pos[1]-0.5, pos[0]-0.5), 1, 1,
                                               fill=True, color='black', alpha=0.5))
    
    plt.tight_layout()
    plt.show()

In [None]:
visualize_values(q_agent, sarsa_agent, env)

---

## 第六部分：Expected SARSA

### 6.1 算法原理

$$Q(S,A) \leftarrow Q(S,A) + \alpha \left[ R + \gamma \mathbb{E}_\pi[Q(S',A')] - Q(S,A) \right]$$

对于 ε-greedy:

$$\mathbb{E}[Q(S',A')] = \frac{\epsilon}{|A|} \sum_a Q(S',a) + (1-\epsilon) \max_a Q(S',a)$$

In [None]:
class ExpectedSARSAAgent:
    """Expected SARSA 智能体"""
    
    def __init__(self, n_actions, learning_rate=0.1, discount_factor=0.99,
                 epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        self.n_actions = n_actions
        self.lr = learning_rate
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.q_table = defaultdict(lambda: np.zeros(n_actions))
    
    def get_action(self, state, training=True):
        if training and np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        q_values = self.q_table[state]
        return np.random.choice(np.where(np.isclose(q_values, np.max(q_values)))[0])
    
    def _get_expected_q(self, state):
        """计算期望 Q 值"""
        q_values = self.q_table[state]
        probs = np.ones(self.n_actions) * self.epsilon / self.n_actions
        probs[np.argmax(q_values)] += 1 - self.epsilon
        return np.dot(probs, q_values)
    
    def update(self, state, action, reward, next_state, done):
        current_q = self.q_table[state][action]
        target = reward if done else reward + self.gamma * self._get_expected_q(next_state)
        self.q_table[state][action] += self.lr * (target - current_q)
    
    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

In [None]:
print("\n训练 Expected SARSA...")
exp_sarsa = ExpectedSARSAAgent(n_actions=4, learning_rate=LEARNING_RATE,
                                epsilon=EPSILON, epsilon_decay=1.0, epsilon_min=EPSILON)
exp_sarsa_history = train_q_learning(env, exp_sarsa, episodes=EPISODES, verbose=True)

print("\nExpected SARSA 学到的路径:")
exp_sarsa_path = extract_path(exp_sarsa, env)
env.reset()
env.render(exp_sarsa_path)

---

## 总结

### 算法对比

| 特性 | Q-Learning | SARSA | Expected SARSA |
|------|------------|-------|----------------|
| 类型 | Off-Policy | On-Policy | On-Policy |
| 更新目标 | $\max_a Q(S',a)$ | $Q(S',A')$ | $\mathbb{E}[Q(S',A')]$ |
| 方差 | 低 | 高 | 低 |
| 安全性 | 激进 | 保守 | 中等 |

### 选择建议

- **Q-Learning**: 追求最优性能
- **SARSA**: 需要安全探索（机器人控制）
- **Expected SARSA**: 平衡方案

---

## 单元测试

In [None]:
def run_tests():
    print("开始单元测试...\n")
    passed = 0
    
    # 测试1: SARSA 更新
    agent = SARSAAgent(n_actions=4, learning_rate=0.5, discount_factor=0.9)
    agent.q_table[(0,1)] = np.array([1.0, 2.0, 0.0, 0.0])
    agent.update((0,0), 0, -1.0, (0,1), 1, False)
    expected = 0.5 * (-1 + 0.9 * 2.0)  # 0.4
    assert np.isclose(agent.q_table[(0,0)][0], expected), "SARSA更新错误"
    print("✓ 测试1: SARSA 更新正确")
    passed += 1
    
    # 测试2: Expected SARSA 期望计算
    agent = ExpectedSARSAAgent(n_actions=4, epsilon=0.2)
    agent.q_table[(0,0)] = np.array([1.0, 2.0, 0.5, 0.5])
    expected_q = agent._get_expected_q((0,0))
    manual = 0.05*1.0 + 0.85*2.0 + 0.05*0.5 + 0.05*0.5  # 1.8
    assert np.isclose(expected_q, manual), f"期望计算错误: {expected_q}"
    print("✓ 测试2: Expected SARSA 期望计算正确")
    passed += 1
    
    print(f"\n全部 {passed} 项测试通过!")

run_tests()

---

## 参考资料

1. Rummery & Niranjan (1994). On-Line Q-Learning Using Connectionist Systems.
2. Sutton & Barto (2018). Reinforcement Learning: An Introduction, Chapter 6.
3. Van Seijen et al. (2009). A Theoretical and Empirical Analysis of Expected Sarsa.