# Q-Learning 高级技巧与实战

---

## 学习目标

通过本教程，你将学会：
- 理解 Q-Learning 过估计问题及 Double Q-Learning 解决方案
- 掌握学习率调度和资格迹等高级技巧
- 使用 Gymnasium 标准环境进行训练
- 在 Taxi-v3 环境中实现完整训练流程
- 模型保存、加载与评估

## 前置知识

- Q-Learning 和 SARSA 基础
- Python 面向对象编程
- NumPy 和 Matplotlib

## 预计时间

50-70 分钟

---

## 第1部分：Q-Learning 的过估计问题

### 1.1 过估计现象

Q-Learning 更新使用 $\max$ 操作：

$$Q(S,A) \leftarrow Q(S,A) + \alpha[R + \gamma \max_a Q(S',a) - Q(S,A)]$$

**问题**：$\max$ 操作会系统性地高估 Q 值

**原因分析**：
- 假设所有 Q 值估计都有噪声：$\hat{Q}(s,a) = Q^*(s,a) + \epsilon_a$
- $\max_a \hat{Q}(s,a) \geq \max_a Q^*(s,a)$（因为噪声可能使某些估计偏高）
- 这种偏差会通过 bootstrapping 传播和累积

### 1.2 Double Q-Learning

**核心思想**：解耦动作选择和价值评估

维护两个 Q 表 $Q_1$ 和 $Q_2$，交替更新：

- 更新 $Q_1$：用 $Q_1$ 选择动作，用 $Q_2$ 评估
  $$Q_1(S,A) \leftarrow Q_1(S,A) + \alpha[R + \gamma Q_2(S', \arg\max_a Q_1(S',a)) - Q_1(S,A)]$$

- 更新 $Q_2$：用 $Q_2$ 选择动作，用 $Q_1$ 评估
  $$Q_2(S,A) \leftarrow Q_2(S,A) + \alpha[R + \gamma Q_1(S', \arg\max_a Q_2(S',a)) - Q_2(S,A)]$$

---

## 第2部分：代码实现

### 步骤1: 导入库

In [None]:
# ============================================================
# 导入必要的库
# ============================================================

import numpy as np
from collections import defaultdict
from typing import Tuple, List, Dict, Any, Optional
from dataclasses import dataclass, field
import json
import pickle
from pathlib import Path
import matplotlib.pyplot as plt

# 尝试导入 Gymnasium
try:
    import gymnasium as gym
    HAS_GYM = True
    print(f"Gymnasium 版本: {gym.__version__}")
except ImportError:
    HAS_GYM = False
    print("请安装 gymnasium: pip install gymnasium")

# 设置随机种子
np.random.seed(42)

# 可视化配置
plt.rcParams['figure.figsize'] = (12, 5)
plt.rcParams['font.size'] = 11

print("\n库导入完成")

### 步骤2: 实现 Double Q-Learning

In [None]:
@dataclass
class TrainingMetrics:
    """训练指标记录"""
    episode_rewards: List[float] = field(default_factory=list)
    episode_lengths: List[int] = field(default_factory=list)
    epsilon_history: List[float] = field(default_factory=list)
    
    def get_moving_average(self, window: int = 100) -> np.ndarray:
        """计算移动平均"""
        if len(self.episode_rewards) < window:
            return np.array(self.episode_rewards)
        return np.convolve(
            self.episode_rewards,
            np.ones(window) / window,
            mode='valid'
        )


class DoubleQLearningAgent:
    """
    Double Q-Learning 智能体
    
    通过维护两个 Q 表，解耦动作选择和价值评估，减少过估计偏差。
    
    Attributes:
        q_table1, q_table2: 两个独立的 Q 表
        lr: 学习率
        gamma: 折扣因子
        epsilon: 探索率
    """
    
    def __init__(
        self,
        n_actions: int,
        learning_rate: float = 0.1,
        discount_factor: float = 0.99,
        epsilon: float = 1.0,
        epsilon_decay: float = 0.995,
        epsilon_min: float = 0.01
    ):
        self.n_actions = n_actions
        self.lr = learning_rate
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        
        # 两个独立的 Q 表
        self.q_table1: Dict[Any, np.ndarray] = defaultdict(
            lambda: np.zeros(n_actions)
        )
        self.q_table2: Dict[Any, np.ndarray] = defaultdict(
            lambda: np.zeros(n_actions)
        )
        
        self.metrics = TrainingMetrics()
        
    def get_action(self, state: Any, training: bool = True) -> int:
        """使用两个 Q 表的和选择动作"""
        if training and np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        
        # 使用两个 Q 表的和
        combined_q = self.q_table1[state] + self.q_table2[state]
        max_q = np.max(combined_q)
        max_actions = np.where(np.isclose(combined_q, max_q))[0]
        return np.random.choice(max_actions)
    
    def update(
        self,
        state: Any,
        action: int,
        reward: float,
        next_state: Any,
        done: bool
    ) -> float:
        """
        Double Q-Learning 更新
        
        随机选择更新 Q1 或 Q2，解耦选择和评估。
        """
        if np.random.random() < 0.5:
            # 更新 Q1：用 Q1 选择动作，Q2 评估
            current_q = self.q_table1[state][action]
            if done:
                target = reward
            else:
                best_action = np.argmax(self.q_table1[next_state])
                target = reward + self.gamma * self.q_table2[next_state][best_action]
            td_error = target - current_q
            self.q_table1[state][action] += self.lr * td_error
        else:
            # 更新 Q2：用 Q2 选择动作，Q1 评估
            current_q = self.q_table2[state][action]
            if done:
                target = reward
            else:
                best_action = np.argmax(self.q_table2[next_state])
                target = reward + self.gamma * self.q_table1[next_state][best_action]
            td_error = target - current_q
            self.q_table2[state][action] += self.lr * td_error
        
        return td_error
    
    def decay_epsilon(self) -> None:
        """衰减探索率"""
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
    
    def get_q_values(self, state: Any) -> np.ndarray:
        """获取状态的平均 Q 值"""
        return (self.q_table1[state] + self.q_table2[state]) / 2


print("Double Q-Learning 智能体定义完成")

### 步骤3: 实现标准 Q-Learning (对比用)

In [None]:
class QLearningAgent:
    """标准 Q-Learning 智能体"""
    
    def __init__(
        self,
        n_actions: int,
        learning_rate: float = 0.1,
        discount_factor: float = 0.99,
        epsilon: float = 1.0,
        epsilon_decay: float = 0.995,
        epsilon_min: float = 0.01
    ):
        self.n_actions = n_actions
        self.lr = learning_rate
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.q_table: Dict[Any, np.ndarray] = defaultdict(
            lambda: np.zeros(n_actions)
        )
        self.metrics = TrainingMetrics()
        
    def get_action(self, state: Any, training: bool = True) -> int:
        if training and np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        q_values = self.q_table[state]
        max_q = np.max(q_values)
        max_actions = np.where(np.isclose(q_values, max_q))[0]
        return np.random.choice(max_actions)
    
    def update(self, state, action, reward, next_state, done) -> float:
        current_q = self.q_table[state][action]
        if done:
            target = reward
        else:
            target = reward + self.gamma * np.max(self.q_table[next_state])
        td_error = target - current_q
        self.q_table[state][action] += self.lr * td_error
        return td_error
    
    def decay_epsilon(self) -> None:
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
    
    def save(self, filepath: str) -> None:
        """保存模型"""
        data = {
            'q_table': {str(k): v.tolist() for k, v in self.q_table.items()},
            'epsilon': self.epsilon,
            'config': {
                'n_actions': self.n_actions,
                'lr': self.lr,
                'gamma': self.gamma
            }
        }
        with open(filepath, 'w') as f:
            json.dump(data, f, indent=2)
        print(f"模型已保存到: {filepath}")
    
    def load(self, filepath: str) -> None:
        """加载模型"""
        with open(filepath, 'r') as f:
            data = json.load(f)
        
        self.q_table = defaultdict(lambda: np.zeros(self.n_actions))
        for k, v in data['q_table'].items():
            try:
                key = eval(k)
            except:
                key = int(k) if k.isdigit() else k
            self.q_table[key] = np.array(v)
        
        self.epsilon = data.get('epsilon', 0.01)
        print(f"模型已从 {filepath} 加载")


print("标准 Q-Learning 智能体定义完成")

---

## 第3部分：Gymnasium 环境实战

### 3.1 Taxi-v3 环境介绍

Taxi-v3 是 Gymnasium 内置的经典强化学习环境：

```
+---------+
|R: | : :G|    R, G, Y, B: 乘客位置/目的地
| : | : : |    |: 墙壁
| : : : : |    黄色方块: 出租车位置
| | : | : |
|Y| : |B: |
+---------+

状态空间: 500 种状态
  - 出租车位置: 25 (5x5)
  - 乘客位置: 5 (R,G,Y,B 或在车上)
  - 目的地: 4 (R,G,Y,B)

动作空间: 6 种动作
  - 0: 向南移动
  - 1: 向北移动
  - 2: 向东移动
  - 3: 向西移动
  - 4: 接乘客
  - 5: 放乘客

奖励设计:
  - 每步: -1
  - 成功送达: +20
  - 非法接/放: -10
```

In [None]:
# 探索 Taxi 环境
if HAS_GYM:
    env = gym.make('Taxi-v3', render_mode='ansi')
    
    print("Taxi-v3 环境信息:")
    print(f"  状态空间大小: {env.observation_space.n}")
    print(f"  动作空间大小: {env.action_space.n}")
    
    # 重置环境并显示
    state, info = env.reset(seed=42)
    print(f"\n初始状态: {state}")
    print(f"\n环境渲染:")
    print(env.render())
    
    # 动作说明
    action_names = ['南', '北', '东', '西', '接乘客', '放乘客']
    print("\n动作说明:")
    for i, name in enumerate(action_names):
        print(f"  {i}: {name}")
    
    env.close()
else:
    print("跳过 Taxi 环境演示 (需要 gymnasium)")

### 3.2 训练函数

In [None]:
def train_agent(
    env,
    agent,
    episodes: int = 2000,
    max_steps: int = 200,
    verbose: bool = True,
    log_interval: int = 200
) -> TrainingMetrics:
    """
    通用训练函数
    
    支持 Gymnasium 环境和自定义环境。
    """
    metrics = TrainingMetrics()
    
    for episode in range(episodes):
        # 重置环境
        result = env.reset()
        state = result[0] if isinstance(result, tuple) else result
        
        total_reward = 0.0
        steps = 0
        
        for _ in range(max_steps):
            action = agent.get_action(state, training=True)
            
            # 执行动作
            result = env.step(action)
            if len(result) == 3:
                next_state, reward, done = result
            else:
                next_state, reward, terminated, truncated, _ = result
                done = terminated or truncated
            
            # 更新 Q 值
            agent.update(state, action, reward, next_state, done)
            
            state = next_state
            total_reward += reward
            steps += 1
            
            if done:
                break
        
        agent.decay_epsilon()
        
        metrics.episode_rewards.append(total_reward)
        metrics.episode_lengths.append(steps)
        metrics.epsilon_history.append(agent.epsilon)
        
        if verbose and (episode + 1) % log_interval == 0:
            avg_reward = np.mean(metrics.episode_rewards[-log_interval:])
            avg_steps = np.mean(metrics.episode_lengths[-log_interval:])
            print(f"Episode {episode + 1:4d} | "
                  f"Avg Reward: {avg_reward:8.2f} | "
                  f"Avg Steps: {avg_steps:6.1f} | "
                  f"ε: {agent.epsilon:.4f}")
    
    agent.metrics = metrics
    return metrics


print("训练函数定义完成")

### 3.3 在 Taxi 环境训练

In [None]:
if HAS_GYM:
    print("="*60)
    print("Taxi-v3 Q-Learning 训练")
    print("="*60)
    
    # 创建环境和智能体
    env = gym.make('Taxi-v3')
    
    agent = QLearningAgent(
        n_actions=env.action_space.n,
        learning_rate=0.1,
        discount_factor=0.99,
        epsilon=1.0,
        epsilon_decay=0.995,
        epsilon_min=0.01
    )
    
    # 训练
    metrics = train_agent(env, agent, episodes=2000, verbose=True)
    
    env.close()
    
    print(f"\n训练完成！")
    print(f"最后100回合平均奖励: {np.mean(metrics.episode_rewards[-100:]):.2f}")
    print(f"Q表大小: {len(agent.q_table)} 状态")
else:
    print("跳过 Taxi 训练 (需要 gymnasium)")

### 3.4 可视化训练过程

In [None]:
def plot_training_metrics(metrics: TrainingMetrics, title: str = "训练曲线"):
    """绘制训练指标"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    window = 50
    
    # 奖励曲线
    rewards = metrics.episode_rewards
    smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
    axes[0].plot(rewards, alpha=0.3, color='blue')
    axes[0].plot(range(window-1, len(rewards)), smoothed, color='blue', linewidth=2)
    axes[0].set_xlabel('Episode')
    axes[0].set_ylabel('Total Reward')
    axes[0].set_title('回合奖励')
    axes[0].grid(True, alpha=0.3)
    
    # 步数曲线
    steps = metrics.episode_lengths
    smoothed_steps = np.convolve(steps, np.ones(window)/window, mode='valid')
    axes[1].plot(steps, alpha=0.3, color='green')
    axes[1].plot(range(window-1, len(steps)), smoothed_steps, color='green', linewidth=2)
    axes[1].set_xlabel('Episode')
    axes[1].set_ylabel('Steps')
    axes[1].set_title('回合步数')
    axes[1].grid(True, alpha=0.3)
    
    # 探索率衰减
    axes[2].plot(metrics.epsilon_history, color='red')
    axes[2].set_xlabel('Episode')
    axes[2].set_ylabel('Epsilon')
    axes[2].set_title('探索率衰减')
    axes[2].grid(True, alpha=0.3)
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


if HAS_GYM:
    plot_training_metrics(metrics, "Taxi-v3 Q-Learning 训练曲线")

### 3.5 评估训练好的智能体

In [None]:
def evaluate_agent(
    env,
    agent,
    episodes: int = 100,
    render: bool = False
) -> Dict[str, float]:
    """
    评估智能体性能
    
    Returns:
        包含评估指标的字典
    """
    rewards = []
    steps = []
    successes = 0
    
    for ep in range(episodes):
        result = env.reset()
        state = result[0] if isinstance(result, tuple) else result
        
        total_reward = 0
        ep_steps = 0
        
        while True:
            action = agent.get_action(state, training=False)  # 贪心策略
            
            result = env.step(action)
            if len(result) == 3:
                next_state, reward, done = result
            else:
                next_state, reward, terminated, truncated, _ = result
                done = terminated or truncated
            
            total_reward += reward
            ep_steps += 1
            state = next_state
            
            if render and ep < 3:
                print(env.render())
            
            if done:
                if reward == 20:  # Taxi 成功送达奖励
                    successes += 1
                break
            
            if ep_steps >= 200:
                break
        
        rewards.append(total_reward)
        steps.append(ep_steps)
    
    results = {
        'mean_reward': np.mean(rewards),
        'std_reward': np.std(rewards),
        'mean_steps': np.mean(steps),
        'success_rate': successes / episodes * 100
    }
    
    return results


if HAS_GYM:
    print("\n" + "="*50)
    print("评估训练好的智能体")
    print("="*50)
    
    env = gym.make('Taxi-v3')
    results = evaluate_agent(env, agent, episodes=100)
    
    print(f"\n评估结果 (100回合):")
    print(f"  平均奖励: {results['mean_reward']:.2f} ± {results['std_reward']:.2f}")
    print(f"  平均步数: {results['mean_steps']:.1f}")
    print(f"  成功率: {results['success_rate']:.1f}%")
    
    env.close()

### 3.6 演示训练好的智能体

In [None]:
def demo_agent(env, agent, episodes: int = 2):
    """演示智能体行为"""
    action_names = ['南', '北', '东', '西', '接', '放']
    
    for ep in range(episodes):
        print(f"\n{'='*40}")
        print(f"演示回合 {ep + 1}")
        print(f"{'='*40}")
        
        result = env.reset()
        state = result[0] if isinstance(result, tuple) else result
        
        total_reward = 0
        steps = 0
        
        print(f"\n初始状态:")
        print(env.render())
        
        while steps < 20:  # 限制显示步数
            action = agent.get_action(state, training=False)
            
            result = env.step(action)
            if len(result) == 3:
                next_state, reward, done = result
            else:
                next_state, reward, terminated, truncated, _ = result
                done = terminated or truncated
            
            total_reward += reward
            steps += 1
            
            print(f"\n步骤 {steps}: 动作={action_names[action]}, 奖励={reward}")
            print(env.render())
            
            state = next_state
            
            if done:
                print(f"\n回合结束！总奖励: {total_reward}, 步数: {steps}")
                break


if HAS_GYM:
    env = gym.make('Taxi-v3', render_mode='ansi')
    demo_agent(env, agent, episodes=1)
    env.close()

---

## 第4部分：高级技巧

### 4.1 学习率调度

固定学习率可能导致：
- 太大：Q值震荡，不稳定
- 太小：收敛过慢

**解决方案**：基于访问次数的衰减学习率

In [None]:
class AdaptiveLRQLearning(QLearningAgent):
    """自适应学习率的 Q-Learning"""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # 状态-动作访问计数
        self.visit_count: Dict[Any, np.ndarray] = defaultdict(
            lambda: np.zeros(self.n_actions)
        )
        
    def get_learning_rate(self, state, action) -> float:
        """
        基于访问次数的衰减学习率
        
        α(s,a) = 1 / (1 + N(s,a))
        """
        count = self.visit_count[state][action]
        return 1.0 / (1.0 + count)
    
    def update(self, state, action, reward, next_state, done) -> float:
        # 更新访问计数
        self.visit_count[state][action] += 1
        
        # 使用自适应学习率
        lr = self.get_learning_rate(state, action)
        
        current_q = self.q_table[state][action]
        if done:
            target = reward
        else:
            target = reward + self.gamma * np.max(self.q_table[next_state])
        
        td_error = target - current_q
        self.q_table[state][action] += lr * td_error
        
        return td_error


print("自适应学习率 Q-Learning 定义完成")

### 4.2 Double Q-Learning 对比实验

In [None]:
if HAS_GYM:
    print("="*60)
    print("Q-Learning vs Double Q-Learning 对比实验")
    print("="*60)
    
    env = gym.make('Taxi-v3')
    
    # 标准 Q-Learning
    print("\n训练标准 Q-Learning...")
    q_agent = QLearningAgent(
        n_actions=env.action_space.n,
        learning_rate=0.1,
        epsilon=1.0,
        epsilon_decay=0.995,
        epsilon_min=0.01
    )
    q_metrics = train_agent(env, q_agent, episodes=1000, verbose=False)
    
    # Double Q-Learning
    print("训练 Double Q-Learning...")
    double_q_agent = DoubleQLearningAgent(
        n_actions=env.action_space.n,
        learning_rate=0.1,
        epsilon=1.0,
        epsilon_decay=0.995,
        epsilon_min=0.01
    )
    double_q_metrics = train_agent(env, double_q_agent, episodes=1000, verbose=False)
    
    env.close()
    
    # 绘制对比
    fig, ax = plt.subplots(figsize=(10, 5))
    
    window = 50
    q_smooth = np.convolve(q_metrics.episode_rewards, np.ones(window)/window, mode='valid')
    double_smooth = np.convolve(double_q_metrics.episode_rewards, np.ones(window)/window, mode='valid')
    
    ax.plot(q_smooth, label='Q-Learning', alpha=0.8)
    ax.plot(double_smooth, label='Double Q-Learning', alpha=0.8)
    ax.set_xlabel('Episode')
    ax.set_ylabel('Total Reward')
    ax.set_title('Q-Learning vs Double Q-Learning on Taxi-v3')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.show()
    
    print(f"\n最后100回合平均奖励:")
    print(f"  Q-Learning: {np.mean(q_metrics.episode_rewards[-100:]):.2f}")
    print(f"  Double Q-Learning: {np.mean(double_q_metrics.episode_rewards[-100:]):.2f}")

---

## 第5部分：模型保存与加载

In [None]:
if HAS_GYM:
    # 保存模型
    model_path = "taxi_q_learning_model.json"
    agent.save(model_path)
    
    # 创建新智能体并加载模型
    new_agent = QLearningAgent(n_actions=6)
    new_agent.load(model_path)
    
    # 验证加载的模型
    env = gym.make('Taxi-v3')
    results = evaluate_agent(env, new_agent, episodes=50)
    print(f"\n加载模型后的评估:")
    print(f"  平均奖励: {results['mean_reward']:.2f}")
    print(f"  成功率: {results['success_rate']:.1f}%")
    env.close()
    
    # 清理
    import os
    if os.path.exists(model_path):
        os.remove(model_path)
        print(f"\n已清理临时模型文件")

---

## 总结

### 核心要点

1. **过估计问题**：Q-Learning 的 max 操作导致系统性高估
2. **Double Q-Learning**：通过解耦选择和评估减少偏差
3. **学习率调度**：基于访问次数自适应调整
4. **Gymnasium**：标准化的 RL 环境接口

### 超参数调优建议

| 参数 | 建议范围 | 说明 |
|------|----------|------|
| 学习率 | 0.05-0.5 | 表格型可用较大值 |
| 折扣因子 | 0.95-0.99 | 任务越长期越接近1 |
| 初始探索率 | 1.0 | 从完全探索开始 |
| 最终探索率 | 0.01-0.1 | 保持少量探索 |
| 衰减率 | 0.99-0.999 | 控制探索下降速度 |

### 表格型方法局限

- 状态空间必须离散且有限
- 无法处理高维/连续状态
- 无法泛化到未见状态

**解决方案**：深度 Q 网络 (DQN) - 用神经网络近似 Q 函数

---

## 单元测试

In [None]:
def run_tests():
    """运行单元测试"""
    print("开始单元测试...\n")
    passed = 0
    failed = 0
    
    # 测试1: Double Q-Learning 更新
    try:
        agent = DoubleQLearningAgent(n_actions=4, learning_rate=0.5)
        state = (0, 0)
        next_state = (0, 1)
        
        # 多次更新，验证两个Q表都被更新
        np.random.seed(42)
        for _ in range(10):
            agent.update(state, 0, -1.0, next_state, False)
        
        # 验证两个Q表都有更新
        assert agent.q_table1[state][0] != 0 or agent.q_table2[state][0] != 0
        print("测试1通过: Double Q-Learning 更新正确")
        passed += 1
    except AssertionError as e:
        print(f"测试1失败: {e}")
        failed += 1
    
    # 测试2: 自适应学习率
    try:
        agent = AdaptiveLRQLearning(n_actions=4)
        state = (0, 0)
        
        # 初始学习率应为 1.0
        lr1 = agent.get_learning_rate(state, 0)
        assert np.isclose(lr1, 1.0), f"初始学习率错误: {lr1}"
        
        # 更新后学习率应衰减
        agent.update(state, 0, -1.0, (0, 1), False)
        lr2 = agent.get_learning_rate(state, 0)
        assert lr2 < lr1, "学习率应该衰减"
        
        print("测试2通过: 自适应学习率正确")
        passed += 1
    except AssertionError as e:
        print(f"测试2失败: {e}")
        failed += 1
    
    # 测试3: Gymnasium 环境兼容性
    if HAS_GYM:
        try:
            env = gym.make('Taxi-v3')
            agent = QLearningAgent(n_actions=env.action_space.n)
            
            # 运行一个回合
            state, _ = env.reset()
            for _ in range(10):
                action = agent.get_action(state)
                next_state, reward, terminated, truncated, _ = env.step(action)
                agent.update(state, action, reward, next_state, terminated or truncated)
                state = next_state
                if terminated or truncated:
                    break
            
            env.close()
            print("测试3通过: Gymnasium 兼容性正确")
            passed += 1
        except Exception as e:
            print(f"测试3失败: {e}")
            failed += 1
    else:
        print("测试3跳过: 需要 gymnasium")
    
    # 测试4: 模型保存/加载
    try:
        agent = QLearningAgent(n_actions=4)
        state = (0, 0)
        agent.q_table[state] = np.array([1.0, 2.0, 3.0, 4.0])
        
        # 保存
        test_path = "_test_model.json"
        agent.save(test_path)
        
        # 加载
        new_agent = QLearningAgent(n_actions=4)
        new_agent.load(test_path)
        
        # 验证
        assert np.allclose(new_agent.q_table[state], agent.q_table[state])
        
        # 清理
        import os
        os.remove(test_path)
        
        print("测试4通过: 模型保存/加载正确")
        passed += 1
    except Exception as e:
        print(f"测试4失败: {e}")
        failed += 1
    
    print(f"\n{'='*50}")
    print(f"测试完成: {passed} 通过, {failed} 失败")
    if failed == 0:
        print("所有测试通过！")
    print(f"{'='*50}")
    
    return failed == 0


run_tests()

---

## 参考资料

1. Van Hasselt, H. (2010). Double Q-learning. NeurIPS.
2. Sutton, R.S. & Barto, A.G. (2018). Reinforcement Learning: An Introduction, 2nd ed.
3. [Gymnasium Documentation](https://gymnasium.farama.org/)
4. [OpenAI Spinning Up](https://spinningup.openai.com/)

---

[返回目录](../README.md)