In [1]:
import gymnasium as gym
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
import numpy as np



# 1. 创建自定义环境，在info中返回指标
class CustomCartPole(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        
        # 在info中添加自定义指标
        info["reward"] = reward

        if terminated or truncated:
            info["episode_done"] = 1
        
        return obs, reward, terminated, truncated, info

# 2. 创建自定义回调函数，注意参数签名
class CustomMetricsCallback(DefaultCallbacks):
    def on_episode_step(
        self,
        *,
        episode,
        env_runner=None,
        metrics_logger=None,
        env=None,
        env_index,
        rl_module=None,
        worker=None,
        base_env=None,
        policies=None,
        **kwargs,
    ) -> None:
        # 获取最新的info
        info = episode.get_infos(-1)
        episode.add_temporary_timestep_data("reward", info["reward"])

        if 'episode_done' in info:
            episode.add_temporary_timestep_data("episode", 1)

    def on_episode_start(self, *, episode, metrics_logger, **kwargs):
        rewards = episode.get_temporary_timestep_data("reward")
        episodes = episode.get_temporary_timestep_data("episode")
        print(f'on_episode_start: {rewards}')
        print(f'on_episode_start: {episodes}')

    def on_episode_end(self, *, episode, metrics_logger, **kwargs):
        rewards = episode.get_temporary_timestep_data("reward")
        episodes = episode.get_temporary_timestep_data("episode")
        metrics_logger.log_value(
            "custom_episode_return_mean", np.sum(rewards)
        )

        metrics_logger.log_value(
            "custom_episode_return_max", np.sum(rewards), reduce='max'
        )

        metrics_logger.log_value(
            "custom_episode_return_min", np.sum(rewards), reduce='min'
        )

        assert np.sum(episodes) == 1, f'episodes: {np.sum(episodes)}'
        metrics_logger.log_value(
            "custom_num_episodes", np.sum(episodes), reduce='sum'
        )

    def on_train_result(
        self, *, algorithm, result, metrics_logger, **kwargs
    ):
        result.setdefault("custom_metrics", {})
        result["custom_metrics"]["custom_episode_return_mean"] = result["env_runners"]["custom_episode_return_mean"]
        result["custom_metrics"]["custom_num_episodes"] = result["env_runners"]["custom_num_episodes"]
        result["custom_metrics"]["custom_episode_return_max"] = result["env_runners"]["custom_episode_return_max"]
        result["custom_metrics"]["custom_episode_return_min"] = result["env_runners"]["custom_episode_return_min"]

# 3. 配置
def env_creator(env_config):
    env = gym.make("CartPole-v1")
    return CustomCartPole(env)

from ray.tune.registry import register_env
register_env("custom_cartpole", env_creator)

config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("custom_cartpole")
    .callbacks(CustomMetricsCallback)
)

# 4. 训练并查看结果
from ray.rllib.algorithms.ppo import PPO

algo = PPO(config=config)

algo.train()
result = algo.train()
result

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-03-18 10:39:25,279	INFO worker.py:1821 -- Started a local Ray instance.
2025-03-18 10:39:37,061	INFO trainable.py:161 -- Trainable.setup took 16.238 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


[36m(SingleAgentEnvRunner pid=25020)[0m on_episode_start: []
[36m(SingleAgentEnvRunner pid=25020)[0m on_episode_start: []
[36m(SingleAgentEnvRunner pid=25020)[0m on_episode_start: []
[36m(SingleAgentEnvRunner pid=25020)[0m on_episode_start: []
[36m(SingleAgentEnvRunner pid=25020)[0m on_episode_start: []
[36m(SingleAgentEnvRunner pid=25020)[0m on_episode_start: []
[36m(SingleAgentEnvRunner pid=25020)[0m on_episode_start: [][32m [repeated 360x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m


{'timers': {'training_iteration': 25.57689875600394,
  'restore_workers': 2.288194140419364e-05,
  'training_step': 25.57650632005534,
  'env_runner_sampling_timer': 3.7668462620675562,
  'learner_update_timer': 21.794673148978035,
  'synch_weights': 0.012533215032890439,
  'synch_env_connectors': 0.013212673020316288},
 'env_runners': {'num_agent_steps_sampled': {'default_agent': 4000},
  'custom_episode_return_mean': 19.55495634473576,
  'num_env_steps_sampled': 4000,
  'episode_len_min': 10,
  'agent_episode_returns_mean': {'default_agent': 36.56},
  'weights_seq_no': 1.0,
  'episode_len_mean': 36.56,
  'episode_return_max': 108.0,
  'custom_episode_return_min': 7.0,
  'episode_return_mean': 36.56,
  'num_module_steps_sampled_lifetime': {'default_policy': 8000},
  'num_episodes': 110,
  'num_agent_steps_sampled_lifetime': {'default_agent': 8000},
  'module_episode_returns_mean': {'default_policy': 36.56},
  'episode_return_min': 10.0,
  'num_module_steps_sampled': {'default_policy':