Simple Tag
https://www.pettingzoo.ml/mpe/simple_tag

> This is a predator-prey environment. Good agents (green) are faster and receive a negative reward for being hit by adversaries (red) (-10 for each collision). Adversaries are slower and are rewarded for hitting good agents (+10 for each collision). Obstacles (large black circles) block the way. By default, there is 1 good agent, 3 adversaries and 2 obstacles.

Testing some hardcoded algorithms

In [47]:
import os
import time
import enum
import math
import random
import collections
import statistics

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import torch.nn
import torch.nn.functional as F

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

class TimeDelta(object):
    def __init__(self, delta_time):
        """Convert time difference in seconds to days, hours, minutes, seconds.
        
        Parameters
        ==========
        delta_time : float
            Time difference in seconds.
        """
        self.fractional, seconds = math.modf(delta_time)
        seconds = int(seconds)
        minutes, self.seconds = divmod(seconds, 60)
        hours, self.minutes = divmod(minutes, 60)
        self.days, self.hours = divmod(hours, 24)
    
    def __repr__(self):
        return f"{self.days}-{self.hours:02}:{self.minutes:02}:{self.seconds + self.fractional:02}"

class Normalizer(object):
    def __init__(self, env):
        other.adversary
        self.n_landmarks = len(env.world.landmarks)
        self.n_allagents = len(env.world.agents)
        self.n_good = sum(map(lambda a: not a.adversary, env.world.agents))
    
    @staticmethod
    def normalize_abs_pos(s):
        """Clip absolute position and scale to [-1, 1]
        s is a scalar or an ndarray of one dimension."""
        return np.clip(s, -1.5, 1.5) / 1.5

    @staticmethod
    def normalize_rel_pos(s):
        """Clip relative position and scale to [-1, 1]
        s is a scalar or an ndarray of one dimension."""
        return np.clip(s, -3, 3) / 3

    def normalize_obs(self, obs):
        # normalize and clip positions
        norm_obs = obs.copy()
        # normalize velocity of current entity
        norm_obs[:2] = norm_obs[:2] / 1.3
        # clip/scale abs. position of current entity
        norm_obs[2:4] = self.normalize_abs_pos(norm_obs[2:4])
        # clip/scale rel. position of other entities
        n_range = self.n_landmarks + self.n_allagents - 1
        for i in range(n_range):
            norm_obs[4 + (2*i):2 + (2*(i + 1))] = self.normalize_rel_pos(
                norm_obs[4 + (2*i):2 + (2*(i + 1))]
            )
        # normalize velocity of other entities
        norm_obs[4 + (2*n_range):] = norm_obs[4 + (2*n_range):] / 1.3
        return norm_obs

from pettingzoo.mpe import simple_tag_v2
from pettingzoo.utils import random_demo

Arguments in instantiate environment.

- num_good: number of good agents
- num_adversaries: number of adversaries
- num_obstacles: number of obstacles
- max_cycles: number of frames (a step for each agent) until game terminates
- continuous_actions: Whether agent action spaces are discrete(default) or continuous

In [13]:
env = simple_tag_v2.env(
    num_good=3,
    num_adversaries=3,
    num_obstacles=2,
    max_cycles=300,
    continuous_actions=False
).unwrapped
print("Peek into unwrapped environment:", *dir(env))

Peek into unwrapped environment: __class__ __delattr__ __dict__ __dir__ __doc__ __eq__ __format__ __ge__ __getattribute__ __gt__ __hash__ __init__ __init_subclass__ __le__ __lt__ __module__ __ne__ __new__ __reduce__ __reduce_ex__ __repr__ __setattr__ __sizeof__ __str__ __subclasshook__ __weakref__ _accumulate_rewards _agent_selector _clear_rewards _dones_step_first _execute_world_step _index_map _reset_render _set_action _was_done_step action_space action_spaces agent_iter agents close continuous_actions current_actions last local_ratio max_cycles max_num_agents metadata np_random num_agents observation_space observation_spaces observe possible_agents render reset scenario seed state state_space step steps unwrapped viewer world


### What are the environment parameters?

Adversaries (red) capture non-adversary (green). The map is a 2D grid and everything is initialized in the region [-1, +1]. There doesn't seem to be position clipping for out of bounds, but non-adversary agent are penalized for out of bounds.
Agent's observation is a ndarray vector of concatenated data in the following order:

1. current velocity (2,)
2. current position (2,)
3. relative position (2,) of each landmark
4. relative position (2,) of each other agent
5. velocity (2,) of each other non-adversary agent

So observation forms a vector:
`[self_vel, self_pos, landmark_rel_positions, other_agent_rel_positions, other_agent_velocities]`

Max velocity for each coordinate is 1.3. Agents can move off the arena [-0.9, 0.9] and if good agents do they get penalized by increasingly by distance away.

Max possible distance away moving in one direction is around 40.

When there are 3 adverseries and 3 non-adversaries, then advarsary observation space is 24 dimensional and non-advarsary observation space is 22 dimensional.

The environment is sequential. Agents move one at a time. Agents are either `adversary_*` for adversary or `agent_*` for non-adversary.

Actions:

- 0 is NOP
- 1 is go left
- 2 is go right
- 3 is go down
- 4 is go up

In [3]:
# Print variables of the environment
# Documentation:   https://www.pettingzoo.ml/api
env.reset()
print("State size", env.state_space.shape)
print("Name of current agent", env.agent_selection)
print("Observation space of current agent", env.observation_space(env.agent_selection).shape)
print("Action space of current agent", env.action_space(env.agent_selection))
print("Sample random action from current agent", env.action_space(env.agent_selection).sample())
print("The agent names:", *env.agents)
print()

# select an agent in the environment world, after using env.unwrapped
agent = env.world.agents[0]
print("agent's name is", agent.name)
print("agent's position and velocity coordinates", agent.state.p_vel, agent.state.p_pos)
print("is agent an adversary?", agent.adversary)

landmark = env.world.landmarks[0]
print("landmark's name is", landmark.name)
print("landmark's position coordinates (doesn't move)", landmark.state.p_pos)

State size (138,)
Name of current agent adversary_0
Observation space of current agent (24,)
Action space of current agent Discrete(5)
Sample random action from current agent 4
The agent names: adversary_0 adversary_1 adversary_2 agent_0 agent_1 agent_2

agent's name is adversary_0
agent's position and velocity coordinates [0. 0.] [-0.84429837  0.7186429 ]
is agent an adversary? True
landmark's name is landmark 0
landmark's position coordinates (doesn't move) [ 0.34800373 -0.42134618]


In [3]:
# Demo environment with random policy
env.reset()
random_demo(env, render=True, episodes=5)

Average total reward -3542.3399623641044


-17711.699811820523

In [14]:
# Demo environment with hardcoded policies
eps = 0.3

def hardcode_policy_1(observation, agent_name):
    """
    Parameters
    ==========
    observation : ndarray
    agent_name : str
    """
    if "adversary" in agent_name:
        # adversary
        if agent_name == "adversary_0":
            return np.random.binomial(2, 0.3) + 3
    elif "agent" in agent_name:
        # non-adversary
        if agent_name == "agent_0":
            pass
    return 0

def hardcode_policy_2(observation, agent_name):
    """
    Parameters
    ==========
    observation : ndarray
    agent : str
    """
    if "adversary" in agent_name:
        # adversary
        if agent_name == "adversary_0":
            # get agent_0's
            x, y = observation[12:14]
            if x < -eps: # go left
                return 1
            elif x > eps: # go right
                return 2
            elif y < -eps: # go down
                return 3
            elif y > eps: # go up
                return 4
            else:
                return random.randint(0, 4)
    elif "agent" in agent_name:
        # non-adversary
        if agent_name == "agent_0":
            return 0
            # return random.randint(0, 4)
    return 0

env.reset()
agent_rewards = 0
adversary_rewards = 0
for agent_step_idx, agent_name in enumerate(env.agent_iter()):
    env.render()
    observation, reward, done, info = env.last()
    if done:
        env.step(None)
    else:
        action = hardcode_policy_2(observation, agent_name)
        print("obs", np.round(observation, 2))
        env.step(action)
    if "adversary" in agent_name:
        adversary_rewards += reward
    if "agent" in agent_name:
        agent_rewards += reward
    # time.sleep(0.1)

print(f"episode ran for {agent_step_idx} steps")
print("agent_rewards", agent_rewards)
print("adversary_rewards", adversary_rewards)

obs [ 0.    0.    0.6   0.93 -0.02 -0.63  0.12 -0.38 -1.56 -1.26 -0.04 -0.14
 -0.63 -1.34 -1.34 -0.68 -0.53 -0.26  0.    0.    0.    0.    0.    0.  ]
obs [ 0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  1.56  1.26  1.52  1.12
  0.93 -0.08  0.22  0.58  1.04  1.01  0.    0.    0.    0.    0.    0.  ]
obs [ 0.    0.    0.56  0.79  0.02 -0.49  0.16 -0.24  0.04  0.14 -1.52 -1.12
 -0.59 -1.2  -1.3  -0.54 -0.48 -0.12  0.    0.    0.    0.    0.    0.  ]
obs [ 0.    0.   -0.03 -0.42  0.61  0.71  0.75  0.97  0.63  1.34 -0.93  0.08
  0.59  1.2  -0.7   0.66  0.11  1.09  0.    0.    0.    0.  ]
obs [ 0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  1.34  0.68 -0.22 -0.58
  1.3   0.54  0.7  -0.66  0.81  0.43  0.    0.    0.    0.  ]
obs [ 0.    0.    0.08  0.67  0.5  -0.37  0.64 -0.12  0.53  0.26 -1.04 -1.01
  0.48  0.12 -0.11 -1.09 -0.81 -0.43  0.    0.    0.    0.  ]
obs [-0.29  0.03  0.58  0.93  0.   -0.63  0.15 -0.38 -1.53 -1.27 -0.02 -0.15
 -0.61 -1.35 -1.31 -0.69 -0.5  -0.26 -0.   -0.    0. 

obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.57 -0.4  -0.22 -0.58
  1.28  0.53  0.7  -0.66  0.88  0.39  0.   -0.    0.02 -0.01]
obs [ 0.02 -0.01  0.14  0.64  0.44 -0.34  0.58 -0.09 -0.3  -0.79 -1.1  -0.97
  0.41  0.13 -0.17 -1.05 -0.88 -0.39  0.   -0.   -0.    0.  ]
obs [-0.02 -0.45 -0.16 -0.2   0.74  0.5   0.88  0.75 -0.8  -0.14  0.71  0.97
  0.13 -0.22 -0.57  0.44  0.31  0.84  0.   -0.   -0.    0.    0.01 -0.01]
obs [-0.   -0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.8   0.14  1.51  1.11
  0.93 -0.08  0.22  0.58  1.1   0.97  0.   -0.   -0.    0.    0.01 -0.01]
obs [-0.    0.    0.55  0.77  0.03 -0.47  0.17 -0.22 -0.71 -0.97 -1.51 -1.11
 -0.58 -1.19 -1.28 -0.53 -0.4  -0.14  0.   -0.   -0.    0.    0.01 -0.01]
obs [ 0.   -0.   -0.03 -0.42  0.61  0.71  0.75  0.97 -0.13  0.22 -0.93  0.08
  0.58  1.19 -0.7   0.66  0.18  1.05 -0.    0.    0.01 -0.01]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.57 -0.44 -0.22 -0.58
  1.28  0.53  0.7  -0.66  0.88  0.39  0.   -0.    0.0

obs [-0.   -0.01 -0.05 -0.52  0.63  0.82  0.77  1.07  0.27  0.07 -0.91  0.18
  0.6   1.29 -0.68  0.77  0.2   1.16 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.95 -0.7  -0.22 -0.58
  1.28  0.53  0.68 -0.77  0.88  0.39 -0.   -0.01  0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09  0.07 -1.09 -1.11 -0.97
  0.4   0.14 -0.2  -1.16 -0.88 -0.39 -0.   -0.01 -0.    0.  ]
obs [-0.01 -0.02  0.22 -0.46  0.36  0.76  0.5   1.01 -1.18  0.12  0.33  1.23
 -0.27 -0.06 -0.95  0.7  -0.07  1.09 -0.   -0.01 -0.    0.    0.   -0.  ]
obs [-0.   -0.   -0.96 -0.34  1.54  0.63  1.68  0.89  1.18 -0.12  1.51  1.11
  0.91 -0.19  0.22  0.58  1.11  0.97 -0.   -0.01 -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.33 -1.23 -1.51 -1.11
 -0.6  -1.29 -1.28 -0.53 -0.4  -0.14 -0.   -0.01 -0.    0.    0.   -0.  ]
obs [-0.   -0.01 -0.05 -0.52  0.63  0.82  0.77  1.07  0.27  0.06 -0.91  0.19
  0.6   1.29 -0.68  0.77  0.2   1.16 -0.    0.    0. 

obs [-0.   -0.   -0.96 -0.34  1.54  0.63  1.68  0.89  1.17 -0.22  1.51  1.11
  0.76 -0.26  0.22  0.58  1.11  0.97 -0.31 -0.15 -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.34 -1.33 -1.51 -1.11
 -0.75 -1.37 -1.28 -0.53 -0.4  -0.14 -0.31 -0.15 -0.    0.    0.   -0.  ]
obs [-0.31 -0.15 -0.2  -0.6   0.78  0.9   0.92  1.15  0.41  0.05 -0.76  0.26
  0.75  1.37 -0.53  0.84  0.35  1.24 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.94 -0.8  -0.22 -0.58
  1.28  0.53  0.53 -0.84  0.88  0.39 -0.31 -0.15  0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09  0.06 -1.19 -1.11 -0.97
  0.4   0.14 -0.35 -1.24 -0.88 -0.39 -0.31 -0.15 -0.    0.  ]
obs [ 0.32 -0.08  0.24 -0.56  0.34  0.86  0.48  1.11 -1.2   0.22  0.31  1.33
 -0.46 -0.05 -0.97  0.81 -0.09  1.2  -0.23 -0.11 -0.    0.    0.   -0.  ]
obs [-0.   -0.   -0.96 -0.34  1.54  0.63  1.68  0.89  1.2  -0.22  1.51  1.11
  0.73 -0.27  0.22  0.58  1.11  0.97 -0.23 -0.11 -0. 

obs [-0.    0.   -0.37 -0.64  0.95  0.93  1.09  1.19 -0.28 -0.23 -0.59  0.3
  0.92  1.41 -0.36  0.88  0.52  1.27 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.08 -1.11 -0.22 -0.58
  1.28  0.53  0.36 -0.88  0.88  0.39 -0.    0.    0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09 -0.8  -1.5  -1.11 -0.97
  0.4   0.14 -0.52 -1.27 -0.88 -0.39 -0.    0.   -0.    0.  ]
obs [-0.04 -0.21 -0.66 -0.89  1.24  1.19  1.38  1.44 -0.3   0.55  1.2   1.66
  0.29  0.25 -0.08  1.13  0.81  1.52 -0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.3  -0.55  1.51  1.11
  0.59 -0.3   0.22  0.58  1.11  0.97 -0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -1.2  -1.66 -1.51 -1.11
 -0.92 -1.41 -1.28 -0.53 -0.4  -0.14 -0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.37 -0.64  0.95  0.93  1.09  1.19 -0.29 -0.25 -0.59  0.3
  0.92  1.41 -0.36  0.88  0.52  1.27 -0.    0.    0.   

obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.95 -1.52 -1.51 -1.11
 -0.91 -1.4  -1.28 -0.53 -0.4  -0.14  0.05  0.1   0.    0.    0.   -0.  ]
obs [ 0.05  0.1  -0.37 -0.63  0.95  0.92  1.09  1.18 -0.03 -0.12 -0.59  0.29
  0.91  1.4  -0.37  0.87  0.51  1.26  0.    0.    0.   -0.  ]
obs [ 0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.33 -0.99 -0.22 -0.58
  1.28  0.53  0.37 -0.87  0.88  0.39  0.05  0.1   0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09 -0.55 -1.38 -1.11 -0.97
  0.4   0.14 -0.51 -1.26 -0.88 -0.39  0.05  0.1   0.    0.  ]
obs [ 0.16  0.22 -0.38 -0.72  0.96  1.02  1.1   1.27 -0.58  0.38  0.93  1.49
  0.02  0.11 -0.35  0.96  0.53  1.36  0.05  0.12  0.    0.    0.   -0.  ]
obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.58 -0.38  1.51  1.11
  0.6  -0.28  0.22  0.58  1.11  0.97  0.05  0.12  0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.93 -1.49 -1.51 -1.11
 -0.91 -1.39 -1.28 -0.53 -0.4  -0.14  0.05  0.12  0. 

obs [ 0.38  0.27 -0.34 -0.2   0.92  0.5   1.07  0.75 -0.62 -0.14  0.89  0.97
  0.01 -0.31 -0.39  0.44  0.49  0.84  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.62  0.14  1.51  1.11
  0.63 -0.17  0.22  0.58  1.11  0.97  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.89 -0.97 -1.51 -1.11
 -0.88 -1.28 -1.28 -0.53 -0.4  -0.14  0.    0.   -0.    0.    0.   -0.  ]
obs [ 0.    0.   -0.33 -0.51  0.91  0.81  1.05  1.06 -0.01  0.31 -0.63  0.17
  0.88  1.28 -0.41  0.75  0.48  1.15 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.39 -0.44 -0.22 -0.58
  1.28  0.53  0.41 -0.75  0.88  0.39  0.    0.    0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09 -0.49 -0.84 -1.11 -0.97
  0.4   0.14 -0.48 -1.15 -0.88 -0.39  0.    0.   -0.    0.  ]
obs [ 0.28 -0.1  -0.32 -0.21  0.9   0.51  1.04  0.76 -0.64 -0.13  0.86  0.98
 -0.01 -0.3  -0.42  0.45  0.46  0.85  0.    0.   -0. 

obs [ 0.01 -0.02 -0.31 -0.58  0.89  0.88  1.03  1.13 -0.17  0.08 -0.65  0.24
  0.86  1.35 -0.42  0.82  0.46  1.21 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.26 -0.74 -0.22 -0.58
  1.28  0.53  0.42 -0.82  0.88  0.39  0.01 -0.02  0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09 -0.63 -1.13 -1.11 -0.97
  0.4   0.14 -0.46 -1.21 -0.88 -0.39  0.01 -0.02 -0.    0.  ]
obs [-0.03 -0.38 -0.48 -0.53  1.06  0.83  1.2   1.08 -0.48  0.19  1.03  1.3
  0.17 -0.05 -0.25  0.78  0.63  1.17  0.   -0.02 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.48 -0.19  1.51  1.11
  0.65 -0.24  0.22  0.58  1.11  0.97  0.   -0.02 -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -1.03 -1.3  -1.51 -1.11
 -0.86 -1.35 -1.28 -0.53 -0.4  -0.14  0.   -0.02 -0.    0.    0.   -0.  ]
obs [ 0.   -0.02 -0.31 -0.58  0.89  0.88  1.03  1.13 -0.17  0.05 -0.65  0.24
  0.86  1.35 -0.42  0.82  0.46  1.22 -0.    0.    0.  

obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.64 -0.33  1.51  1.11
  0.7  -0.22  0.22  0.58  1.11  0.97  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.87 -1.44 -1.51 -1.11
 -0.81 -1.33 -1.28 -0.53 -0.4  -0.14  0.    0.   -0.    0.    0.   -0.  ]
obs [ 0.    0.   -0.26 -0.55  0.84  0.85  0.98  1.1  -0.06 -0.11 -0.7   0.22
  0.81  1.33 -0.47  0.8   0.41  1.19 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.42 -0.91 -0.22 -0.58
  1.28  0.53  0.47 -0.8   0.88  0.39  0.    0.    0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09 -0.47 -1.3  -1.11 -0.97
  0.4   0.14 -0.41 -1.19 -0.88 -0.39  0.    0.   -0.    0.  ]
obs [ 0.63 -0.05 -0.25 -0.67  0.84  0.97  0.98  1.22 -0.7   0.33  0.8   1.44
 -0.    0.11 -0.48  0.91  0.4   1.3   0.01  0.02 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.7  -0.33  1.51  1.11
  0.7  -0.22  0.22  0.58  1.11  0.97  0.01  0.02 -0. 

obs [-0.31 -0.02  0.03 -0.59  0.55  0.89  0.69  1.14 -0.99  0.26  0.52  1.37
 -0.29  0.09 -0.77  0.84  0.12  1.23  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.99 -0.26  1.51  1.11
  0.7  -0.17  0.22  0.58  1.11  0.97  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.52 -1.37 -1.51 -1.11
 -0.81 -1.28 -1.28 -0.53 -0.4  -0.14  0.    0.   -0.    0.    0.   -0.  ]
obs [ 0.    0.   -0.26 -0.51  0.84  0.81  0.98  1.06  0.29 -0.09 -0.7   0.17
  0.81  1.28 -0.48  0.75  0.41  1.14 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.77 -0.84 -0.22 -0.58
  1.28  0.53  0.48 -0.75  0.88  0.39  0.    0.    0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09 -0.12 -1.23 -1.11 -0.97
  0.4   0.14 -0.41 -1.14 -0.88 -0.39  0.    0.   -0.    0.  ]
obs [-0.23  0.28  0.01 -0.56  0.57  0.86  0.71  1.12 -0.97  0.23  0.54  1.34
 -0.26  0.06 -0.74  0.81  0.14  1.2   0.    0.   -0. 

obs [-0.08 -0.09 -0.04 -0.81  0.62  1.11  0.76  1.36 -0.92  0.48  0.59  1.59
 -0.22  0.31 -0.7   1.06  0.19  1.45  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.92 -0.48  1.51  1.11
  0.7  -0.17  0.22  0.58  1.11  0.97  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.59 -1.59 -1.51 -1.11
 -0.81 -1.28 -1.28 -0.53 -0.4  -0.14  0.    0.   -0.    0.    0.   -0.  ]
obs [ 0.    0.   -0.26 -0.51  0.84  0.81  0.98  1.06  0.22 -0.31 -0.7   0.17
  0.81  1.28 -0.48  0.75  0.41  1.14 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.7  -1.06 -0.22 -0.58
  1.28  0.53  0.48 -0.75  0.88  0.39  0.    0.    0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09 -0.19 -1.45 -1.11 -0.97
  0.4   0.14 -0.41 -1.14 -0.88 -0.39  0.    0.   -0.    0.  ]
obs [-0.06  0.23 -0.04 -0.79  0.62  1.09  0.76  1.34 -0.92  0.45  0.59  1.56
 -0.21  0.28 -0.69  1.04  0.19  1.43  0.    0.   -0. 

obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.88 -0.46  1.51  1.11
  0.7  -0.17  0.22  0.58  1.11  0.97  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.    0.55  0.77  0.03 -0.48  0.17 -0.22 -0.63 -1.57 -1.51 -1.11
 -0.81 -1.28 -1.28 -0.53 -0.4  -0.14  0.    0.   -0.    0.    0.   -0.  ]
obs [ 0.    0.   -0.26 -0.51  0.84  0.81  0.98  1.06  0.18 -0.29 -0.7   0.17
  0.81  1.28 -0.48  0.75  0.41  1.14 -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.73  0.25  1.32  0.05  1.46  0.31  0.66 -1.04 -0.22 -0.58
  1.28  0.53  0.48 -0.75  0.88  0.39  0.    0.    0.   -0.  ]
obs [ 0.   -0.    0.15  0.64  0.43 -0.34  0.57 -0.09 -0.23 -1.43 -1.11 -0.97
  0.4   0.14 -0.41 -1.14 -0.88 -0.39  0.    0.   -0.    0.  ]
obs [ 0.04 -0.41 -0.07 -0.84  0.65  1.14  0.79  1.39 -0.89  0.5   0.62  1.61
 -0.18  0.33 -0.66  1.08  0.22  1.47  0.    0.   -0.    0.    0.   -0.  ]
obs [-0.    0.   -0.96 -0.34  1.54  0.63  1.68  0.89  0.89 -0.5   1.51  1.11
  0.7  -0.17  0.22  0.58  1.11  0.97  0.    0.   -0. 

In [44]:
def hardcode_policy_1(observation, agent_name):
    """
    Parameters
    ==========
    observation : ndarray
    agent_name : str
    """
    if "adversary" in agent_name:
        # adversary
        if agent_name == "adversary_0":
            pass
    elif "agent" in agent_name:
        # non-adversary
        if agent_name == "agent_0":
            return 3
    return 0

env.reset()
agent_rewards = 0
adversary_rewards = 0
for agent_step_idx, agent_name in enumerate(env.agent_iter()):
    # env.render()
    observation, reward, done, info = env.last()
    if done:
        env.step(None)
    else:
        action = hardcode_policy_1(observation, agent_name)
        env.step(action)
    if "adversary" in agent_name:
        adversary_rewards += reward
    if "agent" in agent_name:
        agent_rewards += reward
    
    if agent_name == "agent_0":
#         print("obs", np.round(observation, 2))
#         print("obs[vel]", np.round(observation[:2], 2))
#         print("obs[pos]", np.round(observation[2:4], 2))
#         print("obs[landmark1]", np.round(observation[6:8], 2))
#         print("obs[landmark2]", np.round(observation[8:10], 2))
        pass
    elif agent_name == "adversary_0":
        print(observation.shape)
    # time.sleep(0.1)

print(f"episode ran for {agent_step_idx} steps")
print("agent_rewards", agent_rewards)
print("adversary_rewards", adversary_rewards)

(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,)
(24,

In [37]:
# Agent and adversary observations:







IndentationError: expected an indented block (2903398728.py, line 25)

In [46]:
1 + (2 == 2)

2

### How to train the agents?

- Use the differental inter-agent learning (DIAL) algorithm.
- Use parameter sharing for DAIL agents. Separate parameter sets for adversary agents and good agents.
- It's not entirely clear the authors accumulate gradients for differentiable communication, but it 

Messages are vectors. Length 4, 5 should work.

Concatenate the messages from all the actors and add them to the message input for the current agent.

The names of agents are: 
adversary_0 adversary_1 adversary_2 agent_0 agent_1 agent_2

## Scratch work

In [15]:
a = torch.tensor([1,3,2,0])
torch.argmax(a).item(), torch.max(a), a[2]

(1, tensor(3), tensor(2))

In [3]:
d = {1: 'a', 2: 'b', 3: 'c'}
for i in d:
    print(i , end=' ')

1 2 3 

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
a = torch.tensor(2, device=device)
b = torch.tensor(3)
a*b

tensor(6, device='cuda:0')

In [9]:
v = torch.arange(6)
a = torch.tensor([9, 8])

idx = 4

torch.hstack((v[:idx], a, v[idx + 2:]))


tensor([0, 1, 2, 3, 9, 8])

In [14]:
w = torch.tensor([0,1,2])
w.device
w.to(device)
w.device

device(type='cpu')