In [1]:
import gym
import numpy as np
import random


def choose_action(state, Q, epsilon, action_space):
    """
    使用 ε-贪婪策略选择动作。
    """
    if random.uniform(0, 1) < epsilon:
        return action_space.sample()  # 探索
    else:
        return np.argmax(Q[state, :])  # 利用


def sarsa(env, num_episodes=1000, alpha=0.1, gamma=0.99, epsilon=0.1):
    """
    SARSA 算法实现。
    """
    # 初始化 Q 表
    Q = np.zeros((env.observation_space.n, env.action_space.n))

    for episode in range(num_episodes):
        # 重置环境并设置种子
        state, info = env.reset(seed=episode)  # 为每个回合设置不同的种子以确保多样性
        action = choose_action(state, Q, epsilon, env.action_space)
        done = False
        step = 0

        while not done:
            next_state, reward, done, truncated, info = env.step(action)

            if done or truncated:
                # 回合结束，只使用即时奖励更新 Q 值
                Q[state, action] += alpha * (reward - Q[state, action])
            else:
                # 选择下一个动作
                next_action = choose_action(next_state, Q, epsilon, env.action_space)
                # 更新 Q 值
                Q[state, action] += alpha * (reward + gamma * Q[next_state, next_action] - Q[state, action])
                # 更新 state 和 action
                state = next_state
                action = next_action

            step += 1

        # 每100个回合输出一次信息
        if (episode + 1) % 100 == 0:
            print(f"回合数: {episode + 1}, 步数: {step}")

    return Q


def evaluate_policy(env, Q, num_episodes=100):
    """
    使用学习到的 Q 表评估策略。
    """
    total_rewards = 0
    for episode in range(num_episodes):
        state, info = env.reset(seed=episode)  # 为评估设置种子
        done = False
        while not done:
            action = np.argmax(Q[state, :])
            next_state, reward, done, truncated, info = env.step(action)
            total_rewards += reward
            state = next_state
            if done or truncated:
                break
    average_reward = total_rewards / num_episodes
    print(f"平均奖励: {average_reward}")


def main():
    # 创建环境
    env = gym.make('Taxi-v3')

    # 设置随机种子以确保结果可重复
    np.random.seed(0)
    random.seed(0)
    # Gym 的种子现在在 reset 时设置，因此不需要在这里设置

    # 运行 SARSA 算法
    Q = sarsa(env, num_episodes=10000, alpha=0.1, gamma=0.99, epsilon=0.1)

    # 评估学习到的策略
    evaluate_policy(env, Q, num_episodes=100)

    # 关闭环境
    env.close()


if __name__ == "__main__":
    main()


  if not isinstance(terminated, (bool, np.bool8)):


KeyboardInterrupt: 