In [113]:
from pathlib import Path
import torch
from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian
from ray.tune.result import TRAINING_ITERATION
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core import (
    COMPONENT_LEARNER,
    COMPONENT_LEARNER_GROUP,
    COMPONENT_RL_MODULE,
)
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum
from ray.rllib.utils.metrics import (
    ENV_RUNNER_RESULTS,
    EPISODE_RETURN_MEAN,
    NUM_ENV_STEPS_SAMPLED_LIFETIME,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    check,
    run_rllib_example_script_experiment,
)
from ray.tune.registry import get_trainable_cls, register_env

from pettingzoo.sisl import waterworld_v4

from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.tune.registry import get_trainable_cls, register_env
import os

import gymnasium as gym
import numpy as np
import tree  # pip install dm_tree

from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy, softmax
from ray.rllib.utils.metrics import (
    ENV_RUNNER_RESULTS,
    EPISODE_RETURN_MEAN,
)

torch, _ = try_import_torch()

In [None]:
num-predators=2
num-prey=2

In [None]:
# 设置参数
import sys
sys.argv = [
    'notebook_script.py',
    '--enable-new-api-stack',
    '--num-agents=4',
    # 新增参数用于指定predator和prey的数量
    '--n-predators=2',
    '--n-preys=2', 
    '--checkpoint-at-end',
    '--stop-reward=200.0',
    '--checkpoint-freq=1',
]

In [87]:
parser = add_rllib_example_script_args(
    default_iters=2,
    default_timesteps=10000,
    default_reward=0.0,
)

In [88]:
# 添加新的参数解析
parser.add_argument(
    "--n-predators",
    type=int,
    default=2,
    help="Number of predator agents"
)
parser.add_argument(
    "--n-preys", 
    type=int,
    default=2,
    help="Number of prey agents"
)
parser.add_argument(
    "--use-onnx-for-inference",
    action="store_true",
    help="Whether to convert the loaded module to ONNX format and then perform "
    "inference through this ONNX model.",
)
parser.add_argument(
    "--explore-during-inference",
    action="store_true",
    help="Whether the trained policy should use exploration during action "
    "inference.",
)
parser.add_argument(
    "--num-episodes-during-inference",
    type=int,
    default=10,
    help="Number of episodes to do inference over (after restoring from a checkpoint).",
)

_StoreAction(option_strings=['--num-episodes-during-inference'], dest='num_episodes_during_inference', nargs=None, const=None, default=10, type=<class 'int'>, choices=None, required=False, help='Number of episodes to do inference over (after restoring from a checkpoint).', metavar=None)

In [94]:

args = parser.parse_args()

# 验证参数
assert args.n_predators > 0, "Must set --n-predators > 0 when running this script!"
assert args.n_preys > 0, "Must set --n-preys > 0 when running this script!"
assert (
    args.enable_new_api_stack
), "Must set --enable-new-api-stack when running this script!"

# 计算总智能体数量
total_agents = args.n_predators + args.n_preys
print(f"参数解析完成: n_predators={args.n_predators}, n_preys={args.n_preys}, total_agents={total_agents}, algo={args.algo}")


参数解析完成: n_predators=2, n_preys=2, total_agents=4, algo=PPO


In [95]:

# 修改环境注册，传递predator和prey的数量
register_env("env", lambda _: PettingZooEnv(
    waterworld_v4.env(
        n_predators=args.n_predators,
        n_preys=args.n_preys  # 注意：这里应该是n_preys而不是n_prey
    )
))


In [97]:

# 创建新的policies字典，匹配环境中的agent命名
predator_policies = [f"predator_{i}" for i in range(args.n_predators)]
prey_policies = [f"prey_{i}" for i in range(args.n_preys)]
all_policies = predator_policies + prey_policies
print(all_policies)

['predator_0', 'predator_1', 'prey_0', 'prey_1']


In [106]:
# 创建RL module specs字典
rl_module_specs = {p: RLModuleSpec() for p in all_policies}
# print(f"创建的policies: {list(policies.keys())}")
# print(f"创建的RL module specs: {list(rl_module_specs.keys())}")

In [99]:
base_config = (
    get_trainable_cls(args.algo)
    .get_default_config()
    .environment("env")
    .multi_agent(
        # 在新API中，只需要指定policy_mapping_fn
        policies=set(all_policies),
        policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
    )
    .training(
        vf_loss_coeff=0.005,
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs=rl_module_specs,
        ),
        model_config=DefaultModelConfig(vf_share_layers=True),
    )
)

# 训练
print("开始训练...")
results = run_rllib_example_script_experiment(base_config, args, keep_ray_up=True)
print("训练完成")

2025-06-26 07:26:13,297	INFO worker.py:1747 -- Calling ray.init() again after it has already been called.


开始训练...
== Status ==
Current time: 2025-06-26 07:26:13 (running for 00:00:00.11)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_07-26-13/PPO_2025-06-26_07-26-13/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-06-26 07:26:18 (running for 00:00:05.14)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_07-26-13/PPO_2025-06-26_07-26-13/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




Trial name,env_runner_group,env_runners,fault_tolerance,learners,num_env_steps_sampled_lifetime,num_training_step_calls_per_iteration,perf,timers
PPO_env_66eb0_00000,{'actor_manager_num_outstanding_async_reqs': 0},"{'env_reset_timer': np.float64(0.0023604979796800762), 'connector_pipeline_timer': np.float64(0.0003246039996156469), 'agent_steps': {'prey_1': 500.0, 'predator_1': 500.0, 'predator_0': 500.0, 'prey_0': 500.0}, 'num_module_steps_sampled_lifetime': {'prey_0': 2004.0, 'predator_0': 2000.0, 'prey_1': 2004.0, 'predator_1': 2004.0}, 'env_to_module_connector': {'connector_pipeline_timer': np.float64(0.00011837546502323462), 'timers': {'connectors': {'agent_to_module_mapping': np.float64(3.4250023006150638e-06), 'batch_individual_items': np.float64(1.3240884139739838e-05), 'add_observations_from_episodes_to_batch': np.float64(1.9349793506166415e-05), 'add_time_dim_to_batch_and_zero_pad': np.float64(9.745407055214614e-06), 'numpy_to_tensor': np.float64(2.3869756415713063e-05), 'add_states_from_episodes_to_batch': np.float64(6.313448366311252e-06)}}}, 'module_to_env_connector': {'timers': {'connectors': {'module_to_agent_unmapping': np.float64(2.8402581422446355e-06), 'listify_data_for_vector_env': np.float64(6.109693616853156e-06), 'normalize_and_clip_actions': np.float64(3.654940639490682e-05), 'un_batch_to_individual_items': np.float64(1.3381484580179156e-05), 'remove_single_ts_time_rank_from_batch': np.float64(1.0666992495726612e-06), 'get_actions': np.float64(8.932109222115932e-05), 'tensor_to_numpy': np.float64(3.264590290546003e-05)}}, 'connector_pipeline_timer': np.float64(0.00023805947165909827)}, 'episode_return_mean': -443.24515081505814, 'episode_len_min': 2000, 'agent_episode_returns_mean': {'predator_1': -67.81485630391575, 'prey_0': -73.71830380375121, 'predator_0': -92.7488216211424, 'prey_1': -208.9631690862488}, 'num_env_steps_sampled': 4000.0, 'module_episode_returns_mean': {'predator_0': -92.7488216211424, 'prey_0': -73.71830380375121, 'prey_1': -208.9631690862488, 'predator_1': -67.81485630391575}, 'episode_duration_sec_mean': 2.1081632410059683, 'num_agent_steps_sampled': {'prey_0': 1002.0, 'predator_0': 1000.0, 'prey_1': 1002.0, 'predator_1': 1002.0}, 'num_agent_steps_sampled_lifetime': {'prey_0': 2004.0, 'predator_0': 2000.0, 'prey_1': 2004.0, 'predator_1': 2004.0}, 'timers': {'connectors': {'batch_individual_items': np.float64(2.1828003809787333e-05), 'agent_to_module_mapping': np.float64(5.202004103921354e-06), 'add_observations_from_episodes_to_batch': np.float64(2.8144015232101083e-05), 'add_time_dim_to_batch_and_zero_pad': np.float64(1.6893507563509044e-05), 'numpy_to_tensor': np.float64(5.312749999575317e-05), 'add_states_from_episodes_to_batch': np.float64(7.5864954851567745e-06)}}, 'num_episodes': 2.0, 'episode_return_min': -451.69751868313847, 'num_module_steps_sampled': {'predator_0': 1000.0, 'prey_0': 1002.0, 'prey_1': 1002.0, 'predator_1': 1002.0}, 'sample': np.float64(2.135121788320539), 'num_episodes_lifetime': 4.0, 'rlmodule_inference_timer': np.float64(0.0001202673108937409), 'weights_seq_no': 1.0, 'episode_len_max': 2000, 'env_to_module_sum_episodes_length_in': np.float64(1881.0000036551799), 'env_step_timer': np.float64(0.00040608099219482545), 'episode_return_max': -429.24278635180315, 'num_env_steps_sampled_lifetime': 8000.0, 'episode_len_mean': 2000.0, 'env_to_module_sum_episodes_length_out': np.float64(1881.0000036551799), 'time_between_sampling': np.float64(3.108147768492927), 'num_env_steps_sampled_lifetime_throughput': np.float64(1151.9950801533569)}","{'num_healthy_workers': 2, 'num_remote_worker_restarts': 0}","{'__all_modules__': {'learner_connector': {'timers': {'connectors': {'agent_to_module_mapping': 0.001842057590256445, 'general_advantage_estimation': 0.05638679909199709, 'add_columns_from_episodes_to_train_batch': 0.03407630433503072, 'add_observations_from_episodes_to_batch': 8.166401385096833e-05, 'numpy_to_tensor': 0.0015490653048618698, 'add_one_ts_to_episodes_and_truncate': 0.0031309397477889433, 'batch_individual_items': 0.02191352123161778, 'add_time_dim_to_batch_and_zero_pad': 3.2909143483266236e-05, 'add_states_from_episodes_to_batch': 8.21120134787634e-06}}, 'connector_pipeline_timer': 0.11943638734082924}, 'learner_connector_sum_episodes_length_out': 4000.0, 'num_module_steps_trained_lifetime': 240640, 'num_env_steps_trained_lifetime': 1880000, 'num_trainable_parameters': 517140, 'num_non_trainable_parameters': 0, 'num_env_steps_trained': 940000, 'learner_connector_sum_episodes_length_in': 4000.0, 'num_module_steps_trained': 120320, 'num_env_steps_trained_lifetime_throughput': 356777.5169579141, 'num_module_steps_trained_throughput': 2124939.2681887764, 'num_module_steps_trained_lifetime_throughput': 2164152.696159943}, 'predator_1': {'weights_seq_no': 2.0, 'diff_num_grad_updates_vs_sampler_policy': np.float32(1.0), 'vf_loss_unclipped': np.float32(312.8168), 'num_module_steps_trained_lifetime': 60160, 'policy_loss': np.float32(0.031444043), 'vf_loss': np.float32(9.932936), 'curr_entropy_coeff': 0.0, 'entropy': np.float32(2.616324), 'num_trainable_parameters': 129285, 'gradients_default_optimizer_global_norm': np.float32(1.2542958), 'vf_explained_var': np.float32(0.0011557937), 'curr_kl_coeff': 0.20000000298023224, 'num_module_steps_trained': 30080, 'mean_kl_loss': np.float32(0.007951054), 'default_optimizer_learning_rate': 5e-05, 'total_loss': np.float32(0.08269893), 'module_train_batch_size_mean': 128.0, 'num_module_steps_trained_lifetime_throughput': 11416.93608688965}, 'prey_1': {'weights_seq_no': 2.0, 'diff_num_grad_updates_vs_sampler_policy': np.float32(1.0), 'vf_loss_unclipped': np.float32(1166.7649), 'num_module_steps_trained_lifetime': 60160, 'policy_loss': np.float32(-0.03917417), 'vf_loss': np.float32(9.918778), 'curr_entropy_coeff': 0.0, 'entropy': np.float32(2.8136663), 'num_trainable_parameters': 129285, 'gradients_default_optimizer_global_norm': np.float32(1.8443462), 'vf_explained_var': np.float32(-0.0003361702), 'curr_kl_coeff': 0.10000000149011612, 'num_module_steps_trained': 30080, 'mean_kl_loss': np.float32(0.00437786), 'default_optimizer_learning_rate': 5e-05, 'total_loss': np.float32(0.011295281), 'module_train_batch_size_mean': 128.0, 'num_module_steps_trained_lifetime_throughput': 11417.035603814104}, 'prey_0': {'policy_loss': np.float32(0.0025350451), 'curr_kl_coeff': 0.20000000298023224, 'num_trainable_parameters': 129285, 'num_module_steps_trained': 30080, 'mean_kl_loss': np.float32(0.0056399116), 'vf_explained_var': np.float32(0.0022902489), 'default_optimizer_learning_rate': 5e-05, 'total_loss': np.float32(0.048071153), 'vf_loss_unclipped': np.float32(157.75735), 'num_module_steps_trained_lifetime': 60160, 'module_train_batch_size_mean': 128.0, 'vf_loss': np.float32(8.881626), 'weights_seq_no': 2.0, 'diff_num_grad_updates_vs_sampler_policy': np.float32(1.0), 'curr_entropy_coeff': 0.0, 'entropy': np.float32(2.752587), 'gradients_default_optimizer_global_norm': np.float32(0.93448186), 'num_module_steps_trained_lifetime_throughput': 11416.996025149803}, 'predator_0': {'num_module_steps_trained_lifetime': 60160, 'module_train_batch_size_mean': 128.0, 'vf_loss': np.float32(8.901313), 'weights_seq_no': 2.0, 'diff_num_grad_updates_vs_sampler_policy': np.float32(1.0), 'vf_loss_unclipped': np.float32(156.68777), 'entropy': np.float32(2.786313), 'gradients_default_optimizer_global_norm': np.float32(2.142382), 'policy_loss': np.float32(-0.1252856), 'curr_kl_coeff': 0.30000001192092896, 'curr_entropy_coeff': 0.0, 'num_trainable_parameters': 129285, 'num_module_steps_trained': 30080, 'mean_kl_loss': np.float32(0.006620531), 'vf_explained_var': np.float32(0.00047051907), 'default_optimizer_learning_rate': 5e-05, 'total_loss': np.float32(-0.07879286), 'num_module_steps_trained_lifetime_throughput': 11416.912068213105}}",8000,1,"{'cpu_util_percent': np.float64(14.385714285714284), 'ram_util_percent': np.float64(23.800000000000004)}","{'training_iteration': 5.1811524428435956, 'restore_env_runners': 1.9143130339216443e-05, 'training_step': 5.18096036298899, 'env_runner_sampling_timer': 2.156525201426994, 'learner_update_timer': 3.019933069680119, 'synch_weights': 0.0041777585819363595, 'synch_env_connectors': 0.0015853579971008003}"


== Status ==
Current time: 2025-06-26 07:26:23 (running for 00:00:10.20)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_07-26-13/PPO_2025-06-26_07-26-13/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




2025-06-26 07:26:27,986	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/qrbao/ray_results/PPO_2025-06-26_07-26-13' in 0.0141s.


== Status ==
Current time: 2025-06-26 07:26:27 (running for 00:00:14.68)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_07-26-13/PPO_2025-06-26_07-26-13/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)
+---------------------+------------+---------------------+--------+------------------+------+-------------------+-----------------+---------------------+-----------------+---------------------+
| Trial name          | status     | loc                 |   iter |   total time (s) |   ts |   combined return |   return prey_1 |   return predator_1 |   return prey_0 |   return predator_0 |
|---------------------+------------+---------------------+--------+------------------+------+-------------------+-----------------+---------------------+-----------------+---------------------|
| PPO_env_66eb0_00000 | TERMINATED | 192.168.0.25:830892 |   

2025-06-26 07:26:28,300	INFO tune.py:1041 -- Total run time: 15.00 seconds (14.67 seconds for the tuning loop).


训练完成


In [66]:
# # 设置参数
# import sys
# sys.argv = [
#     'notebook_script.py',
#     '--enable-new-api-stack',
#     '--num-agents=4',
#     # '--env-config={"n_predators":2,"n_preys":2,"n_evaders":5,"n_obstacles":1,"n_poisons":1}',
#     '--checkpoint-at-end',
#     '--stop-reward=200.0',
#     '--checkpoint-freq=1',
# ]

In [54]:

# env = waterworld_v4.env(render_mode="human",n_predators=2,n_preys=2,n_evaders=5,n_obstacles=1,n_poisons=1)
# env.reset(seed=42)

# for agent in env.agent_iter():
#     observation, reward, termination, truncation, info = env.last()

#     if termination or truncation:
#         action = None
#     else:
#         # this is where you would insert your policy
#         action = env.action_space(agent).sample()

#     env.step(action)
# env.close()

In [None]:
# parser = add_rllib_example_script_args(
# default_iters=2,
# default_timesteps=10000,
# default_reward=0.0,
# )


In [None]:

# parser.add_argument(
# "--use-onnx-for-inference",
# action="store_true",
# help="Whether to convert the loaded module to ONNX format and then perform "
# "inference through this ONNX model.",
# )
# parser.add_argument(
# "--explore-during-inference",
# action="store_true",
# help="Whether the trained policy should use exploration during action "
# "inference.",
# )
# parser.add_argument(
# "--num-episodes-during-inference",
# type=int,
# default=10,
# help="Number of episodes to do inference over (after restoring from a checkpoint).",
# )


_StoreAction(option_strings=['--num-episodes-during-inference'], dest='num_episodes_during_inference', nargs=None, const=None, default=10, type=<class 'int'>, choices=None, required=False, help='Number of episodes to do inference over (after restoring from a checkpoint).', metavar=None)

In [None]:

# args = parser.parse_args()


In [None]:

# print(args)

Namespace(algo='PPO', enable_new_api_stack=True, framework='torch', env=None, num_env_runners=None, num_envs_per_env_runner=None, num_agents=4, evaluation_num_env_runners=0, evaluation_interval=0, evaluation_duration=10, evaluation_duration_unit='episodes', evaluation_parallel_to_training=False, output=None, log_level=None, no_tune=False, num_samples=1, max_concurrent_trials=None, verbose=2, checkpoint_freq=1, checkpoint_at_end=True, wandb_key=None, wandb_project=None, wandb_run_name=None, stop_reward=200.0, stop_iters=2, stop_timesteps=10000, as_test=False, as_release_test=False, num_learners=None, num_cpus_per_learner=None, num_gpus_per_learner=None, num_aggregator_actors_per_learner=None, num_cpus=0, local_mode=False, num_gpus=None, use_onnx_for_inference=False, explore_during_inference=False, num_episodes_during_inference=10)


In [102]:

# assert args.num_agents > 0, "Must set --num-agents > 0 when running this script!"
# assert (
# args.enable_new_api_stack
# ), "Must set --enable-new-api-stack when running this script!"

# print(f"参数解析完成: num_agents={args.num_agents}, algo={args.algo}")

In [103]:

# # 修改环境注册，传递predator和prey的数量
# register_env("env", lambda _: PettingZooEnv(
#     waterworld_v4.env(
#         n_predators=args.n_predators,  # 需要在args中定义
#         n_prey=args.n_prey            # 需要在args中定义
#     )
# ))


In [107]:

# # 创建policies字典
# policies = {f"pursuer_{i}" for i in range(args.num_agents)}
# policies
# {p: RLModuleSpec() for p in all_policies}

In [None]:

# base_config = (
#     get_trainable_cls(args.algo)
#     .get_default_config()
#     .environment("env")
#     .multi_agent(
#         policies=policies,
#         policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
#     )
#     .training(
#         vf_loss_coeff=0.005,
#     )
#     .rl_module(
#         rl_module_spec=MultiRLModuleSpec(
#             rl_module_specs={p: RLModuleSpec() for p in policies},
#         ),
#         model_config=DefaultModelConfig(vf_share_layers=True),
#     )
# )


In [13]:
# print(base_config.env_config)

In [None]:

# # 训练
# print("开始训练...")
# results = run_rllib_example_script_experiment(base_config, args, keep_ray_up=True)
# print("训练完成")

2025-06-26 01:37:30,452	INFO worker.py:1747 -- Calling ray.init() again after it has already been called.


开始训练...
== Status ==
Current time: 2025-06-26 01:37:30 (running for 00:00:00.11)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_01-37-30/PPO_2025-06-26_01-37-30/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-06-26 01:37:35 (running for 00:00:05.13)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_01-37-30/PPO_2025-06-26_01-37-30/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-06-26 01:37:40 (running for 00:00:10.18)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_01-37-30/P

Trial name,env_runner_group,fault_tolerance,num_env_steps_sampled_lifetime,num_training_step_calls_per_iteration,perf,timers
PPO_env_afee0_00000,{'actor_manager_num_outstanding_async_reqs': 0},"{'num_healthy_workers': 0, 'num_remote_worker_restarts': 0}",0,1,"{'cpu_util_percent': np.float64(8.866666666666667), 'ram_util_percent': np.float64(21.366666666666664)}","{'training_iteration': 59.427974912861245, 'restore_env_runners': 0.01979821590532083, 'training_step': 59.407926392985104, 'env_runner_sampling_timer': 59.40787343050062, 'synch_env_connectors': 1.1440997599856928}"


== Status ==
Current time: 2025-06-26 01:38:36 (running for 00:01:05.60)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_01-37-30/PPO_2025-06-26_01-37-30/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




2025-06-26 01:38:37,894	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/qrbao/ray_results/PPO_2025-06-26_01-37-30' in 0.0035s.
2025-06-26 01:38:38,028	INFO tune.py:1041 -- Total run time: 67.57 seconds (67.43 seconds for the tuning loop).


== Status ==
Current time: 2025-06-26 01:38:37 (running for 00:01:07.44)
Using FIFO scheduling algorithm.
Logical resource usage: 3.0/24 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)
Result logdir: /tmp/ray/session_2025-06-25_23-20-00_551135_808463/artifacts/2025-06-26_01-37-30/PPO_2025-06-26_01-37-30/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)
+---------------------+------------+---------------------+--------+------------------+------+
| Trial name          | status     | loc                 |   iter |   total time (s) |   ts |
|---------------------+------------+---------------------+--------+------------------+------|
| PPO_env_afee0_00000 | TERMINATED | 192.168.0.25:817079 |      2 |          63.1343 |    0 |
+---------------------+------------+---------------------+--------+------------------+------+


训练完成


In [1]:
# 获取最佳结果
print("获取最佳checkpoint...")
best_result = results.get_best_result(
    metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max"
)

获取最佳checkpoint...


NameError: name 'results' is not defined

In [109]:
print("加载所有智能体的RLModule...")
rl_modules = {}
# 修改：使用新的智能体命名方式
predator_agents = [f"predator_{i}" for i in range(args.n_predators)]
prey_agents = [f"prey_{i}" for i in range(args.n_preys)]
all_agent_names = predator_agents + prey_agents


加载所有智能体的RLModule...


In [110]:

for agent_name in all_agent_names:
    rl_module_path = os.path.join(
        best_result.checkpoint.path,
        "learner_group",
        "learner",
        "rl_module",
        agent_name,
    )
    
    if os.path.exists(rl_module_path):
        rl_modules[agent_name] = RLModule.from_checkpoint(rl_module_path)
        print(f"成功加载 {agent_name} 的模型")
    else:
        print(f"警告: 找不到 {agent_name} 的模型路径: {rl_module_path}")
print(f"总共加载了 {len(rl_modules)} 个智能体模型")

成功加载 predator_0 的模型
成功加载 predator_1 的模型
成功加载 prey_0 的模型
成功加载 prey_1 的模型
总共加载了 4 个智能体模型


In [111]:

# 推理阶段
print("开始推理...")

def get_action_from_rl_module(rl_module, observation, explore=False):
    """从RLModule获取动作"""
    # 将观察转换为torch tensor并添加batch维度
    input_dict = {Columns.OBS: torch.from_numpy(observation).unsqueeze(0)}
    
    if explore:
        # 使用探索性前向传播
        rl_module_out = rl_module.forward_exploration(input_dict)
        action_dist_inputs = rl_module_out["action_dist_inputs"][0]
        action = TorchDiagGaussian.from_logits(action_dist_inputs.unsqueeze(0)).sample().squeeze(0).numpy()
    else:
        # 使用推理前向传播
        rl_module_out = rl_module.forward_inference(input_dict)
        action_dist_inputs = rl_module_out["action_dist_inputs"][0]
        action = TorchDiagGaussian.from_logits(action_dist_inputs.unsqueeze(0)).sample().squeeze(0).numpy()
    
    return action

开始推理...


In [112]:

# 运行推理episodes
for episode in range(args.num_episodes_during_inference):
    print(f"\n=== Episode {episode + 1} ===")
    
    # 创建环境（使用新的参数方式）
    env = waterworld_v4.env(
        render_mode="human",
        n_predators=args.n_predators,  # 修改：使用新的参数
        n_preys=args.n_preys,         # 修改：使用新的参数
        n_evaders=5,  # 可以根据需要调整
        n_poisons=10,  # 可以根据需要调整
    )
    env.reset(seed=42 + episode)
    
    episode_rewards = {agent: 0 for agent in env.agents}
    step_count = 0
    
    try:
        for agent in env.agent_iter():
            observation, reward, termination, truncation, info = env.last()
            
            # 累积奖励
            if agent in episode_rewards:
                episode_rewards[agent] += reward
            
            if termination or truncation:
                action = None
                print(f"{agent} 终止, 奖励: {reward}")
            else:
                # 使用对应的RLModule获取动作
                if agent in rl_modules:
                    try:
                        action = get_action_from_rl_module(
                            rl_modules[agent], 
                            observation, 
                            explore=args.explore_during_inference
                        )
                        print(f"{agent} 执行动作: {action}")
                    except Exception as e:
                        print(f"为 {agent} 获取动作时出错: {e}")
                        # 如果出错，使用随机动作作为备选
                        action = env.action_space(agent).sample()
                        print(f"{agent} 使用随机动作: {action}")
                else:
                    # 如果没有找到对应的模型，使用随机动作
                    action = env.action_space(agent).sample()
                    print(f"{agent} 没有找到对应模型，使用随机动作: {action}")
            
            env.step(action)
            step_count += 1
            
            # 每100步输出一次进度
            if step_count % 100 == 0:
                print(f"已执行 {step_count} 步")
                
    except KeyboardInterrupt:
        print("用户中断")
        break
    except Exception as e:
        print(f"Episode {episode + 1} 出现错误: {e}")
    finally:
        env.close()
    
    # 输出episode结果
    print(f"Episode {episode + 1} 完成")
    print("各智能体总奖励:")
    for agent, total_reward in episode_rewards.items():
        print(f"  {agent}: {total_reward:.2f}")
    print(f"总步数: {step_count}")

print("推理完成!")


=== Episode 1 ===
predator_0 执行动作: [-1.8125288 -0.5591454]
predator_1 执行动作: [-0.08958572  0.6042472 ]
prey_0 执行动作: [-1.0631835  0.5505582]
prey_1 执行动作: [-0.14060654 -0.10101031]
predator_0 执行动作: [-0.13771439  1.3597841 ]
predator_1 执行动作: [0.2649867  0.99578065]
prey_0 执行动作: [1.9193748 0.5030465]
prey_1 执行动作: [-0.2940722  0.6914272]
predator_0 执行动作: [-0.01998153  0.6650984 ]
predator_1 执行动作: [-0.7461475  -0.04232921]
prey_0 执行动作: [-0.22005966 -1.3864584 ]
prey_1 执行动作: [1.2614335 0.3888871]
predator_0 执行动作: [-0.12914672  0.7379982 ]
predator_1 执行动作: [ 0.58827925 -0.62155104]
prey_0 执行动作: [ 0.05086606 -1.5000789 ]
prey_1 执行动作: [0.23730011 0.2526495 ]
predator_0 执行动作: [-1.4005857  -0.07694232]
predator_1 执行动作: [-1.09248   -0.2871877]
prey_0 执行动作: [-0.21279551 -1.8017559 ]
prey_1 执行动作: [ 0.7457725  -0.08326474]
predator_0 执行动作: [ 0.5335653 -1.3469042]
predator_1 执行动作: [ 0.11541498 -1.0610658 ]
prey_0 执行动作: [ 0.29792756 -0.13114363]
prey_1 执行动作: [-1.4455218 -1.1708032]
predator_0 执行动作: [-1.

In [None]:
from pettingzoo.sisl import waterworld_v4

env = waterworld_v4.env(render_mode="human")
env.reset(seed=42)

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()

    if termination or truncation:
        action = None
    else:
        # this is where you would insert your policy
        action = env.action_space(agent).sample()

    env.step(action)
env.close()

KeyboardInterrupt: 

: 