In [None]:
#!pip install ray[rllib,tune]
#!pip install pettingzoo pygame pymunk
#!pip install torch

In [2]:
import ray
import glob
import sys

from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv, ParallelPettingZooEnv
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.tune.registry import get_trainable_cls, register_env

from pettingzoo.sisl import waterworld_v4

In [5]:
import sys
from glob import glob
from os import path
from ray.rllib.policy.policy import Policy

ray.shutdown() 

# Trained to about 0 combined return
checkpoint_path = "/root/ray_results/PPO_2024-08-28_20-57-45/PPO_2_agent_env_2cf59_00000_0_2024-08-28_20-57-45/checkpoint_000000"
pols = glob(checkpoint_path+"/policies/*")
specs = {path.basename(p) : Policy.from_checkpoint(p) for p in pols}
#specs = {path.basename(p) : SingleAgentRLModuleSpec(load_state_path=p) for p in pols} # Non-deterministic policy weight return (implies new)


num_agents = 2

register_env(f"{num_agents}_agent_env", lambda _: ParallelPettingZooEnv(waterworld_v4.parallel_env(n_pursuers=num_agents)))
policies = {f"pursuer_{i}" for i in range(num_agents)}


resto_config = (
    get_trainable_cls("PPO")
    .get_default_config()
    .environment(f"{num_agents}_agent_env")
    .multi_agent(
        policies=policies,
        # Exact 1:1 mapping from AgentID to ModuleID.
        policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
    )
    .rl_module(
        #model_config_dict={"vf_share_layers": True},
        rl_module_spec=MultiAgentRLModuleSpec(
            #load_state_path=
            #module_specs=specs,
            module_specs={p: SingleAgentRLModuleSpec() for p in policies},
        ),
    )
    .evaluation(
        evaluation_interval=1,
    )
)

resto_algo = resto_config.build()
resto_algo.get_policy("pursuer_0").set_weights(specs["pursuer_0"].get_weights())
resto_algo.get_policy("pursuer_1").set_weights(specs["pursuer_1"].get_weights())
#resto_algo.get_policy("pursuer_0").get_weights()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2024-08-29 02:06:21,682	INFO worker.py:1772 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


Above, we restore an algo to its original size, and evaluate below.

In [7]:
resto_algo.evaluate()

{'env_runners': {'episode_reward_max': np.float64(104.48447046235398),
  'episode_reward_min': np.float64(-49.97758701259906),
  'episode_reward_mean': np.float64(34.31495169008534),
  'episode_len_mean': np.float64(500.0),
  'episode_media': {},
  'episodes_timesteps_total': 5000,
  'policy_reward_min': {'pursuer_0': np.float64(-27.80399612520721),
   'pursuer_1': np.float64(-74.73581968384396)},
  'policy_reward_max': {'pursuer_0': np.float64(108.50397211584374),
   'pursuer_1': np.float64(66.59594375510522)},
  'policy_reward_mean': {'pursuer_0': np.float64(32.97170784631929),
   'pursuer_1': np.float64(1.3432438437660394)},
  'custom_metrics': {},
  'hist_stats': {'episode_reward': [np.float64(44.487376281462424),
    np.float64(98.65380257932057),
    np.float64(84.50964155764572),
    np.float64(20.08502968193733),
    np.float64(104.48447046235398),
    np.float64(-38.577462399422544),
    np.float64(-49.97758701259906),
    np.float64(20.465628189378837),
    np.float64(91.2633

Next, lets try with a different number of test agents.

In [3]:
import sys
from glob import glob
from os import path
from ray.rllib.policy.policy import Policy
import numpy as np

ray.shutdown() 

# Trained to about 0 combined return
checkpoint_path = "/root/ray_results/PPO_2024-08-28_20-57-45/PPO_2_agent_env_2cf59_00000_0_2024-08-28_20-57-45/checkpoint_000000"
pols = glob(checkpoint_path+"/policies/*")
specs = {path.basename(p) : Policy.from_checkpoint(p) for p in pols}
#specs = {path.basename(p) : SingleAgentRLModuleSpec(load_state_path=p) for p in pols} # Non-deterministic policy weight return (implies new)

num_trained_agents = 2
num_test_agents = 6

register_env(f"{num_test_agents}_agent_env", lambda _: ParallelPettingZooEnv(waterworld_v4.parallel_env(n_pursuers=num_test_agents)))
policies = {f"pursuer_{i}" for i in range(num_test_agents)}


resto_config = (
    get_trainable_cls("PPO")
    .get_default_config()
    .environment(f"{num_test_agents}_agent_env")
    .multi_agent(
        policies=policies,
        # Exact 1:1 mapping from AgentID to ModuleID.
        policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
    )
    .rl_module(
        #model_config_dict={"vf_share_layers": True},
        rl_module_spec=MultiAgentRLModuleSpec(
            #load_state_path=
            #module_specs=specs,
            module_specs={p: SingleAgentRLModuleSpec() for p in policies},
        ),
    )
    .evaluation(
        evaluation_interval=1,
    )
)

resto_algo = resto_config.build()
for test_id in range(num_test_agents):
    train_id = np.random.randint(num_trained_agents)
    resto_algo.get_policy(f"pursuer_{test_id}").set_weights(specs[f"pursuer_{train_id}"].get_weights())

#resto_algo.get_policy("pursuer_0").set_weights(specs["pursuer_0"].get_weights())
#resto_algo.get_policy("pursuer_1").set_weights(specs["pursuer_1"].get_weights())
#resto_algo.get_policy("pursuer_0").get_weights()

resto_algo.evaluate()

  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2024-09-03 17:51:00,086	INFO worker.py:1772 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m


{'env_runners': {'episode_reward_max': np.float64(-133.08651258384768),
  'episode_reward_min': np.float64(-429.6043485772751),
  'episode_reward_mean': np.float64(-279.01460767396213),
  'episode_len_mean': np.float64(500.0),
  'episode_media': {},
  'episodes_timesteps_total': 5000,
  'policy_reward_min': {'pursuer_0': np.float64(-70.69341337433875),
   'pursuer_1': np.float64(-83.28156152153042),
   'pursuer_2': np.float64(-89.37824978401237),
   'pursuer_3': np.float64(-64.26033845731828),
   'pursuer_4': np.float64(-104.50749893251731),
   'pursuer_5': np.float64(-180.21766134359572)},
  'policy_reward_max': {'pursuer_0': np.float64(-8.571273725909617),
   'pursuer_1': np.float64(60.30311215169101),
   'pursuer_2': np.float64(42.74708175708947),
   'pursuer_3': np.float64(-7.6108996017414645),
   'pursuer_4': np.float64(67.86120096904311),
   'pursuer_5': np.float64(-44.10802049107301)},
  'policy_reward_mean': {'pursuer_0': np.float64(-42.284014583047856),
   'pursuer_1': np.floa

Test for training:

In [4]:
resto_algo.train()



{'evaluation': {'env_runners': {'episode_reward_max': np.float64(-364.769848341295),
   'episode_reward_min': np.float64(-619.0375683355406),
   'episode_reward_mean': np.float64(-493.78093420107905),
   'episode_len_mean': np.float64(500.0),
   'episode_media': {},
   'episodes_timesteps_total': 5000,
   'policy_reward_min': {'pursuer_0': np.float64(-113.90485384019159),
    'pursuer_1': np.float64(-109.27251741852379),
    'pursuer_2': np.float64(-114.55402342215696),
    'pursuer_3': np.float64(-111.76090622319133),
    'pursuer_4': np.float64(-109.0597747992956),
    'pursuer_5': np.float64(-217.50166684483233)},
   'policy_reward_max': {'pursuer_0': np.float64(11.303415129637894),
    'pursuer_1': np.float64(0.2069972847033158),
    'pursuer_2': np.float64(-46.84567547040106),
    'pursuer_3': np.float64(-6.933874989835202),
    'pursuer_4': np.float64(-7.903867561959164),
    'pursuer_5': np.float64(-102.74657319368887)},
   'policy_reward_mean': {'pursuer_0': np.float64(-62.6888

## Testing Callbacks



In [34]:
import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv, ParallelPettingZooEnv
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.tune.registry import get_trainable_cls, register_env
from pettingzoo.sisl import waterworld_v4

parser = add_rllib_example_script_args(
    default_iters=200,
    default_timesteps=1000000,
    default_reward=300,
)

from typing import Dict, Tuple
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.policy import Policy

class MyCallbacks(DefaultCallbacks):
    def on_episode_start(
        self,
        *,
        episode: Episode,
        env_index: int,
        **kwargs,
    ) -> None:
        episode.hist_data["mean_agent_return_hist"] = []

    def on_episode_end(
        self,
        *,
        episode: Episode,
        env_index: int,
        #worker: RolloutWorker,
        #base_env: BaseEnv,
        #policies: Dict[str, Policy],
        **kwargs,
    ):
        print(f"Total Reward: {episode.total_reward:>5}")
        episode.custom_metrics["mean_agent_return"] = episode.total_reward / len(policies)
        episode.hist_data["mean_agent_return_hist"].append( episode.total_reward / len(policies) )
    
    def on_train_result(self, *, algorithm, result: dict, **kwargs):
        result["num_agents"] = len(policies)


ray.shutdown()


num_agents = 2
register_env(f"{num_agents}_agent_env", lambda _: ParallelPettingZooEnv(waterworld_v4.parallel_env(n_pursuers=num_agents)))
policies = {f"pursuer_{i}" for i in range(num_agents)}

config = (
    get_trainable_cls("PPO")
    .get_default_config()
    .environment(f"{num_agents}_agent_env")
    .multi_agent(
        policies=policies,
        # Exact 1:1 mapping from AgentID to ModuleID.
        policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
    )
    .rl_module(
        rl_module_spec=MultiAgentRLModuleSpec(
            module_specs={p: SingleAgentRLModuleSpec() for p in policies},
        ),
    )
    #.evaluation(
    #    evaluation_interval=1,
    #)
    .callbacks(MyCallbacks)
)

algo = config.build()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2024-09-06 18:44:55,926	INFO worker.py:1772 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


In [35]:
algo.train()

[36m(RolloutWorker pid=2004060)[0m Total Reward: -331.3670194328563
[36m(RolloutWorker pid=2004061)[0m Total Reward: -288.67849504476527[32m [repeated 6x across cluster][0m


{'custom_metrics': {},
 'episode_media': {},
 'info': {'learner': {'pursuer_0': {'learner_stats': {'allreduce_latency': np.float64(0.0),
     'grad_gnorm': np.float32(2.320964),
     'cur_kl_coeff': np.float64(0.20000000000000004),
     'cur_lr': np.float64(5.0000000000000016e-05),
     'total_loss': np.float64(8.972203861673673),
     'policy_loss': np.float64(-0.007902693240127216),
     'vf_loss': np.float64(8.97887042115132),
     'vf_explained_var': np.float64(2.0121286312739053e-05),
     'kl': np.float64(0.006180924484601747),
     'entropy': np.float64(2.805913825829824),
     'entropy_coeff': np.float64(0.0)},
    'model': {},
    'custom_metrics': {},
    'num_agent_steps_trained': np.float64(125.0),
    'num_grad_updates_lifetime': np.float64(480.5),
    'diff_num_grad_updates_vs_sampler_policy': np.float64(479.5)},
   'pursuer_1': {'learner_stats': {'allreduce_latency': np.float64(0.0),
     'grad_gnorm': np.float32(1.347748),
     'cur_kl_coeff': np.float64(0.2000000000000