# Introduction

So far we considered single-agent environments but they have their limitations. What if we want the agent to cooperate or compete with other agent? Both [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/) and [gymnasium](https://gymnasium.farama.org/) do not natively implement such features, but we can bypass this problemn using [pettingzoo](https://pettingzoo.farama.org/) library. In this notebook we will augument training possibilities by creating multiagent environments using pettingZoo library with StableBaselines3.

# PettingZoo API

PettingZoo is a library based on gymnasium enabling multiagent environments. It contains AEC (Agent Environment Cycle) API for environments which agents perform actions one after another and Parallel API for simultaneous actions and observations. Moreover this library features various wrappers enabling even more features. We will describe some of them below.

### AEC API

AEC (Agent Environment Cycle) API allows to represent any type of game for multiagent reinforcement learning. Here agents execute actions in turns, one after another. In following example we create pettingzoo environment representing rock paper scissors game. There is no model nor policy, only two players taking random actions.

In [1]:
from pettingzoo.classic import rps_v2

env = rps_v2.env(render_mode="human")
env.reset(seed=42)

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()

    if termination or truncation:
        action = None
    else:
        # this is where you would insert your policy
        action = env.action_space(agent).sample()

    env.step(action)
env.close()

In games not all moves are usually available to players. In that case we need to implement action mask that gives us all available actions for given player, below simple chess example with random moves and masking. This pettingZoo environment implements observation as dict with environment observation and action mask. We will later show how to train models with masks. For now this is how we can apply mask to sample function.

In [2]:
from pettingzoo.classic import chess_v6

env = chess_v6.env(render_mode="human")
env.metadata['render_fps'] = 30
env.reset(seed=42)

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()

    if termination or truncation:
        action = None
    else:
        # invalid action masking is optional and environment-dependent
        if "action_mask" in info:
            mask = info["action_mask"]
        elif isinstance(observation, dict) and "action_mask" in observation:
            mask = observation["action_mask"]
        else:
            mask = None
        # this is where you would insert your policy
        action = env.action_space(agent).sample(mask)

    env.step(action)
env.close()

### Parallel API

For simultaneous actions and observations we use alternative Parallel API. We can get actions of all agents at the same time, here with sample function. Below example shows pistonball environment in which agents cooperate to move ball to other side.

In [3]:
from pettingzoo.butterfly import pistonball_v6
parallel_env = pistonball_v6.parallel_env(render_mode="human")
observations, infos = parallel_env.reset(seed=42)

while parallel_env.agents:
    # this is where you would insert your policy
    actions = {agent: parallel_env.action_space(agent).sample() for agent in parallel_env.agents}

    observations, rewards, terminations, truncations, infos = parallel_env.step(actions)
parallel_env.close()

### Wrapers

PettingZoo features some usefull wrapers. We can convert AEC environments to Parallel and other way around with ```aec_to_parallel``` and ```parallel_to_aec```. Other usefull wraper is ```TerminateIllegalWrapper``` that disallows illegal moves. For parallel environments we need to wrap them first in ```BaseParallelWraper```. More wrappers can be found on pettingZoo documentation website. Keep in mind that most of pettingZoo native environments are already wrapped with appropriate environments so there is no need to do this again.

In [4]:
from pettingzoo.utils.conversions import aec_to_parallel
from pettingzoo.butterfly import pistonball_v6
env = pistonball_v6.env()
env = aec_to_parallel(env)

In [5]:
from pettingzoo.utils import parallel_to_aec
from pettingzoo.butterfly import pistonball_v6
env = pistonball_v6.parallel_env()
env = parallel_to_aec(env)

In [6]:
from pettingzoo.utils import TerminateIllegalWrapper
from pettingzoo.classic import tictactoe_v3
env = tictactoe_v3.env()
env = TerminateIllegalWrapper(env, illegal_reward=-1)

env.reset()
for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    if termination or truncation:
        action = None
    else:
        # this is where you would insert your policy
        action = env.action_space(agent).sample()
    env.step(action)
env.close()

In [7]:
from pettingzoo.utils import BaseParallelWrapper
from pettingzoo.butterfly import pistonball_v6

parallel_env = pistonball_v6.parallel_env(render_mode="human")
parallel_env = BaseParallelWrapper(parallel_env)

observations, infos = parallel_env.reset()

while parallel_env.agents:
    actions = {agent: parallel_env.action_space(agent).sample(
    ) for agent in parallel_env.agents}  # this is where you would insert your policy
    observations, rewards, terminations, truncations, infos = parallel_env.step(
        actions)

parallel_env.close()

# Training and evaluation

Import libraries

In [8]:
from __future__ import annotations

import glob
import os
import time

import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

from pettingzoo.sisl import waterworld_v4

Create and wrap environemnt.

- ```ss.pettingzoo_env_to_vec_env_v1``` makes environment compatibile with stable baselines3
- ```ss.concat_vec_envs_v1``` makes n simulations at the same time

In [9]:
env = waterworld_v4.parallel_env()
env.reset(seed=42)

env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")

Create PPO model, you can tune these and more parameters. This model uses ```MlpPolicy``` what means using multi-layer perceptron network.

In [10]:
model = PPO(
    MlpPolicy,
    env,
    verbose=3,
    learning_rate=1e-3,
    batch_size=256,
    device='cpu'
)

Using cpu device


Train and save model

In [11]:
model.learn(total_timesteps=10*(2**15), progress_bar=True)
model.save('waterworld_model')
env.close()


Output()

------------------------------
| time/              |       |
|    fps             | 1782  |
|    iterations      | 1     |
|    time_elapsed    | 18    |
|    total_timesteps | 32768 |
------------------------------


------------------------------------------
| time/                   |              |
|    fps                  | 1554         |
|    iterations           | 2            |
|    time_elapsed         | 42           |
|    total_timesteps      | 65536        |
| train/                  |              |
|    approx_kl            | 0.0038197297 |
|    clip_fraction        | 0.0402       |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.8         |
|    explained_variance   | 0.0024       |
|    learning_rate        | 0.001        |
|    loss                 | 5.17         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.0004      |
|    std                  | 0.973        |
|    value_loss           | 12.2         |
------------------------------------------


------------------------------------------
| time/                   |              |
|    fps                  | 1503         |
|    iterations           | 3            |
|    time_elapsed         | 65           |
|    total_timesteps      | 98304        |
| train/                  |              |
|    approx_kl            | 0.0036872597 |
|    clip_fraction        | 0.0309       |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.77        |
|    explained_variance   | 0.317        |
|    learning_rate        | 0.001        |
|    loss                 | 7.4          |
|    n_updates            | 20           |
|    policy_gradient_loss | -0.000907    |
|    std                  | 0.963        |
|    value_loss           | 15.6         |
------------------------------------------


------------------------------------------
| time/                   |              |
|    fps                  | 1477         |
|    iterations           | 4            |
|    time_elapsed         | 88           |
|    total_timesteps      | 131072       |
| train/                  |              |
|    approx_kl            | 0.0052369395 |
|    clip_fraction        | 0.0665       |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.73        |
|    explained_variance   | 0.383        |
|    learning_rate        | 0.001        |
|    loss                 | 6.74         |
|    n_updates            | 30           |
|    policy_gradient_loss | -0.000769    |
|    std                  | 0.948        |
|    value_loss           | 14.4         |
------------------------------------------


------------------------------------------
| time/                   |              |
|    fps                  | 1463         |
|    iterations           | 5            |
|    time_elapsed         | 111          |
|    total_timesteps      | 163840       |
| train/                  |              |
|    approx_kl            | 0.0064727897 |
|    clip_fraction        | 0.0852       |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.69        |
|    explained_variance   | 0.46         |
|    learning_rate        | 0.001        |
|    loss                 | 7.57         |
|    n_updates            | 40           |
|    policy_gradient_loss | -0.000463    |
|    std                  | 0.927        |
|    value_loss           | 17.2         |
------------------------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 1454        |
|    iterations           | 6           |
|    time_elapsed         | 135         |
|    total_timesteps      | 196608      |
| train/                  |             |
|    approx_kl            | 0.007325147 |
|    clip_fraction        | 0.0951      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.66       |
|    explained_variance   | 0.445       |
|    learning_rate        | 0.001       |
|    loss                 | 9.74        |
|    n_updates            | 50          |
|    policy_gradient_loss | 0.000506    |
|    std                  | 0.915       |
|    value_loss           | 19.8        |
-----------------------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 1448        |
|    iterations           | 7           |
|    time_elapsed         | 158         |
|    total_timesteps      | 229376      |
| train/                  |             |
|    approx_kl            | 0.010654227 |
|    clip_fraction        | 0.109       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.62       |
|    explained_variance   | 0.445       |
|    learning_rate        | 0.001       |
|    loss                 | 12.6        |
|    n_updates            | 60          |
|    policy_gradient_loss | 0.000134    |
|    std                  | 0.894       |
|    value_loss           | 20.3        |
-----------------------------------------


------------------------------------------
| time/                   |              |
|    fps                  | 1441         |
|    iterations           | 8            |
|    time_elapsed         | 181          |
|    total_timesteps      | 262144       |
| train/                  |              |
|    approx_kl            | 0.0055977553 |
|    clip_fraction        | 0.0748       |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.58        |
|    explained_variance   | 0.374        |
|    learning_rate        | 0.001        |
|    loss                 | 9.86         |
|    n_updates            | 70           |
|    policy_gradient_loss | -0.00161     |
|    std                  | 0.874        |
|    value_loss           | 17.7         |
------------------------------------------


------------------------------------------
| time/                   |              |
|    fps                  | 1436         |
|    iterations           | 9            |
|    time_elapsed         | 205          |
|    total_timesteps      | 294912       |
| train/                  |              |
|    approx_kl            | 0.0073400927 |
|    clip_fraction        | 0.0923       |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.52        |
|    explained_variance   | 0.369        |
|    learning_rate        | 0.001        |
|    loss                 | 6.79         |
|    n_updates            | 80           |
|    policy_gradient_loss | -0.000752    |
|    std                  | 0.848        |
|    value_loss           | 17.3         |
------------------------------------------


-----------------------------------------
| time/                   |             |
|    fps                  | 1433        |
|    iterations           | 10          |
|    time_elapsed         | 228         |
|    total_timesteps      | 327680      |
| train/                  |             |
|    approx_kl            | 0.007446139 |
|    clip_fraction        | 0.0914      |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.46       |
|    explained_variance   | 0.411       |
|    learning_rate        | 0.001       |
|    loss                 | 10          |
|    n_updates            | 90          |
|    policy_gradient_loss | -0.00159    |
|    std                  | 0.826       |
|    value_loss           | 17.4        |
-----------------------------------------


Evaluate model by simulating n games and collecting rewards

In [23]:
env = waterworld_v4.env()
env.reset(seed=42)
model = PPO.load('waterworld_model')

rewards = {agent: 0 for agent in env.possible_agents}

for i in range(10):
    env.reset(seed=i)

    for agent in env.agent_iter():
        obs, reward, termination, truncation, info = env.last()

        for a in env.agents:
            rewards[a] += env.rewards[a]
        if termination or truncation:
            break
        else:
            act = model.predict(obs, deterministic=True)[0]

        env.step(act)
env.close()

avg_reward = sum(rewards.values()) / len(rewards.values())
print("Rewards: ", rewards)
print(f"Avg reward: {avg_reward}")



Rewards:  {'pursuer_0': np.float64(376.1775285425326), 'pursuer_1': np.float64(152.46351438828694)}
Avg reward: 264.3205214654098


You can visualise model by setting ```render_mode``` flag to ```"human"```

In [24]:
env = waterworld_v4.env(render_mode='human')
env.metadata['render_fps'] = 60
env.reset(seed=42)
model = model = PPO.load('waterworld_model')

for agent in env.agent_iter():
    obs, reward, termination, truncation, info = env.last()

    if termination or truncation:
        break
    else:
        act = model.predict(obs, deterministic=True)[0]

    env.step(act)
env.close()

# Training environments with action mask

Following pettingZoo documentation to train environments containing action mask in Stable Baselines3 we need to define below wrapper. Feel free to copy paste it but let's try to explain what happens here.

To use MaskablePPO from ```Stable Baselines3 - Contrib``` we must wrap environment in ```ActionMasker```. Moreover environment must define action_masks method returning action mask. Alternativelly we can pass function reference to ActionMasker that takes environment and returns action mask like in example below (```mask_fn```)

PettingZoo environments choose to approach this in following fasion. They create observation as dictionary with ```observation``` and ```action_mask``` and then split it in wrapper. Here is the wrapper:

In [14]:
import glob
import os
import time

import gymnasium as gym
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker

import pettingzoo.utils
from pettingzoo.classic import connect_four_v3

In [15]:
# To pass into other gymnasium wrappers, we need to ensure that pettingzoo's wrappper
# can also be a gymnasium Env. Thus, we subclass under gym.Env as well.
class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper, gym.Env):
    """Wrapper to allow PettingZoo environments to be used with SB3 illegal action masking."""

    def reset(self, seed=None, options=None):
        """Gymnasium-like reset function which assigns obs/action spaces to be the same for each agent.

        This is required as SB3 is designed for single-agent RL and doesn't expect obs/action spaces to be functions
        """
        super().reset(seed, options)

        # Strip the action mask out from the observation space
        self.observation_space = super().observation_space(self.possible_agents[0])[
            "observation"
        ]
        self.action_space = super().action_space(self.possible_agents[0])

        # Return initial observation, info (PettingZoo AEC envs do not by default)
        return self.observe(self.agent_selection), {}

    def step(self, action):
        """Gymnasium-like step function, returning observation, reward, termination, truncation, info.

        The observation is for the next agent (used to determine the next action), while the remaining
        items are for the agent that just acted (used to understand what just happened).
        """
        current_agent = self.agent_selection

        super().step(action)

        next_agent = self.agent_selection
        return (
            self.observe(next_agent),
            self._cumulative_rewards[current_agent],
            self.terminations[current_agent],
            self.truncations[current_agent],
            self.infos[current_agent],
        )

    def observe(self, agent):
        """Return only raw observation, removing action mask."""
        return super().observe(agent)["observation"]

    def action_mask(self):
        """Separate function used in order to access the action mask."""
        return super().observe(self.agent_selection)["action_mask"]

In [16]:
def mask_fn(env):
    # Do whatever you'd like in this function to return the action mask
    # for the current env. In this example, we assume the env has a
    # helpful method we can rely on.
    return env.action_mask()

### training

In [17]:
env = connect_four_v3.env()
env = SB3ActionMaskWrapper(env)
env.reset(seed=42)
env = ActionMasker(env, mask_fn)

In [18]:
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [19]:
model.learn(total_timesteps=10*(2**12), progress_bar=True)
model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")
env.close()

Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 21.1     |
|    ep_rew_mean     | 1        |
| time/              |          |
|    fps             | 432      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 20.7        |
|    ep_rew_mean          | 0.99        |
| time/                   |             |
|    fps                  | 351         |
|    iterations           | 2           |
|    time_elapsed         | 11          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.006231459 |
|    clip_fraction        | 0.027       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.9        |
|    explained_variance   | -2.53       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0127     |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0157     |
|    value_loss           | 0.0748      |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 20.2        |
|    ep_rew_mean          | 0.99        |
| time/                   |             |
|    fps                  | 334         |
|    iterations           | 3           |
|    time_elapsed         | 18          |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.007466279 |
|    clip_fraction        | 0.0582      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.9        |
|    explained_variance   | -0.31       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0295     |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0204     |
|    value_loss           | 0.0143      |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 21.5        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 325         |
|    iterations           | 4           |
|    time_elapsed         | 25          |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.009238483 |
|    clip_fraction        | 0.0651      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.87       |
|    explained_variance   | -0.489      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0195     |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.0238     |
|    value_loss           | 0.00731     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 18.5        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 321         |
|    iterations           | 5           |
|    time_elapsed         | 31          |
|    total_timesteps      | 10240       |
| train/                  |             |
|    approx_kl            | 0.010334175 |
|    clip_fraction        | 0.0889      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.86       |
|    explained_variance   | -0.567      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0461     |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0266     |
|    value_loss           | 0.00438     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 18          |
|    ep_rew_mean          | 0.99        |
| time/                   |             |
|    fps                  | 318         |
|    iterations           | 6           |
|    time_elapsed         | 38          |
|    total_timesteps      | 12288       |
| train/                  |             |
|    approx_kl            | 0.011570411 |
|    clip_fraction        | 0.108       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.85       |
|    explained_variance   | -0.27       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0787     |
|    n_updates            | 50          |
|    policy_gradient_loss | -0.0308     |
|    value_loss           | 0.0031      |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 15.7        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 317         |
|    iterations           | 7           |
|    time_elapsed         | 45          |
|    total_timesteps      | 14336       |
| train/                  |             |
|    approx_kl            | 0.011280136 |
|    clip_fraction        | 0.104       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.81       |
|    explained_variance   | -0.392      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0199     |
|    n_updates            | 60          |
|    policy_gradient_loss | -0.026      |
|    value_loss           | 0.00442     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 15.9        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 315         |
|    iterations           | 8           |
|    time_elapsed         | 51          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.013332065 |
|    clip_fraction        | 0.147       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.78       |
|    explained_variance   | 0.0181      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0441     |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.0332     |
|    value_loss           | 0.002       |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 13.8        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 313         |
|    iterations           | 9           |
|    time_elapsed         | 58          |
|    total_timesteps      | 18432       |
| train/                  |             |
|    approx_kl            | 0.013599146 |
|    clip_fraction        | 0.14        |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.73       |
|    explained_variance   | 0.0693      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0292     |
|    n_updates            | 80          |
|    policy_gradient_loss | -0.0348     |
|    value_loss           | 0.00173     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 13.7        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 312         |
|    iterations           | 10          |
|    time_elapsed         | 65          |
|    total_timesteps      | 20480       |
| train/                  |             |
|    approx_kl            | 0.012979541 |
|    clip_fraction        | 0.146       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.67       |
|    explained_variance   | 0.0999      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0318     |
|    n_updates            | 90          |
|    policy_gradient_loss | -0.0313     |
|    value_loss           | 0.00144     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 10.9        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 312         |
|    iterations           | 12          |
|    time_elapsed         | 78          |
|    total_timesteps      | 24576       |
| train/                  |             |
|    approx_kl            | 0.015458971 |
|    clip_fraction        | 0.181       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.49       |
|    explained_variance   | 0.278       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0636     |
|    n_updates            | 110         |
|    policy_gradient_loss | -0.0352     |
|    value_loss           | 0.000858    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 9.98        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 312         |
|    iterations           | 13          |
|    time_elapsed         | 85          |
|    total_timesteps      | 26624       |
| train/                  |             |
|    approx_kl            | 0.017691106 |
|    clip_fraction        | 0.167       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.37       |
|    explained_variance   | 0.441       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0493     |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.0345     |
|    value_loss           | 0.00059     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 8.92        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 312         |
|    iterations           | 14          |
|    time_elapsed         | 91          |
|    total_timesteps      | 28672       |
| train/                  |             |
|    approx_kl            | 0.017308254 |
|    clip_fraction        | 0.169       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.26       |
|    explained_variance   | 0.377       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0408     |
|    n_updates            | 130         |
|    policy_gradient_loss | -0.0347     |
|    value_loss           | 0.000489    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 8.21        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 312         |
|    iterations           | 15          |
|    time_elapsed         | 98          |
|    total_timesteps      | 30720       |
| train/                  |             |
|    approx_kl            | 0.017459856 |
|    clip_fraction        | 0.155       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.15       |
|    explained_variance   | 0.528       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0363     |
|    n_updates            | 140         |
|    policy_gradient_loss | -0.0305     |
|    value_loss           | 0.000318    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 8.01        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 311         |
|    iterations           | 16          |
|    time_elapsed         | 105         |
|    total_timesteps      | 32768       |
| train/                  |             |
|    approx_kl            | 0.015052963 |
|    clip_fraction        | 0.174       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.04       |
|    explained_variance   | 0.638       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0627     |
|    n_updates            | 150         |
|    policy_gradient_loss | -0.0341     |
|    value_loss           | 0.000222    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 7.45        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 310         |
|    iterations           | 17          |
|    time_elapsed         | 112         |
|    total_timesteps      | 34816       |
| train/                  |             |
|    approx_kl            | 0.014838042 |
|    clip_fraction        | 0.15        |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.932      |
|    explained_variance   | 0.737       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0509     |
|    n_updates            | 160         |
|    policy_gradient_loss | -0.0298     |
|    value_loss           | 0.000128    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 7.39        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 310         |
|    iterations           | 18          |
|    time_elapsed         | 118         |
|    total_timesteps      | 36864       |
| train/                  |             |
|    approx_kl            | 0.014503637 |
|    clip_fraction        | 0.156       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.816      |
|    explained_variance   | 0.85        |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0338     |
|    n_updates            | 170         |
|    policy_gradient_loss | -0.0296     |
|    value_loss           | 6.55e-05    |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 7.23       |
|    ep_rew_mean          | 1          |
| time/                   |            |
|    fps                  | 309        |
|    iterations           | 19         |
|    time_elapsed         | 125        |
|    total_timesteps      | 38912      |
| train/                  |            |
|    approx_kl            | 0.01360078 |
|    clip_fraction        | 0.121      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.732     |
|    explained_variance   | 0.877      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0402    |
|    n_updates            | 180        |
|    policy_gradient_loss | -0.0265    |
|    value_loss           | 5.15e-05   |
----------------------------------------


------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 7.11         |
|    ep_rew_mean          | 1            |
| time/                   |              |
|    fps                  | 307          |
|    iterations           | 20           |
|    time_elapsed         | 133          |
|    total_timesteps      | 40960        |
| train/                  |              |
|    approx_kl            | 0.0061922166 |
|    clip_fraction        | 0.0621       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.675       |
|    explained_variance   | 0.912        |
|    learning_rate        | 0.0003       |
|    loss                 | -0.00511     |
|    n_updates            | 190          |
|    policy_gradient_loss | -0.0146      |
|    value_loss           | 3.02e-05     |
------------------------------------------


### Evaluation

This is example eval function from pettingZoo documentation. This function is generalized to be used both for evaluating and rendering. First it creates environment with given parameters, then it finds latest saved policy and loads it, then it plays n games collecting it scores, then prints it and optionally renders game. There is more generalized functions like this in documentation that can be used with various environments. [SB3: Action Masked PPO for Connect Four](https://pettingzoo.farama.org/tutorials/sb3/connect_four/)

In [20]:
"""
Author: Elliot (https://github.com/elliottower)
"""
def eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs):
    # Evaluate a trained agent vs a random agent
    env = env_fn.env(render_mode=render_mode, **env_kwargs)

    print(
        f"Starting evaluation vs a random agent. Trained agent will play as {env.possible_agents[1]}."
    )

    try:
        latest_policy = max(
            glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
        )
    except ValueError:
        print("Policy not found.")
        exit(0)

    model = MaskablePPO.load(latest_policy)

    scores = {agent: 0 for agent in env.possible_agents}
    total_rewards = {agent: 0 for agent in env.possible_agents}
    round_rewards = []

    for i in range(num_games):
        env.reset(seed=i)
        env.action_space(env.possible_agents[0]).seed(i)

        for agent in env.agent_iter():
            obs, reward, termination, truncation, info = env.last()

            # Separate observation and action mask
            observation, action_mask = obs.values()

            if termination or truncation:
                # If there is a winner, keep track, otherwise don't change the scores (tie)
                if (
                    env.rewards[env.possible_agents[0]]
                    != env.rewards[env.possible_agents[1]]
                ):
                    winner = max(env.rewards, key=env.rewards.get)
                    scores[winner] += env.rewards[
                        winner
                    ]  # only tracks the largest reward (winner of game)
                # Also track negative and positive rewards (penalizes illegal moves)
                for a in env.possible_agents:
                    total_rewards[a] += env.rewards[a]
                # List of rewards by round, for reference
                round_rewards.append(env.rewards)
                break
            else:
                if agent == env.possible_agents[0]:
                    act = env.action_space(agent).sample(action_mask)
                else:
                    # Note: PettingZoo expects integer actions # TODO: change chess to cast actions to type int?
                    act = int(
                        model.predict(
                            observation, action_masks=action_mask, deterministic=True
                        )[0]
                    )
            env.step(act)
    env.close()

    # Avoid dividing by zero
    if sum(scores.values()) == 0:
        winrate = 0
    else:
        winrate = scores[env.possible_agents[1]] / sum(scores.values())
    print("Rewards by round: ", round_rewards)
    print("Total rewards (incl. negative rewards): ", total_rewards)
    print("Winrate: ", winrate)
    print("Final scores: ", scores)
    return round_rewards, total_rewards, winrate, scores

In [21]:
env_fn = connect_four_v3

env_kwargs = {}

# Evaluate 100 games against a random agent (winrate should be ~80%)
eval_action_mask(env_fn, num_games=100, render_mode=None, **env_kwargs)

# Watch two games vs a random agent
eval_action_mask(env_fn, num_games=2, render_mode="human", **env_kwargs)

Starting evaluation vs a random agent. Trained agent will play as player_1.
Rewards by round:  [{'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}, {'player_0': -1, 'player_1': 1}, {'player_0': 1, 'player_1': -1}, {'player_0': 1, 'player_1': -1}, {'player_0': 

([{'player_0': 1, 'player_1': -1}, {'player_0': -1, 'player_1': 1}],
 {'player_0': 0, 'player_1': 0},
 0.5,
 {'player_0': 1, 'player_1': 1})

# Conclusions

We learned to use pettingZoo library to train Stable Baselines3 models on multi-agent environments predefined in pettingZoo. This configuration has it's disadvantages, We cannot train agents that have different observation or action spaces or different purpouses. That is because in reality we train the same agent from different perspectives. PettingZoo is compatibile with more RL libraries so feel free to try them! We decided on stable baselines3 because it is used widely and is easy to install. In next notebook we'll learn to create custom environments