In [12]:
"""Example showing how one can implement a simple self-play training workflow.

Uses the open spiel adapter of RLlib with the "connect_four" game and
a multi-agent setup with a "main" policy and n "main_v[x]" policies
(x=version number), which are all at-some-point-frozen copies of
"main". At the very beginning, "main" plays against RandomPolicy.

Checks for the training progress after each training update via a custom
callback. We simply measure the win rate of "main" vs the opponent
("main_v[x]" or RandomPolicy at the beginning) by looking through the
achieved rewards in the episodes in the train batch. If this win rate
reaches some configurable threshold, we add a new policy to
the policy map (a frozen copy of the current "main" one) and change the
policy_mapping_fn to make new matches of "main" vs any of the previous
versions of "main" (including the just added one).

After training for n iterations, a configurable number of episodes can
be played by the user against the "main" agent on the command line.
"""

import functools

import numpy as np
import torch

from ray.tune.result import TRAINING_ITERATION
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env.utils import try_import_pyspiel, try_import_open_spiel
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule
from ray.rllib.examples.multi_agent.utils import (
    ask_user_for_action,
    SelfPlayCallback,
    SelfPlayCallbackOldAPIStack,
)
from ray.rllib.examples._old_api_stack.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.tune.registry import get_trainable_cls, register_env

open_spiel = try_import_open_spiel(error=True)
pyspiel = try_import_pyspiel(error=True)

# Import after try_import_open_spiel, so we can error out with hints.
from open_spiel.python.rl_environment import Environment  # noqa: E402



In [13]:

parser = add_rllib_example_script_args(default_timesteps=2000000)
parser.set_defaults(
    env="connect_four",
    checkpoint_freq=1,
    checkpoint_at_end=True,
)
parser.add_argument(
    "--win-rate-threshold",
    type=float,
    default=0.95,
    help="Win-rate at which we setup another opponent by freezing the "
    "current main policy and playing against a uniform distribution "
    "of previously frozen 'main's from here on.",
)
parser.add_argument(
    "--min-league-size",
    type=float,
    default=3,
    help="Minimum number of policies/RLModules to consider the test passed. "
    "The initial league size is 2: `main` and `random`. "
    "`--min-league-size=3` thus means that one new policy/RLModule has been "
    "added so far (b/c the `main` one has reached the `--win-rate-threshold "
    "against the `random` Policy/RLModule).",
)
parser.add_argument(
    "--num-episodes-human-play",
    type=int,
    default=10,
    help="How many episodes to play against the user on the command "
    "line after training has finished.",
)
parser.add_argument(
    "--from-checkpoint",
    type=str,
    default=None,
    help="Full path to a checkpoint file for restoring a previously saved "
    "Algorithm state.",
)



_StoreAction(option_strings=['--from-checkpoint'], dest='from_checkpoint', nargs=None, const=None, default=None, type=<class 'str'>, choices=None, required=False, help='Full path to a checkpoint file for restoring a previously saved Algorithm state.', metavar=None)

In [14]:
import sys
sys.argv = [
    'notebook_script.py',


]
print("参数设置完成")

参数设置完成


In [15]:

args = parser.parse_args()


In [16]:

register_env("open_spiel_env", lambda _: OpenSpielEnv(pyspiel.load_game(args.env)))

def agent_to_module_mapping_fn(agent_id, episode, **kwargs):
    # agent_id = [0|1] -> module depends on episode ID
    # This way, we make sure that both modules sometimes play agent0
    # (start player) and sometimes agent1 (player to move 2nd).
    return "main" if hash(episode.id_) % 2 == agent_id else "random"

def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    return "main" if episode.episode_id % 2 == agent_id else "random"


In [17]:

config = (
    get_trainable_cls(args.algo)
    .get_default_config()
    .environment("open_spiel_env")
    # Set up the main piece in this experiment: The league-bases self-play
    # callback, which controls adding new policies/Modules to the league and
    # properly matching the different policies in the league with each other.
    .callbacks(
        functools.partial(
            (
                SelfPlayCallback
                if args.enable_new_api_stack
                else SelfPlayCallbackOldAPIStack
            ),
            win_rate_threshold=args.win_rate_threshold,
        )
    )
    .env_runners(
        num_env_runners=(args.num_env_runners or 2),
        num_envs_per_env_runner=1 if args.enable_new_api_stack else 5,
    )
    .multi_agent(
        # Initial policy map: Random and default algo one. This will be expanded
        # to more policy snapshots taken from "main" against which "main"
        # will then play (instead of "random"). This is done in the
        # custom callback defined above (`SelfPlayCallback`).
        policies=(
            {
                # Our main policy, we'd like to optimize.
                "main": PolicySpec(),
                # An initial random opponent to play against.
                "random": PolicySpec(policy_class=RandomPolicy),
            }
            if not args.enable_new_api_stack
            else {"main", "random"}
        ),
        # Assign agent 0 and 1 randomly to the "main" policy or
        # to the opponent ("random" at first). Make sure (via episode_id)
        # that "main" always plays against "random" (and not against
        # another "main").
        policy_mapping_fn=(
            agent_to_module_mapping_fn
            if args.enable_new_api_stack
            else policy_mapping_fn
        ),
        # Always just train the "main" policy.
        policies_to_train=["main"],
    )
    .rl_module(
        model_config=DefaultModelConfig(fcnet_hiddens=[512, 512]),
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs={
                "main": RLModuleSpec(),
                "random": RLModuleSpec(module_class=RandomRLModule),
            }
        ),
    )
)

# Only for PPO, change the `num_epochs` setting.
if args.algo == "PPO":
    config.training(num_epochs=20)

stop = {
    NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
    TRAINING_ITERATION: args.stop_iters,
    "league_size": args.min_league_size,
}


In [18]:

# Train the "main" policy to play really well using self-play.
results = None
if not args.from_checkpoint:
    results = run_rllib_example_script_experiment(
        config, args, stop=stop, keep_ray_up=True
    )


2025-06-24 23:01:16,887	INFO worker.py:1747 -- Calling ray.init() again after it has already been called.


0,1
Current time:,2025-06-24 23:02:27
Running for:,00:01:11.00
Memory:,23.8/125.5 GiB

Trial name,status,loc,iter,total time (s),ts,num_healthy_workers,actor_manager_num_ou tstanding_async_reqs,num_remote_worker_re starts
PPO_open_spiel_env_b2703_00000,TERMINATED,192.168.0.25:667699,35,64.6118,140000,2,0,0




Trial name,actor_manager_num_outstanding_async_reqs,agent_timesteps_total,counters,custom_metrics,env_runners,episode_media,info,league_size,num_agent_steps_sampled,num_agent_steps_sampled_lifetime,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_lifetime,num_env_steps_sampled_this_iter,num_env_steps_sampled_throughput_per_sec,num_env_steps_trained,num_env_steps_trained_this_iter,num_env_steps_trained_throughput_per_sec,num_healthy_workers,num_remote_worker_restarts,num_steps_trained_this_iter,perf,timers,win_rate
PPO_open_spiel_env_b2703_00000,0,139992,"{'num_env_steps_sampled': 140000, 'num_env_steps_trained': 140000, 'num_agent_steps_sampled': 139992, 'num_agent_steps_trained': 139992}",{},"{'episode_reward_max': 0.0, 'episode_reward_min': -0.8999999999999998, 'episode_reward_mean': np.float64(-0.011873350923482852), 'episode_len_mean': np.float64(10.62269129287599), 'episode_media': {}, 'episodes_timesteps_total': 4026, 'policy_reward_min': {'random': np.float64(-1.4000000000000001), 'main': np.float64(-1.2000000000000002)}, 'policy_reward_max': {'random': np.float64(1.0), 'main': np.float64(1.0)}, 'policy_reward_mean': {'random': np.float64(-0.9097625329815304), 'main': np.float64(0.8978891820580474)}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [0.0, -0.09999999999999998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.4999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.10000000000000009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.09999999999999998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.10000000000000009, 0.0, -0.8999999999999998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.20000000000000007, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.30000000000000004, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.09999999999999998, 0.0, 0.0, -0.10000000000000009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8000000000000002, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.10000000000000009, -0.09999999999999998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.10000000000000009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.10000000000000009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.09999999999999998, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.20000000000000018], 'episode_lengths': [12, 25, 8, 8, 24, 21, 7, 8, 7, 9, 8, 13, 17, 8, 7, 22, 17, 8, 7, 7, 7, 22, 8, 8, 13, 7, 7, 8, 8, 13, 16, 9, 7, 8, 7, 9, 8, 20, 15, 8, 11, 20, 10, 20, 9, 9, 9, 8, 8, 7, 19, 17, 8, 8, 20, 8, 8, 8, 7, 7, 7, 15, 15, 8, 8, 10, 9, 7, 7, 7, 7, 7, 11, 7, 7, 9, 9, 22, 8, 7, 25, 12, 7, 7, 18, 7, 8, 18, 19, 7, 8, 7, 8, 22, 8, 24, 8, 7, 8, 12, 8, 21, 8, 18, 22, 8, 7, 19, 7, 7, 7, 8, 8, 7, 7, 18, 7, 8, 7, 19, 7, 8, 7, 8, 7, 8, 13, 7, 8, 8, 7, 10, 7, 7, 9, 9, 7, 16, 8, 14, 8, 8, 8, 19, 8, 7, 16, 7, 8, 7, 8, 20, 7, 7, 8, 7, 7, 8, 12, 7, 7, 17, 8, 7, 7, 8, 10, 21, 7, 15, 10, 7, 18, 12, 7, 13, 7, 8, 8, 8, 7, 8, 7, 12, 18, 14, 8, 8, 7, 9, 7, 10, 29, 10, 8, 8, 14, 14, 9, 8, 8, 10, 8, 8, 15, 9, 19, 8, 27, 7, 8, 8, 18, 10, 16, 15, 7, 16, 14, 18, 8, 19, 18, 8, 8, 7, 7, 25, 8, 10, 13, 7, 7, 7, 8, 12, 13, 15, 9, 15, 8, 17, 21, 34, 8, 8, 8, 8, 7, 7, 7, 7, 8, 21, 15, 9, 8, 8, 7, 8, 10, 8, 9, 7, 7, 8, 7, 10, 7, 7, 7, 17, 7, 8, 8, 9, 10, 7, 8, 7, 12, 8, 8, 15, 20, 10, 7, 8, 23, 9, 8, 14, 8, 8, 8, 8, 16, 15, 12, 7, 8, 7, 21, 7, 7, 8, 7, 8, 7, 11, 9, 26, 7, 9, 9, 20, 7, 7, 15, 14, 15, 7, 9, 14, 7, 8, 9, 15, 21, 26, 16, 18, 8, 20, 8, 12, 8, 9, 7, 8, 11, 8, 13, 7, 15, 8, 7, 8, 8, 8, 8, 8, 7, 7, 19, 7, 17, 10, 8, 7, 8, 8, 8, 7, 14, 7, 7, 24, 7, 8, 13, 8, 7, 8, 7, 7, 7, 8, 22], 'policy_random_reward': [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.1, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.1, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.2, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.1, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.2000000000000002, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 0.9, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.4000000000000001, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.1, -1.1, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.1, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.1, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 0.9, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.1]}, 'sampler_perf': {'mean_raw_obs_processing_ms': np.float64(0.8892811469404559), 'mean_inference_ms': np.float64(0.5795703321644742), 'mean_action_processing_ms': np.float64(0.1543484845921387), 'mean_env_wait_ms': np.float64(0.13519458640054344), 'mean_env_render_ms': np.float64(0.0)}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': np.float64(0.0019858254606302307), 'StateBufferConnector_ms': np.float64(0.003366891815668675), 'ViewRequirementAgentConnector_ms': np.float64(0.06566296152200422)}, 'num_episodes': 379, 'episode_return_max': 0.0, 'episode_return_min': -0.8999999999999998, 'episode_return_mean': np.float64(-0.011873350923482852), 'episodes_this_iter': 379}",{},"{'learner': {'main': {'learner_stats': {'allreduce_latency': np.float64(0.0), 'grad_gnorm': np.float32(0.9769495), 'cur_kl_coeff': np.float64(0.01875), 'cur_lr': np.float64(5e-05), 'total_loss': np.float64(0.2106280327652177), 'policy_loss': np.float64(-0.01755137841203524), 'vf_loss': np.float64(0.22808257865116877), 'vf_explained_var': np.float64(0.1395652818329194), 'kl': np.float64(0.005164267819673726), 'entropy': np.float64(0.45911177037393347), 'entropy_coeff': np.float64(0.0)}, 'model': {}, 'custom_metrics': {}, 'num_agent_steps_trained': np.float64(122.88235294117646), 'num_grad_updates_lifetime': np.float64(11450.5), 'diff_num_grad_updates_vs_sampler_policy': np.float64(169.5)}}, 'num_env_steps_sampled': 140000, 'num_env_steps_trained': 140000, 'num_agent_steps_sampled': 139992, 'num_agent_steps_trained': 139992}",3,139992,139992,139992,140000,140000,4000,2110.04,140000,4000,2110.04,2,0,4000,"{'cpu_util_percent': np.float64(9.15), 'ram_util_percent': np.float64(19.0)}","{'training_iteration_time_ms': 1876.217, 'restore_workers_time_ms': 0.013, 'training_step_time_ms': 1876.18, 'sample_time_ms': 734.809, 'learn_time_ms': 1138.578, 'learn_throughput': 3513.153, 'synch_weights_time_ms': 2.591}",0.952507


2025-06-24 23:02:27,893	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/qrbao/ray_results/PPO_2025-06-24_23-01-16' in 0.0398s.
2025-06-24 23:02:28,479	INFO tune.py:1041 -- Total run time: 71.59 seconds (70.96 seconds for the tuning loop).


In [24]:
import sys
sys.argv = [
    'notebook_script.py',

]
print("参数设置完成")

参数设置完成


In [25]:

# # Restore trained Algorithm (set to non-explore behavior) and play against
# # human on command line.
# if args.num_episodes_human_play > 0:
#     num_episodes = 0
#     config.explore = False
#     algo = config.build()
#     if args.from_checkpoint:
#         algo.restore(args.from_checkpoint)
#     else:
#         checkpoint = results.get_best_result().checkpoint
#         if not checkpoint:
#             raise ValueError("No last checkpoint found in results!")
#         algo.restore(checkpoint)

#     if args.enable_new_api_stack:
#         rl_module = algo.get_module("main")

#     # Play from the command line against the trained agent
#     # in an actual (non-RLlib-wrapped) open-spiel env.
#     human_player = 1
#     env = Environment(args.env)

#     while num_episodes < args.num_episodes_human_play:
#         print("You play as {}".format("o" if human_player else "x"))
#         time_step = env.reset()
#         while not time_step.last():
#             player_id = time_step.observations["current_player"]
#             if player_id == human_player:
#                 action = ask_user_for_action(time_step)
#             else:
#                 obs = np.array(time_step.observations["info_state"][player_id])
#                 if args.enable_new_api_stack:
#                     action = np.argmax(
#                         rl_module.forward_inference(
#                             {"obs": torch.from_numpy(obs).unsqueeze(0).float()}
#                         )["action_dist_inputs"][0].numpy()
#                     )
#                 else:
#                     action = algo.compute_single_action(obs, policy_id="main")
#                 # In case computer chooses an invalid action, pick a
#                 # random one.
#                 legal = time_step.observations["legal_actions"][player_id]
#                 if action not in legal:
#                     action = np.random.choice(legal)
#             time_step = env.step([action])
#             print(f"\n{env.get_state}")

#         print(f"\n{env.get_state}")

#         print("End of game!")
#         if time_step.rewards[human_player] > 0:
#             print("You win")
#         elif time_step.rewards[human_player] < 0:
#             print("You lose")
#         else:
#             print("Draw")
#         # Switch order of players.
#         human_player = 1 - human_player

#         num_episodes += 1

#     algo.stop()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-06-24 23:04:53,992	INFO trainable.py:577 -- Restored on 192.168.0.25 from checkpoint: Checkpoint(filesystem=local, path=/home/qrbao/ray_results/PPO_2025-06-24_23-01-16/PPO_open_spiel_env_b2703_00000_0_2025-06-24_23-01-16/checkpoint_000034)


You play as o

.......
.......
.......
.......
.......
...x...

Choose an action from [0, 1, 2, 3, 4, 5, 6]:
Choose an action from [0, 1, 2, 3, 4, 5, 6]:


KeyboardInterrupt: Interrupted by user

In [None]:

env = waterworld_v4.env(render_mode="human",n_predators=2,n_preys=2,n_evaders=5,n_obstacles=1,n_poisons=1)
env.reset(seed=42)

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    print({agent: reward})


    if termination or truncation:
        action = None
    else:
        # this is where you would insert your policy
        action = env.action_space(agent).sample()

    env.step(action)
env.close()