**NOTE: This notebook is written for the Google Colab platform, which provides free hardware acceleration. However it can also be run (possibly with minor modifications) as a standard Jupyter notebook, using a local GPU.**

In [None]:
#@title -- Installation of Packages -- { display-mode: "form" }
import sys
!{sys.executable} -m pip install gym[classic_control]
!{sys.executable} -m pip install tianshou
!{sys.executable} -m pip install git+https://github.com/michalgregor/tianshou_agents.git

In [None]:
#@title -- Import of Necessary Packages -- { display-mode: "form" }
from tianshou_agents.methods.ma import ma_default
from tianshou_agents.methods.dqn import dqn_default
from tianshou.env.pettingzoo_env import PettingZooEnv
from pettingzoo.classic import tictactoe_v3
from tianshou.trainer import OffpolicyTrainer

## Using the Multi-Agent Policy with PettingZoo Envs

This notebook gives a brief illustration of how to use the multi-agent policy in Tianshou Agents together with a PettingZoo env. The hyperparameters are from the [Tic-Tac-Toe example in Tianshou](https://github.com/thu-ml/tianshou/blob/master/test/pettingzoo/tic_tac_toe.py).

First of all, we are going to set up a function that constructs our PettingZoo environment and wraps it in Tianshou's ``PettingZooEnv``. We also define a custom function that will extract the observation shape from the observation space. The observation space here is a bit special as Gym environments go – it is a dictionary space and the actual observations live under the `'observation'` key.

In [2]:
def get_env():
    return PettingZooEnv(tictactoe_v3.env())

def extract_obs_shape(observation_space):
    return observation_space['observation'].shape

  from collections import Iterable


To get a configuration for our policies, we are going to retrieve a config from the ``dqn_default`` preset. Then we are going to modify the ``'hidden_sizes'`` of the ``'qnetwork'``.

In [7]:
policy_config = dqn_default.derive_conf()['policy']
policy_config['qnetwork']['hidden_sizes'] = [128, 128, 128, 128]

Having prepared a policy config, we now have several options. We could, for instance, just pass ``policies=[policy_config, policy_config]`` to the ``ma_default`` preset and that would automatically construct two policies with the same configuration.

What we can also do, is define a custom function, that is going to return a list of already built policies. We are going to take this latter approach here, because it is a bit more flexible. We can, for instance, create just a single policy instance, and use it to control both agents. This is not necessarily going to help the learning process, but it serves to illustrate the kind of modeling freedom that we have.

In [4]:
def make_policy(agent, device, seed, **kwargs):
    policy = agent.config_router.policy_builder(
        config=policy_config,
        default_kwargs=dict(kwargs,
            agent=agent,
            device=device,
            seed=seed
        )
    )

    return [policy, policy]

Having figured out how to construct the policies, we can now use the ``ma_default`` preset to construct our agent. The interface is as usual – we can call ``agent.train()`` to start training.

In [9]:
agent = ma_default(
    'TicTacToe',
    task=get_env,
    trainer_class=OffpolicyTrainer,
    policies=make_policy,
    replay_buffer=20000,
    max_epoch=5,
    step_per_epoch=1000,
    extract_obs_shape=extract_obs_shape
)

agent.train()

obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of all legal moves that can be chosen.
obs['action_mask'] contains a mask of 

Epoch #1: 1001it [01:16, 13.13it/s, env_step=1000, len=5, n/ep=0, n/st=1, player_1/loss=0.121, player_2/loss=0.130, rew=0.00]                          


Epoch #1: test_reward: 0.000000 ± 1.000000, best_reward: 0.000000 ± 1.000000 in #0


Epoch #2: 1001it [01:45,  9.47it/s, env_step=2000, len=6, n/ep=0, n/st=1, player_1/loss=0.062, player_2/loss=0.080, rew=0.00]                          


Epoch #2: test_reward: 0.000000 ± 1.000000, best_reward: 0.000000 ± 1.000000 in #0


Epoch #3: 1001it [01:21, 12.27it/s, env_step=3000, len=9, n/ep=0, n/st=1, player_1/loss=0.033, player_2/loss=0.058, rew=0.00]                          


Epoch #3: test_reward: 0.000000 ± 1.000000, best_reward: 0.000000 ± 1.000000 in #0


Epoch #4: 1001it [01:08, 14.57it/s, env_step=4000, len=7, n/ep=0, n/st=1, player_1/loss=2.419, player_2/loss=0.080, rew=0.00]                          


Epoch #4: test_reward: 0.000000 ± 1.000000, best_reward: 0.000000 ± 1.000000 in #0


Epoch #5: 1001it [00:59, 16.92it/s, env_step=5000, len=8, n/ep=0, n/st=1, player_1/loss=0.031, player_2/loss=0.048, rew=0.00]                          


Epoch #5: test_reward: 0.000000 ± 1.000000, best_reward: 0.000000 ± 1.000000 in #0


{'duration': '392.33s',
 'train_time/model': '356.88s',
 'test_step': 38,
 'test_episode': 6,
 'test_time': '0.18s',
 'test_speed': '215.91 step/s',
 'best_reward': 0.0,
 'best_result': '0.00 ± 1.00',
 'train_step': 5000,
 'train_episode': 669,
 'train_time/collector': '35.27s',
 'train_speed': '12.75 step/s'}

To test the agent, we can call ``agent.test`` just like we would do is a regular single-agent setting.

In [14]:
agent.test(episode_per_test=10)

{'n/ep': 10,
 'n/st': 60,
 'rews': array([[-1.,  1.],
        [-1.,  1.],
        [-1.,  1.],
        [-1.,  1.],
        [-1.,  1.],
        [-1.,  1.],
        [-1.,  1.],
        [-1.,  1.],
        [-1.,  1.],
        [-1.,  1.]]),
 'lens': array([6, 6, 6, 6, 6, 6, 6, 6, 6, 6]),
 'idxs': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'rew': 0.0,
 'len': 6.0,
 'rew_std': 1.0,
 'len_std': 0.0}