In [1]:
import gymnasium as gym
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
import numpy as np



# 1. 创建自定义环境，在info中返回指标
class CustomCartPole(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        
    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        
        # 在info中添加自定义指标
        info["my_custom_metric"] = np.random.random()
        info["cart_position"] = obs[0]
        
        return obs, reward, terminated, truncated, info

# 2. 创建自定义回调函数，注意参数签名
class CustomMetricsCallback(DefaultCallbacks):
    def on_episode_step(
        self,
        *,
        episode,
        env_runner=None,
        metrics_logger=None,
        env=None,
        env_index,
        rl_module=None,
        worker=None,
        base_env=None,
        policies=None,
        **kwargs,
    ) -> None:
        # 获取最新的info
        info = episode.get_infos(-1)
        if info is not None:
            if "my_custom_metric" in info:
                # 使用metrics_logger记录自定义指标
                if metrics_logger:
                    metrics_logger.log_value(
                        "my_custom_metric", info["my_custom_metric"], reduce='mean', window=100,
                    )
            if "cart_position" in info:
                if metrics_logger:
                    metrics_logger.log_value(
                        "cart_position", info["cart_position"], reduce=None,
                    )

    def on_train_result(
        self, *, algorithm, result, metrics_logger, **kwargs
    ):
        max_value = max(result["env_runners"]["cart_position"])
        len_value = len(result["env_runners"]["cart_position"])

        result.setdefault("custom_metrics", {})
        result["custom_metrics"]["cart_position_max"] = max_value
        result["custom_metrics"]["cart_position_len"] = len_value

# 3. 配置
def env_creator(env_config):
    env = gym.make("CartPole-v1")
    return CustomCartPole(env)

from ray.tune.registry import register_env
register_env("custom_cartpole", env_creator)

config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("custom_cartpole")
    .callbacks(CustomMetricsCallback)
)

# 4. 训练并查看结果
from ray.rllib.algorithms.ppo import PPO

algo = PPO(config=config)

result = algo.train()
result

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-03-04 08:49:46,542	INFO worker.py:1821 -- Started a local Ray instance.
2025-03-04 08:49:58,940	INFO trainable.py:161 -- Trainable.setup took 16.921 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


{'timers': {'training_iteration': 21.68888509995304,
  'restore_workers': 1.8700025975704193e-05,
  'training_step': 21.688522999873385,
  'env_runner_sampling_timer': 3.65871710004285,
  'learner_update_timer': 18.014326999895275,
  'synch_weights': 0.012986700050532818},
 'env_runners': {'episode_len_max': 91,
  'num_module_steps_sampled_lifetime': {'default_policy': 4000},
  'num_env_steps_sampled': 4000,
  'module_episode_returns_mean': {'default_policy': 23.15},
  'num_env_steps_sampled_lifetime': 4000,
  'num_episodes': 178,
  'weights_seq_no': 0.0,
  'episode_len_min': 9,
  'num_agent_steps_sampled': {'default_agent': 4000},
  'num_agent_steps_sampled_lifetime': {'default_agent': 4000},
  'cart_position': [0.029517662,
   -0.024645122,
   0.033599343,
   -0.020088945,
   0.04158529,
   -0.011633695,
   0.05347695,
   -0.0070822784,
   0.0614726,
   -0.006431169,
   0.065577544,
   -0.0018736608,
   0.06579553,
   0.006590425,
   0.062128518,
   0.01115819,
   0.054576535,
   0.0