In [1]:
"""Uses Stable-Baselines3 to train agents in the Knights-Archers-Zombies environment using SuperSuit vector envs.

This environment requires using SuperSuit's Black Death wrapper, to handle agent death.

For more information, see https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html

Author: Elliot (https://github.com/elliottower)
"""
from __future__ import annotations

import glob
import os
import time
import numpy as np
import time

import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy, MlpPolicy

from pettingzoo.mpe import simple_tag_v3
# from pettingzoo.butterfly import knights_archers_zombies_v10


def train(env_fn, steps: int = 100, seed: int | None = 0, **env_kwargs):
    # Train a single model to play as each agent in an AEC environment
    env = env_fn.parallel_env(**env_kwargs)

    # Add black death wrapper so the number of agents stays constant
    # MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True
    env = ss.black_death_v3(env)

    # Pre-process using SuperSuit
    visual_observation = False #not env.unwrapped.continuous_actions not env.unwrapped.vector_state
    if visual_observation:
        # If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking
        env = ss.color_reduction_v0(env, mode="B")
        env = ss.resize_v1(env, x_size=84, y_size=84)
        env = ss.frame_stack_v1(env, 3)

    env.reset(seed=seed)

    print(f"Starting training on {str(env.metadata['name'])}.")
    env = ss.multiagent_wrappers.pad_observations_v0(env)
    env = ss.pettingzoo_env_to_vec_env_v1(env)
    env = ss.concat_vec_envs_v1(env, 8, num_cpus=1, base_class="stable_baselines3")

    # Use a CNN policy if the observation space is visual
    model = PPO(
        CnnPolicy if visual_observation else MlpPolicy,
        env,
        verbose=3,
        batch_size=256,
    )

    model.learn(total_timesteps=steps)

    model.save(f"{env.unwrapped.metadata.get('name')}_{time.strftime('%Y%m%d-%H%M%S')}")

    print("Model has been saved.")

    print(f"Finished training on {str(env.unwrapped.metadata['name'])}.")

    env.close()

In [2]:
# Set vector_state to false in order to use visual observations (significantly longer training time)
env_kwargs = dict(num_good=1, num_adversaries=3, num_obstacles=2, max_cycles=25, continuous_actions=False )
# max_cycles=100, max_zombies=4, vector_state=True
# Train a model (takes ~5 minutes on a laptop CPU)
env_fn = simple_tag_v3
train(env_fn, steps=10000, seed=0, **env_kwargs)

Starting training on simple_tag_v3.
Using cuda device




------------------------------
| time/              |       |
|    fps             | 4074  |
|    iterations      | 1     |
|    time_elapsed    | 16    |
|    total_timesteps | 65536 |
------------------------------
Model has been saved.
Finished training on simple_tag_v3.


In [3]:

def eval(env_fn, num_games: int = 10000, render_mode: str | None = None, **env_kwargs):
    # Evaluate a trained agent vs a random agent
    env = env_fn.env(render_mode=render_mode, **env_kwargs)

    # Pre-process using SuperSuit
    # if visual_observation = False #not env.unwrapped.continuous_actions
    if visual_observation:
        # If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking
        env = ss.color_reduction_v0(env, mode="B")
        env = ss.resize_v1(env, x_size=84, y_size=84)
        env = ss.frame_stack_v1(env, 3)

    print(
        f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
    )

    try:
        latest_policy = max(
            glob.glob(f"{env.metadata['name']}*.zip"), key=os.path.getctime
        )
    except ValueError:
        print("Policy not found.")
        exit(0)

    model = PPO.load(latest_policy)

    rewards = {agent: 0 for agent in env.possible_agents}

    # Note: we evaluate here using an AEC environments, to allow for easy A/B testing against random policies
    # For example, we can see here that using a random agent for archer_0 results in less points than the trained agent
    for i in range(num_games):
        env.reset(seed=i)
        env.action_space(env.possible_agents[0]).seed(i)
        
        for agent in env.agent_iter():
            obs, reward, termination, truncation, info = env.last()
            if agent == 'agent_0':
                obs=np.append(obs, [0,0])
            #print(obs)
            if render_mode== 'human':
                time.sleep(0.01)
            for agent in env.agents:
                rewards[agent] += env.rewards[agent]

            if termination or truncation:
                break
            else:
                if agent == env.possible_agents[0]:
                    act = env.action_space(agent).sample()
                else:
                    act = model.predict(obs, deterministic=True)[0]
            env.step(act)
    env.close()

    avg_reward = sum(rewards.values()) / len(rewards.values())
    avg_reward_per_agent = {
        agent: rewards[agent] / num_games for agent in env.possible_agents
    }
    print(f"Avg reward: {avg_reward}")
    print("Avg reward per agent, per game: ", avg_reward_per_agent)
    print("Full rewards: ", rewards)
    return avg_reward


SyntaxError: invalid syntax (3142361830.py, line 6)

In [52]:
# Set vector_state to false in order to use visual observations (significantly longer training time)
env_kwargs = dict(num_good=1, num_adversaries=3, num_obstacles=2, max_cycles=25, continuous_actions=False)
eval(env_fn, num_games=10, render_mode=None, **env_kwargs)


Starting evaluation on simple_tag_v3 (num_games=10, render_mode=None)
[ 0.          0.          0.27392337 -0.46042657 -0.19539839  1.243557
  0.29461303 -0.43464413 -1.1919763  -0.5065182   0.3526171   1.2859378
 -0.06065182  0.9194197   0.          0.        ]
[ 0.          0.         -0.918053   -0.96694475  0.9965779   1.7500751
  1.4865893   0.07187403  1.1919763   0.5065182   1.5445935   1.7924559
  1.1313245   1.4259379   0.          0.        ]
[ 0.          0.          0.6265405   0.82551116 -0.5480155  -0.04238079
 -0.05800408 -1.7205819  -0.3526171  -1.2859378  -1.5445935  -1.7924559
 -0.41326892 -0.36651802  0.          0.        ]
[ 0.          0.          0.21327156  0.45899311 -0.13474657  0.32413724
  0.35526484 -1.35406387  0.06065182 -0.91941971 -1.13132453 -1.42593789
  0.41326892  0.36651802  0.          0.        ]
[-0.          0.3         0.27392337 -0.46042657 -0.19539839  1.243557
  0.29461303 -0.43464413 -1.1919763  -0.5065182   0.3526171   1.2859378
 -0.0606

[ 0.28858194 -0.8049862  -0.0277447  -0.10372502  0.11701334 -0.7466686
  0.4840683   0.172383    0.1754419  -0.01577414  0.13819915  0.32256472
 -0.10188723  0.31473142  0.5439189  -0.19400944]
[-0.15193078 -0.6438497   0.11045445  0.21883969 -0.02118581 -1.0692333
  0.34586915 -0.15018173  0.03724276 -0.33833885 -0.13819915 -0.32256472
 -0.24008638 -0.00783329  0.5439189  -0.19400944]
[ 0.54391891 -0.19400944 -0.12963194  0.2110064   0.21890058 -1.06140006
  0.5859555  -0.14234844  0.27732915 -0.33050558  0.10188723 -0.31473142
  0.24008638  0.00783329  0.          0.        ]
[-0.302735    0.5500684   0.14733253 -0.04615672 -0.0580639  -0.8042369
  0.30899104  0.11481468 -0.14621904 -0.13806693 -0.05207117  0.20061144
 -0.22257258  0.23776218  0.00793919 -0.14550708]
[ 0.21643645 -0.3037396   0.0011135  -0.18422364  0.08815514 -0.66616994
  0.4552101   0.25288162  0.14621904  0.13806693  0.09414788  0.33867836
 -0.07635354  0.3758291   0.00793919 -0.14550708]
[ 0.1860519  -0.4828872

[ 0.225      -0.3         0.24471167 -0.24702683  0.42423183  0.32612136
  0.47927547  0.20590317  0.7391814   0.2906319   0.7077758  -0.5613011
  0.35909075 -0.36391753 -0.          0.7       ]
[-0.          0.69999999  0.60380244 -0.61094439  0.06514108  0.69003886
  0.12018473  0.5698207   0.38009068  0.65454942  0.348685   -0.19738358
 -0.35909075  0.36391753  0.          0.        ]
[ 0.450571    0.16273661  1.081674    0.06455503 -0.41273046  0.01453949
 -0.35768682 -0.10567869 -0.15918654 -0.850383   -0.8144623  -0.34158185
 -0.47787154 -0.6054994  -0.4         0.525     ]
[-0.225       0.46875     0.92248744 -0.78582793 -0.2535439   0.86492246
 -0.19850026  0.7447043   0.15918654  0.850383   -0.65527576  0.5088011
 -0.318685    0.24488358 -0.4         0.525     ]
[ 0.46875    -0.225       0.26721168 -0.27702683  0.40173182  0.35612136
  0.4567755   0.23590317  0.8144623   0.34158185  0.65527576 -0.5088011
  0.33659074 -0.26391754 -0.4         0.525     ]
[-0.40000001  0.5249999

[-0.59202325  0.2096038   0.16318913  0.13305959  0.16066265 -0.81170934
 -0.9700765   0.49728483 -0.22116354  0.22902261 -0.09282741 -0.39103249
 -0.06546411  0.01335498  0.          0.        ]
[ 0.35109937 -0.1464515  -0.05116116  0.34255534  0.37501293 -1.0212051
 -0.7557262   0.28778908  0.12891793 -0.5792143   0.14057146 -0.18482342
  0.15514797 -0.18853536  0.526115    0.04089319]
[ 0.05546283  0.45985466  0.07775677 -0.23665893  0.24609502 -0.4419908
 -0.88464415  0.8670034  -0.12891793  0.5792143   0.01165354  0.39439085
  0.02623004  0.3906789   0.526115    0.04089319]
[-0.93249285  0.20118974  0.08941031  0.15773192  0.23444149 -0.8363817
 -0.8962977   0.4726125  -0.14057146  0.18482342 -0.01165354 -0.39439085
  0.01457651 -0.00371195  0.526115    0.04089319]
[ 0.526115    0.04089319  0.10398681  0.15401998  0.21986498 -0.83266968
 -0.91087419  0.47632444 -0.15514797  0.18853536 -0.02623004 -0.39067891
 -0.01457651  0.00371195  0.          0.        ]
[-0.03667547 -0.1098386

[-2.0242462e-01 -3.2121968e-01  2.4761914e-01 -1.4258622e-01
 -1.0997612e+00  2.9632617e-02 -2.7471936e-01 -6.4013624e-01
 -4.3438247e-01 -5.4768234e-04 -1.0052539e-01 -9.1959789e-02
 -2.3102328e-01 -7.3137194e-02  2.4546020e-01 -2.6516011e-01]
[ 4.5369953e-02 -2.6135966e-01 -1.8676333e-01 -1.4313389e-01
 -6.6537875e-01  3.0180300e-02  1.5966313e-01 -6.3958853e-01
  4.3438247e-01  5.4768234e-04  3.3385709e-01 -9.1412112e-02
  2.0335919e-01 -7.2589509e-02  2.4546020e-01 -2.6516011e-01]
[-0.19480418 -0.22393012  0.14709376 -0.234546   -0.9992358   0.12159241
 -0.17419396 -0.5481764   0.10052539  0.09195979 -0.3338571   0.09141211
 -0.13049789  0.0188226   0.2454602  -0.2651601 ]
[ 0.2454602  -0.26516011  0.01659586 -0.21572341 -0.86873794  0.10276981
 -0.04369607 -0.56699902  0.23102328  0.07313719 -0.20335919  0.07258951
  0.13049789 -0.0188226   0.          0.        ]
[-0.05030805 -0.14805387  0.22737667 -0.17470819 -1.0795188   0.06175458
 -0.25447688 -0.6080143  -0.409603    0.00543

344.58035553799584

In [4]:
from __future__ import annotations

import glob
import os
import time
import numpy as np
import time

import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy, MlpPolicy

from pettingzoo.mpe import simple_tag_v3
env_fn = simple_tag_v3
print(env_fn)

<module 'pettingzoo.mpe.simple_tag_v3' from 'C:\\Users\\aron_\\.conda\\envs\\deeplearning\\lib\\site-packages\\pettingzoo\\mpe\\simple_tag_v3.py'>


In [6]:
# Set vector_state to false in order to use visual observations (significantly longer training time)
env_kwargs = dict(num_good=1, num_adversaries=3, num_obstacles=2, max_cycles=200, continuous_actions=False )
eval(env_fn, 100, 'human', **env_kwargs)

TypeError: eval() takes no keyword arguments