In [45]:
import time
import enum

import matplotlib.pyplot as pd
import pandas as pd
import numpy as np
import torch
import torch.nn
import torch.nn.functional as F

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

from pettingzoo.mpe import simple_tag_v2
from pettingzoo.utils import random_demo

Arguments in instantiate environment.

- num_good: number of good agents
- num_adversaries: number of adversaries
- num_obstacles: number of obstacles
- max_cycles: number of frames (a step for each agent) until game terminates
- continuous_actions: Whether agent action spaces are discrete(default) or continuous

In [2]:
env = simple_tag_v2.env(
    num_good=3,
    num_adversaries=3,
    num_obstacles=2,
    max_cycles=100,
    continuous_actions=False
).unwrapped
print("Peek into unwrapped environment:", *dir(env))

Peek into unwrapped environment: __class__ __delattr__ __dict__ __dir__ __doc__ __eq__ __format__ __ge__ __getattribute__ __gt__ __hash__ __init__ __init_subclass__ __le__ __lt__ __module__ __ne__ __new__ __reduce__ __reduce_ex__ __repr__ __setattr__ __sizeof__ __str__ __subclasshook__ __weakref__ _accumulate_rewards _agent_selector _clear_rewards _dones_step_first _execute_world_step _index_map _reset_render _set_action _was_done_step action_space action_spaces agent_iter agents close continuous_actions current_actions last local_ratio max_cycles max_num_agents metadata np_random num_agents observation_space observation_spaces observe possible_agents render reset scenario seed state state_space step steps unwrapped viewer world


Adversaries (red) capture non-adversary (green). The map is a 2D grid and everything is initialized in the region [-1, +1]. There doesn't seem to be position clipping for out of bounds, but non-adversary agent are penalized for out of bounds.
Agent's observation is a ndarray vector of concatenated data in the following order:

1. current velocity (2,)
2. current position (2,)
3. relative position (2,) of each landmark
4. relative position (2,) of each other agent
5. velocity (2,) of each other non-adversary agent

When there are 3 adverseries and 3 non-adversaries, then advarsary observation space is 24 dimensional and non-advarsary observation space is 22 dimensional.

The environment is sequential. Agents move one at a time. Agents are either `adversary_*` for adversary or `agent_*` for non-adversary.

Actions:

- 0 is NOP
- 1 is go left
- 2 is go right
- 3 is go down
- 4 is go up

In [59]:
# Print variables of the environment
# Documentation:   https://www.pettingzoo.ml/api
env.reset()
print("State size", env.state_space.shape)
print("Name of current agent", env.agent_selection)
print("Observation space of current agent", env.observation_space(env.agent_selection).shape)
print("Action space of current agent", env.action_space(env.agent_selection))
print("Sample random action from current agent", env.action_space(env.agent_selection).sample())
print("The agent names:", *env.agents)
print()

print()
print()

# select an agent in the environment world, after using env.unwrapped
agent = env.world.agents[0]
print("agent's name is", agent.name)
print("agent's position and velocity coordinates", agent.state.p_vel, agent.state.p_pos)
print("is agent an adversary?", agent.adversary)

landmark = env.world.landmarks[0]
print("landmark's name is", landmark.name)
print("landmark's position coordinates (doesn't move)", landmark.state.p_pos)

State size (138,)
Name of current agent adversary_0
Observation space of current agent (24,)
Action space of current agent Discrete(5)
Sample random action from current agent 1
The agent names: adversary_0 adversary_1 adversary_2 agent_0 agent_1 agent_2

Box([-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf
 -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf], [inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf], (24,), float32)

agent's name is adversary_0
agent's position and velocity coordinates [0. 0.] [ 0.91148203 -0.05041044]
is agent an adversary? True
landmark's name is landmark 0
landmark's position coordinates (doesn't move) [ 0.73009034 -0.75464927]


In [8]:
# Demo environment with random policy
env.reset()
random_demo(env, render=True, episodes=5)

Average total reward -402.272556480898


-2011.36278240449

In [6]:
def hardcode_policy(observation, agent):
    """
    Parameters
    ==========
    agent : str
    """
#     print(observation.shape)
#     print(agent)
    if "adversary" in agent:
        # adversary
        if agent == "adversary_0":
            return 4
        
    if "agent" in agent:
        # non-adversary
        pass
    return 0

env.reset()
for agent in env.agent_iter():
    env.render()
    observation, reward, done, info = env.last()
    if done:
        env.step(None)
    else:
        action = hardcode_policy(observation, agent)
        env.step(action)
    # time.sleep(0.1)

Messages are of size 5?

Concatenate the messages from all the actors and add them to the message input for the current agent.

The names of agents are: 
adversary_0 adversary_1 adversary_2 agent_0 agent_1 agent_2

In [87]:
env.observation_space("adversary_0").shape[0]

24

In [None]:
def get_agent_counts():
    all_agents = 0
    adversaries = 0
    for agent in env.world.agents:
        all_agents += 1
        adversaries += 1 if agent.adversary else 0
    good_agents = all_agents - adversaries
    return (adversaries, good_agents)

def process_config(config):
    for k, v in config.all.items():
        config.adversary[k] = v
        config.agent[k] = v

n_adversaries, n_good_agents = get_agent_counts()
config = AttrDict(
    all=AttrDict(
        message_size=4,
        hidden_size=128,
        n_actions=env.action_space(env.agent_selection).n,
    )
    adversary=AttrDict(
        n_agents=n_adversaries,
        observation_shape=env.observation_space("adversary_0").shape

    ),
    agent=AttrDict(
        n_agents=n_good_agents,
        observation_shape=env.observation_space("agent_0").shape
    )
)
process_config(config)

class Container(object):
    """Container of messages and hidden states of agents in environment."""
    
    def reset(self):
        keys = [*self.__message_d.keys()]
        for k in keys:
            del self.__message_d[k]
        keys = [*self.__hidden_d.keys()]
         for k in keys:
            del self.__hidden_d[k]
        self.__message_d["adversary"] = torch.zeros(
            self.config.adversary.n_agents*self.config.adversary.message_size,
            dtype=torch.float
        )
        self.__message_d["agent"]     = torch.zeros(
            self.config.agent.n_agents*self.config.agent.message_size,
            dtype=torch.float
        )
        for idx in range(self.n_adversaries):
            self.__hidden_d[f"adversary_{idx}"] = torch.zeros(
                self.config.adversary.hidden_size,
                dtype=torch.float
            )
        for idx in range(self.n_agents):
            self.__hidden_d[f"agent_{idx}"]     = torch.zeros(
                self.config.agent.hidden_size,
                dtype=torch.float
            )
        
    def __init__(self, config):
        self.config
        self.__message_d = {}
        self.__hidden_d = {}
        self.reset()
    
    def get_message(self, agent_name):
        if "adversary" in agent_name:
            return self.__message_d["adversary"]
        elif "agent" in agent_name:
            return self.__message_d["agent"]
        else:
            raise ValueError(f"{agent_name} is neither an agent or adversary.")
    
    def update_message(self, agent_name, message):
        agent_type, agent_idx = agent_name.split("_")
        agent_idx = int(agent_idx)
        self.__message_d[agent_type][agent_idx:agent_idx + self.__message_size] = message


In [None]:
# simple_tag_v2

class SimpleTagNet(torch.nn.Module):
        
    def __init__(self, config, agent_type):
        super().__init__()
        self.observation_size = config[agent_type].observation_size
        self.message_size = config[agent_type].message_size
        self.hidden_size = config[agent_type].hidden_size
        self.n_agents = config[agent_type].n_agents
        self.n_actions = config[agent_type].n_actions
        
        self.agent_lookup  = torch.nn.Embedding(self.n_agents, self.hidden_size)
        self.action_lookup = torch.nn.Embedding(self.n_actions, self.hidden_size)
        self.state_mlp     = torch.nn.Linear(self.observation_size, self.hidden_size)
        self.message_size  = torch.nn.Linear(self.self.message_size, self.hidden_size)
        self.rnn = torch.nn.GRU(128)
        

In [38]:
agent_name = "adversary_0"
agent_type, agent_idx = agent_name.split("_")
int(agent_idx)


0

In [72]:
a = torch.zeros(4*3)
b = torch.tensor([1,2,3,4])
a[4:8] = b
a, b.size()[0]

(tensor([0., 0., 0., 0., 1., 2., 3., 4., 0., 0., 0., 0.]), 4)

In [84]:
a = torch.arange(10)
a[range(3,6)]

tensor([3, 4, 5])