In [1]:
!pip install gymnasium stable-baselines3 moviepy pyvirtualdisplay


Collecting gymnasium
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting stable-baselines3
  Downloading stable_baselines3-2.3.2-py3-none-any.whl (182 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.3/182.3 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
Collecting pyvirtualdisplay
  Downloading PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.13->stable-baselines3)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.13->stable-baselines3)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-

In [2]:
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from gymnasium.wrappers import RecordVideo
import os
from IPython.display import Video

# 创建 CartPole 环境
env = gym.make('CartPole-v1')

# 创建 PPO 模型
model = PPO('MlpPolicy', env, verbose=1)

# 定义一个函数来计算回报
def compute_returns(rewards, gamma):
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    return returns

# 采样多个回合并计算平均回报基线
n_episodes = 10
all_returns = []

for _ in range(n_episodes):
    obs, _ = env.reset()
    done = False
    rewards = []
    while not done:
        action, _ = model.predict(obs)
        obs, reward, terminated, truncated, _ = env.step(action)
        rewards.append(reward)
        done = terminated or truncated

    returns = compute_returns(rewards, gamma=0.99)
    total_return = sum(rewards)
    all_returns.append(total_return)

baseline = np.mean(all_returns)
print("Baseline (average return):", baseline)

# 训练模型
model.learn(total_timesteps=100000)

# 保存模型
model.save("ppo_cartpole")

# 加载模型
model = PPO.load("ppo_cartpole")

# 设置视频存储文件夹
video_folder = 'recorded_videos'
os.makedirs(video_folder, exist_ok=True)

# 创建新环境并包装以录制视频
eval_env = gym.make('CartPole-v1', render_mode='rgb_array')
eval_env = RecordVideo(eval_env, video_folder=video_folder, name_prefix='ppo_cartpole')

# 重置环境并开始评估
obs, _ = eval_env.reset()
total_reward = 0
for i in range(1000):
    action, _ = model.predict(obs)
    obs, reward, terminated, truncated, _ = eval_env.step(action)
    total_reward += reward
    done = terminated or truncated
    if done:
        obs, _ = eval_env.reset()
        print("Total reward for this episode:", total_reward)
        total_reward = 0

eval_env.close()

# 列出生成的视频文件
video_file = [f for f in os.listdir(video_folder) if f.endswith('.mp4')][0]
video_path = os.path.join(video_folder, video_file)
print("Video path:", video_path)

# 显示视频
Video(video_path)


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Baseline (average return): 21.8
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 19.8     |
|    ep_rew_mean     | 19.8     |
| time/              |          |
|    fps             | 727      |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 28.9        |
|    ep_rew_mean          | 28.9        |
| time/                   |             |
|    fps                  | 536         |
|    iterations           | 2           |
|    time_elapsed         | 7           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010089325 |
|    clip_fraction        | 0.125       |
|    clip_range           | 0

  logger.warn(


Moviepy - Building video /content/recorded_videos/ppo_cartpole-episode-0.mp4.
Moviepy - Writing video /content/recorded_videos/ppo_cartpole-episode-0.mp4





Moviepy - Done !
Moviepy - video ready /content/recorded_videos/ppo_cartpole-episode-0.mp4
Total reward for this episode: 500.0
Moviepy - Building video /content/recorded_videos/ppo_cartpole-episode-1.mp4.
Moviepy - Writing video /content/recorded_videos/ppo_cartpole-episode-1.mp4





Moviepy - Done !
Moviepy - video ready /content/recorded_videos/ppo_cartpole-episode-1.mp4
Total reward for this episode: 500.0
Video path: recorded_videos/ppo_cartpole-episode-1.mp4
