In [None]:
# 安装必要库（在终端执行）
# pip install gym-super-mario-bros stable-baselines3 nes-py torch

import gym
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_util import make_vec_env

# 1. 创建环境预处理函数
def make_env():
    # 创建基础环境
    env = gym.make('SuperMarioBros-v3', apply_api_compatibility=True, render_mode='rgb_array')
    # 简化动作空间
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    return env

# 2. 设置训练参数
n_envs = 4                  # 并行环境数量
total_timesteps = 1_000_000  # 总训练步数
checkpoint_freq = 50_000     # 保存模型的间隔步数
save_path = "./mario_ppo"    # 保存路径

# 3. 创建向量化环境
env = make_vec_env(
    make_env,
    n_envs=n_envs,
    vec_env_cls=DummyVecEnv
)

# 堆叠4帧画面
env = VecFrameStack(env, n_stack=4)

# 4. 创建PPO模型
model = PPO(
    "CnnPolicy",
    env,
    verbose=1,
    learning_rate=1e-4,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    tensorboard_log="./mario_tensorboard/"
)

# 5. 设置检查点回调
checkpoint_callback = CheckpointCallback(
    save_freq=max(checkpoint_freq // n_envs, 1),
    save_path=save_path,
    name_prefix="mario_ppo"
)

# 6. 开始训练
model.learn(
    total_timesteps=total_timesteps,
    callback=checkpoint_callback,
    tb_log_name="first_run"
)

# 7. 保存最终模型
model.save(f"{save_path}/mario_ppo_final")

# 8. 测试训练结果
test_env = make_env()
test_env = Monitor(test_env)
test_env = JoypadSpace(test_env, SIMPLE_MOVEMENT)
test_env = DummyVecEnv([lambda: test_env])
test_env = VecFrameStack(test_env, n_stack=4)

obs = test_env.reset()
for _ in range(10000):
    action, _ = model.predict(obs)
    obs, rewards, dones, info = test_env.step(action)
    test_env.render()

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


  logger.warn(


OverflowError: Python integer 1024 out of bounds for uint8