Skip to content

Commit

Permalink
Pettingzoo env (#205)
Browse files Browse the repository at this point in the history
* added pettingzooenv

* improved test speed, made name required for pettingzoo envs

* added duplicate method to pettingzoo env

* fixed duplicate method in pettingzoo env

* fixed duplicate method in pettingzoo env

* fixed observation dtype for pettingzoo env

* fixed action space conversion issue

* fixed existing tests

* rename test and bump supersuit version

* fix MultiagentEnvTest

* add additional pettingzoo tests

* run autoformatter

Co-authored-by: Chris Nota <cpnota@gmail.com>
  • Loading branch information
benblack769 and cpnota committed Jan 21, 2021
1 parent a37ed4f commit fad9f4a
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 115 deletions.
2 changes: 2 additions & 0 deletions all/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from .gym import GymEnvironment
from .atari import AtariEnvironment
from .multiagent_atari import MultiagentAtariEnv
from .multiagent_pettingzoo import MultiagentPettingZooEnv

__all__ = [
"Environment",
"MultiagentEnvironment",
"GymEnvironment",
"AtariEnvironment",
"MultiagentAtariEnv",
"MultiagentPettingZooEnv",
]
105 changes: 6 additions & 99 deletions all/environments/multiagent_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import gym
from all.core import MultiAgentState
from ._multiagent_environment import MultiagentEnvironment
from .multiagent_pettingzoo import MultiagentPettingZooEnv


class MultiagentAtariEnv(MultiagentEnvironment):
class MultiagentAtariEnv(MultiagentPettingZooEnv):
'''
A wrapper for PettingZoo Atari environments (see: https://www.pettingzoo.ml/atari).
Expand All @@ -20,108 +21,14 @@ class MultiagentAtariEnv(MultiagentEnvironment):

def __init__(self, env_name, device='cuda'):
env = self._load_env(env_name)
env.reset()
self._env = env
self._name = env_name
self._device = device
self.agents = self._env.agents
self.subenvs = {
agent: SubEnv(agent, device, self.state_spaces[agent], self.action_spaces[agent])
for agent in self.agents
}

'''
Reset the environment and return a new intial state.
Returns:
An initial MultiagentState object.
'''

def reset(self):
self._env.reset()
return self.last()

'''
Reset the environment and return a new intial state.
Args:
action (int): An int or tensor containing a single integer representing the action.
Returns:
The MultiagentState object for the next agent
'''

def step(self, action):
if action is None:
self._env.step(action)
return
if torch.is_tensor(action):
self._env.step(action.item())
else:
self._env.step(action)
return self.last()

def render(self, mode='human'):
return self._env.render(mode=mode)

def close(self):
self._env.close()

def agent_iter(self):
return self._env.agent_iter()

def is_done(self, agent):
return self._env.dones[agent]

def last(self):
observation, reward, done, info = self._env.last()
observation = np.expand_dims(observation, 0)
return MultiAgentState.from_zoo(
self._env.agent_selection,
(observation, reward, done, info),
device=self._device,
dtype=np.uint8
)

def seed(self, seed):
self._env.seed(seed)

@property
def name(self):
return self._name

@property
def device(self):
return self._device

@property
def state_spaces(self):
return {agent: gym.spaces.Box(0, 255, (1, 84, 84), np.uint8) for agent in self._env.possible_agents}

@property
def observation_spaces(self):
return self.state_spaces

@property
def action_spaces(self):
return self._env.action_spaces
super().__init__(env, name=env_name, device=device)

def _load_env(self, env_name):
from pettingzoo import atari
from supersuit import resize_v0, frame_skip_v0
from supersuit import resize_v0, frame_skip_v0, reshape_v0, max_observation_v0
env = importlib.import_module('pettingzoo.atari.{}'.format(env_name)).env(obs_type='grayscale_image')
env = max_observation_v0(env, 2)
env = frame_skip_v0(env, 4)
env = resize_v0(env, 84, 84)
env = reshape_v0(env, (1, 84, 84))
return env


class SubEnv():
def __init__(self, name, device, state_space, action_space):
self.name = name
self.device = device
self.state_space = state_space
self.action_space = action_space

@property
def observation_space(self):
return self.state_space
130 changes: 130 additions & 0 deletions all/environments/multiagent_pettingzoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import importlib
import numpy as np
import torch
import cloudpickle
import gym
from all.core import MultiAgentState
from ._multiagent_environment import MultiagentEnvironment


class MultiagentPettingZooEnv(MultiagentEnvironment):
'''
A wrapper for generael PettingZoo environments (see: https://www.pettingzoo.ml/).
This wrapper converts the output of the PettingZoo environment to PyTorch tensors,
and wraps them in a State object that can be passed to an Agent.
Args:
zoo_env (AECEnv): A PettingZoo AECEnv environment (e.g. pettingzoo.mpe.simple_push_v2)
device (optional): the device on which tensors will be stored
'''

def __init__(self, zoo_env, name, device='cuda'):
env = zoo_env
env.reset()
self._env = env
self._name = name
self._device = device
self.agents = self._env.agents
self.subenvs = {
agent: SubEnv(agent, device, self.state_spaces[agent], self.action_spaces[agent])
for agent in self.agents
}

'''
Reset the environment and return a new intial state.
Returns:
An initial MultiagentState object.
'''

def reset(self):
self._env.reset()
return self.last()

'''
Reset the environment and return a new intial state.
Args:
action (int): An int or tensor containing a single integer representing the action.
Returns:
The MultiagentState object for the next agent
'''

def step(self, action):
if action is None:
self._env.step(action)
return
self._env.step(self._convert(action))
return self.last()

def seed(self, seed):
self._env.seed(seed)

def render(self, mode='human'):
return self._env.render(mode=mode)

def close(self):
self._env.close()

def agent_iter(self):
return self._env.agent_iter()

def is_done(self, agent):
return self._env.dones[agent]

def duplicate(self, n):
return [MultiagentPettingZooEnv(cloudpickle.loads(cloudpickle.dumps(self._env)), self._name, device=self.device) for _ in range(n)]

def last(self):
observation, reward, done, info = self._env.last()
selected_obs_space = self._env.observation_spaces[self._env.agent_selection]
return MultiAgentState.from_zoo(self._env.agent_selection, (observation, reward, done, info), device=self._device, dtype=selected_obs_space.dtype)

@property
def name(self):
return self._name

@property
def device(self):
return self._device

@property
def agent_selection(self):
return self._env.agent_selection

@property
def state_spaces(self):
return self._env.observation_spaces

@property
def observation_spaces(self):
return self._env.observation_spaces

@property
def action_spaces(self):
return self._env.action_spaces

def _convert(self, action):
agent = self._env.agent_selection
action_space = self._env.action_spaces[agent]
if torch.is_tensor(action):
if isinstance(action_space, gym.spaces.Discrete):
return action.item()
if isinstance(action_space, gym.spaces.Box):
return action.cpu().detach().numpy().reshape(-1)
raise TypeError("Unknown action space type")
return action


class SubEnv():
def __init__(self, name, device, state_space, action_space):
self.name = name
self.device = device
self.state_space = state_space
self.action_space = action_space

@property
def observation_space(self):
return self.state_space
95 changes: 95 additions & 0 deletions all/environments/multiagent_pettingzoo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import unittest
import torch
from all.environments import MultiagentPettingZooEnv
from pettingzoo.mpe import simple_world_comm_v2


class MultiagentPettingZooEnvTest(unittest.TestCase):
def test_init(self):
self._make_env()

def test_reset(self):
env = self._make_env()
state = env.reset()
self.assertEqual(state.observation.shape, (34,))
self.assertEqual(state.reward, 0)
self.assertEqual(state.done, False)
self.assertEqual(state.mask, 1.)
self.assertEqual(state['agent'], 'leadadversary_0')

def test_step(self):
env = self._make_env()
env.reset()
state = env.step(0)
self.assertEqual(state.observation.shape, (34,))
self.assertEqual(state.reward, 0)
self.assertEqual(state.done, False)
self.assertEqual(state.mask, 1.)
self.assertEqual(state['agent'], 'adversary_0')

def test_step_tensor(self):
env = self._make_env()
env.reset()
state = env.step(0)
self.assertEqual(state.observation.shape, (34,))
self.assertEqual(state.reward, 0)
self.assertEqual(state.done, False)
self.assertEqual(state.mask, 1.)
self.assertEqual(state['agent'], 'adversary_0')

def test_name(self):
env = self._make_env()
self.assertEqual(env.name, 'simple_world_comm_v2')

def test_agent_iter(self):
env = self._make_env()
env.reset()
it = iter(env.agent_iter())
self.assertEqual(next(it), 'leadadversary_0')

def test_state_spaces(self):
state_spaces = self._make_env().state_spaces
self.assertEqual(state_spaces['leadadversary_0'].shape, (34,))
self.assertEqual(state_spaces['adversary_0'].shape, (34,))

def test_action_spaces(self):
action_spaces = self._make_env().action_spaces
self.assertEqual(action_spaces['leadadversary_0'].n, 20)
self.assertEqual(action_spaces['adversary_0'].n, 5)

def test_list_agents(self):
env = self._make_env()
print(env.agents)
self.assertEqual(env.agents, ['leadadversary_0', 'adversary_0', 'adversary_1', 'adversary_2', 'agent_0', 'agent_1'])

def test_is_done(self):
env = self._make_env()
env.reset()
self.assertFalse(env.is_done('leadadversary_0'))
self.assertFalse(env.is_done('adversary_0'))

def test_last(self):
env = self._make_env()
env.reset()
state = env.last()
self.assertEqual(state.observation.shape, (34,))
self.assertEqual(state.reward, 0)
self.assertEqual(state.done, False)
self.assertEqual(state.mask, 1.)
self.assertEqual(state['agent'], 'leadadversary_0')

def test_variable_spaces(self):
env = MultiagentPettingZooEnv(simple_world_comm_v2.env(), name="simple_world_comm_v2", device='cpu')
state = env.reset()
# tests that action spaces work
for agent in env.agents:
state = env.last()
self.assertTrue(env.observation_spaces[agent].contains(state['observation'].cpu().detach().numpy()))
env.step(env.action_spaces[env.agent_selection].sample())

def _make_env(self):
return MultiagentPettingZooEnv(simple_world_comm_v2.env(), name="simple_world_comm_v2", device='cpu')


if __name__ == "__main__":
unittest.main()

0 comments on commit fad9f4a

Please sign in to comment.