# Algorithm VF Functions?

In [1]:
import numpy as np
import sys
from glob import glob
from os import path

import ray
from ray.tune import Trainable, Tuner
from ray.tune.registry import register_trainable, validate_trainable
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.rllib.policy.policy import Policy
from ray.tune.registry import get_trainable_cls, register_env

from pettingzoo.sisl import waterworld_v4


parser = add_rllib_example_script_args(
    default_iters=10,
    default_timesteps=1000000,
    default_reward=300,
)
args = parser.parse_args(args=[])
args.num_env_runners = 10
args.env = 'waterworld'
args.algo = 'PPO'
args.num_agents = 4
args.test_agents = 4

checkpoint_path = f"/root/test/{args.env}/{args.algo}/{args.num_agents}_agent"

sup = sorted(glob(checkpoint_path+'/*'))[0]

pols = glob(sup+"/policies/*")
specs = {path.basename(p) : Policy.from_checkpoint(p) for p in pols}

register_env(f"{args.num_agents}_agent_env", lambda _: ParallelPettingZooEnv(waterworld_v4.parallel_env(n_pursuers=args.num_agents)))
policies = {f"pursuer_{i}" for i in range(args.num_agents)}


resto_config = (
    get_trainable_cls("PPO")
    .get_default_config()
    .environment(f"{args.num_agents}_agent_env")
    .multi_agent(
        policies=policies,
        policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs={p: RLModuleSpec() for p in policies},
        ),
    )
    .evaluation(
        evaluation_interval=1,
    )
)
resto_algo = resto_config.build()
""" Known-good weight transfer

for test_id in range(args.test_agents):
    train_id = np.random.randint(args.num_agents)
    resto_algo.get_policy(f"pursuer_{test_id}").set_weights(specs[f"pursuer_{train_id}"].get_weights())
"""

for test_id in range(args.test_agents):
    resto_algo.remove_policy(f"pursuer_{test_id}")
    resto_algo.add_policy(f"pursuer_{test_id}", policy=specs[f"pursuer_{test_id}"])

print(f"Iter 0 eval = {resto_algo.evaluate()['env_runners']['episode_reward_mean']}")
print(f"Iter 1 train = {resto_algo.train()['env_runners']['episode_reward_mean']}")

  from .autonotebook import tqdm as notebook_tqdm
2024-11-04 19:34:25,483	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-11-04 19:34:25,944	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self

Iter 0 eval = 157.5417865937023




Iter 1 train = 114.53457892178979


In [12]:
print(f"Iter 0 eval = {resto_algo.evaluate()['env_runners']['episode_reward_max']}")
print(f"Iter 1 train = {resto_algo.train()['env_runners']['episode_reward_max']}")

{'env_runners': {'episode_reward_max': 461.5181588385271,
  'episode_reward_min': -20.753103857376605,
  'episode_reward_mean': 217.74148432444377,
  'episode_len_mean': 500.0,
  'episode_media': {},
  'episodes_timesteps_total': 5000,
  'policy_reward_min': {'pursuer_0': -6.865795740206084,
   'pursuer_1': -35.87920937729004,
   'pursuer_2': -3.864316381727646,
   'pursuer_3': -120.14498397644783},
  'policy_reward_max': {'pursuer_0': 230.30065341508384,
   'pursuer_1': 177.46971310653652,
   'pursuer_2': 205.22379008417556,
   'pursuer_3': 75.05780568295731},
  'policy_reward_mean': {'pursuer_0': 107.08053636479556,
   'pursuer_1': 78.92683764095337,
   'pursuer_2': 71.6896900930275,
   'pursuer_3': -39.955579774332435},
  'custom_metrics': {},
  'hist_stats': {'episode_reward': [217.63442235367253,
    253.2640716211442,
    461.5181588385271,
    175.2852620876895,
    -20.753103857376605,
    13.641541068033922,
    172.43116994939476,
    285.90102996606805,
    296.8287696208130

In [7]:
type(resto_algo.reward_estimators)

dict

In [14]:
resto_algo.get_policy('pursuer_0').get_weights().keys()

dict_keys(['_logits._model.0.weight', '_logits._model.0.bias', '_hidden_layers.0._model.0.weight', '_hidden_layers.0._model.0.bias', '_hidden_layers.1._model.0.weight', '_hidden_layers.1._model.0.bias', '_value_branch_separate.0._model.0.weight', '_value_branch_separate.0._model.0.bias', '_value_branch_separate.1._model.0.weight', '_value_branch_separate.1._model.0.bias', '_value_branch._model.0.weight', '_value_branch._model.0.bias'])

# Metrics

In [34]:
import numpy as np
import sys
from glob import glob
from os import path

import ray
from ray.tune import Trainable, Tuner
from ray.tune.registry import register_trainable, validate_trainable
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.rllib.policy.policy import Policy
from ray.tune.registry import get_trainable_cls, register_env

from pettingzoo.sisl import waterworld_v4


parser = add_rllib_example_script_args(
    default_iters=10,
    default_timesteps=1000000,
    default_reward=300,
)
args = parser.parse_args(args=[])
args.num_env_runners = 10
args.env = 'waterworld'
args.algo = 'PPO'
args.num_agents = 4
args.test_agents = 4

checkpoint_path = f"/root/test/{args.env}/{args.algo}/{args.num_agents}_agent"

sup = sorted(glob(checkpoint_path+'/*'))[0]

pols = glob(sup+"/policies/*")
specs = {path.basename(p) : Policy.from_checkpoint(p) for p in pols}

register_env(f"{args.num_agents}_agent_env", lambda _: ParallelPettingZooEnv(waterworld_v4.parallel_env(n_pursuers=args.num_agents)))
policies = {f"pursuer_{i}" for i in range(args.num_agents)}


resto_config = (
    get_trainable_cls("PPO")
    .get_default_config()
    .environment(f"{args.num_agents}_agent_env")
    .multi_agent(
        policies=policies,
        policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs={p: RLModuleSpec() for p in policies},
        ),
    )
    .evaluation(
        evaluation_interval=1,
    )
)

from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.air.integrations.wandb import WandbLoggerCallback, setup_wandb

class WandbCallbackWrapper(DefaultCallbacks,WandbLoggerCallback):
    """ """
    # setup(self, *args, **kwargs)
    # log_trial_start(self)
    # log_trial_result(self)
    # log_trial_end(self)
    # on_experiment_end(self)

    def __init__(self, *args, **kwargs):
        DefaultCallbacks.__init__(self, **kwargs)
        WandbLoggerCallback.__init__(self, **kwargs)
        self.setup(self, *args, **kwargs)

    def on_episode_start(self):
        self.log_trial_start(self)

    def on_episode_end(self):
        self.log_trial_end(self)

    def on_train_result(self):
        self.log_trial_result(self)

    #self.on_experiment_end(self)


resto_algo = resto_config.callbacks(
    WandbCallbackWrapper
)

args.wandb_project='delete_me'
args.wandb_key='913528a8e92bf601b6eb055a459bcc89130c7f5f'
conf={}

setup_wandb(conf, args.wandb_key, project=args.wandb_project)

resto_algo = resto_config.build()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))


ValueError: Please pass the project name as argument or through the WANDB_PROJECT_NAME environment variable.

In [37]:
import numpy as np
import sys
from glob import glob
from os import path

import ray
from ray.tune import Trainable, Tuner
from ray.tune.registry import register_trainable, validate_trainable
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.rllib.policy.policy import Policy
from ray.tune.registry import get_trainable_cls, register_env

from pettingzoo.sisl import waterworld_v4


parser = add_rllib_example_script_args(
    default_iters=10,
    default_timesteps=1000000,
    default_reward=300,
)
args = parser.parse_args(args=[])
args.num_env_runners = 10
args.env = 'waterworld'
args.algo = 'PPO'
args.num_agents = 4
args.test_agents = 4

checkpoint_path = f"/root/test/{args.env}/{args.algo}/{args.num_agents}_agent"

sup = sorted(glob(checkpoint_path+'/*'))[0]

pols = glob(sup+"/policies/*")
specs = {path.basename(p) : Policy.from_checkpoint(p) for p in pols}

register_env(f"{args.num_agents}_agent_env", lambda _: ParallelPettingZooEnv(waterworld_v4.parallel_env(n_pursuers=args.num_agents)))
policies = {f"pursuer_{i}" for i in range(args.num_agents)}


resto_config = (
    get_trainable_cls("PPO")
    .get_default_config()
    .environment(f"{args.num_agents}_agent_env")
    .multi_agent(
        policies=policies,
        policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs={p: RLModuleSpec() for p in policies},
        ),
    )
    .evaluation(
        evaluation_interval=1,
    )
)
resto_algo = resto_config.build()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))


In [38]:
from ray.tune.utils import flatten_dict
from ray.air.integrations.wandb import _is_allowed_type

result = resto_algo.train()

flat_result = flatten_dict(result, delimiter="/")
log = {}
for k, v in flat_result.items():
    log[k] = v

In [41]:
print(log.keys())

dict_keys(['num_healthy_workers', 'num_in_flight_async_sample_reqs', 'num_remote_worker_restarts', 'num_agent_steps_sampled', 'num_agent_steps_trained', 'num_env_steps_sampled', 'num_env_steps_trained', 'num_env_steps_sampled_this_iter', 'num_env_steps_trained_this_iter', 'num_env_steps_sampled_throughput_per_sec', 'num_env_steps_trained_throughput_per_sec', 'timesteps_total', 'num_env_steps_sampled_lifetime', 'num_agent_steps_sampled_lifetime', 'num_steps_trained_this_iter', 'agent_timesteps_total', 'done', 'training_iteration', 'trial_id', 'date', 'timestamp', 'time_this_iter_s', 'time_total_s', 'pid', 'hostname', 'node_ip', 'time_since_restore', 'iterations_since_restore', 'evaluation/num_agent_steps_sampled_this_iter', 'evaluation/num_env_steps_sampled_this_iter', 'evaluation/timesteps_this_iter', 'evaluation/num_healthy_workers', 'evaluation/num_in_flight_async_reqs', 'evaluation/num_remote_worker_restarts', 'info/num_env_steps_sampled', 'info/num_env_steps_trained', 'info/num_age