-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added pettingzooenv * improved test speed, made name required for pettingzoo envs * added duplicate method to pettingzoo env * fixed duplicate method in pettingzoo env * fixed duplicate method in pettingzoo env * fixed observation dtype for pettingzoo env * fixed action space conversion issue * fixed existing tests * rename test and bump supersuit version * fix MultiagentEnvTest * add additional pettingzoo tests * run autoformatter Co-authored-by: Chris Nota <cpnota@gmail.com>
- Loading branch information
1 parent
a37ed4f
commit fad9f4a
Showing
7 changed files
with
244 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import importlib | ||
import numpy as np | ||
import torch | ||
import cloudpickle | ||
import gym | ||
from all.core import MultiAgentState | ||
from ._multiagent_environment import MultiagentEnvironment | ||
|
||
|
||
class MultiagentPettingZooEnv(MultiagentEnvironment): | ||
''' | ||
A wrapper for generael PettingZoo environments (see: https://www.pettingzoo.ml/). | ||
This wrapper converts the output of the PettingZoo environment to PyTorch tensors, | ||
and wraps them in a State object that can be passed to an Agent. | ||
Args: | ||
zoo_env (AECEnv): A PettingZoo AECEnv environment (e.g. pettingzoo.mpe.simple_push_v2) | ||
device (optional): the device on which tensors will be stored | ||
''' | ||
|
||
def __init__(self, zoo_env, name, device='cuda'): | ||
env = zoo_env | ||
env.reset() | ||
self._env = env | ||
self._name = name | ||
self._device = device | ||
self.agents = self._env.agents | ||
self.subenvs = { | ||
agent: SubEnv(agent, device, self.state_spaces[agent], self.action_spaces[agent]) | ||
for agent in self.agents | ||
} | ||
|
||
''' | ||
Reset the environment and return a new intial state. | ||
Returns: | ||
An initial MultiagentState object. | ||
''' | ||
|
||
def reset(self): | ||
self._env.reset() | ||
return self.last() | ||
|
||
''' | ||
Reset the environment and return a new intial state. | ||
Args: | ||
action (int): An int or tensor containing a single integer representing the action. | ||
Returns: | ||
The MultiagentState object for the next agent | ||
''' | ||
|
||
def step(self, action): | ||
if action is None: | ||
self._env.step(action) | ||
return | ||
self._env.step(self._convert(action)) | ||
return self.last() | ||
|
||
def seed(self, seed): | ||
self._env.seed(seed) | ||
|
||
def render(self, mode='human'): | ||
return self._env.render(mode=mode) | ||
|
||
def close(self): | ||
self._env.close() | ||
|
||
def agent_iter(self): | ||
return self._env.agent_iter() | ||
|
||
def is_done(self, agent): | ||
return self._env.dones[agent] | ||
|
||
def duplicate(self, n): | ||
return [MultiagentPettingZooEnv(cloudpickle.loads(cloudpickle.dumps(self._env)), self._name, device=self.device) for _ in range(n)] | ||
|
||
def last(self): | ||
observation, reward, done, info = self._env.last() | ||
selected_obs_space = self._env.observation_spaces[self._env.agent_selection] | ||
return MultiAgentState.from_zoo(self._env.agent_selection, (observation, reward, done, info), device=self._device, dtype=selected_obs_space.dtype) | ||
|
||
@property | ||
def name(self): | ||
return self._name | ||
|
||
@property | ||
def device(self): | ||
return self._device | ||
|
||
@property | ||
def agent_selection(self): | ||
return self._env.agent_selection | ||
|
||
@property | ||
def state_spaces(self): | ||
return self._env.observation_spaces | ||
|
||
@property | ||
def observation_spaces(self): | ||
return self._env.observation_spaces | ||
|
||
@property | ||
def action_spaces(self): | ||
return self._env.action_spaces | ||
|
||
def _convert(self, action): | ||
agent = self._env.agent_selection | ||
action_space = self._env.action_spaces[agent] | ||
if torch.is_tensor(action): | ||
if isinstance(action_space, gym.spaces.Discrete): | ||
return action.item() | ||
if isinstance(action_space, gym.spaces.Box): | ||
return action.cpu().detach().numpy().reshape(-1) | ||
raise TypeError("Unknown action space type") | ||
return action | ||
|
||
|
||
class SubEnv(): | ||
def __init__(self, name, device, state_space, action_space): | ||
self.name = name | ||
self.device = device | ||
self.state_space = state_space | ||
self.action_space = action_space | ||
|
||
@property | ||
def observation_space(self): | ||
return self.state_space |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import unittest | ||
import torch | ||
from all.environments import MultiagentPettingZooEnv | ||
from pettingzoo.mpe import simple_world_comm_v2 | ||
|
||
|
||
class MultiagentPettingZooEnvTest(unittest.TestCase): | ||
def test_init(self): | ||
self._make_env() | ||
|
||
def test_reset(self): | ||
env = self._make_env() | ||
state = env.reset() | ||
self.assertEqual(state.observation.shape, (34,)) | ||
self.assertEqual(state.reward, 0) | ||
self.assertEqual(state.done, False) | ||
self.assertEqual(state.mask, 1.) | ||
self.assertEqual(state['agent'], 'leadadversary_0') | ||
|
||
def test_step(self): | ||
env = self._make_env() | ||
env.reset() | ||
state = env.step(0) | ||
self.assertEqual(state.observation.shape, (34,)) | ||
self.assertEqual(state.reward, 0) | ||
self.assertEqual(state.done, False) | ||
self.assertEqual(state.mask, 1.) | ||
self.assertEqual(state['agent'], 'adversary_0') | ||
|
||
def test_step_tensor(self): | ||
env = self._make_env() | ||
env.reset() | ||
state = env.step(0) | ||
self.assertEqual(state.observation.shape, (34,)) | ||
self.assertEqual(state.reward, 0) | ||
self.assertEqual(state.done, False) | ||
self.assertEqual(state.mask, 1.) | ||
self.assertEqual(state['agent'], 'adversary_0') | ||
|
||
def test_name(self): | ||
env = self._make_env() | ||
self.assertEqual(env.name, 'simple_world_comm_v2') | ||
|
||
def test_agent_iter(self): | ||
env = self._make_env() | ||
env.reset() | ||
it = iter(env.agent_iter()) | ||
self.assertEqual(next(it), 'leadadversary_0') | ||
|
||
def test_state_spaces(self): | ||
state_spaces = self._make_env().state_spaces | ||
self.assertEqual(state_spaces['leadadversary_0'].shape, (34,)) | ||
self.assertEqual(state_spaces['adversary_0'].shape, (34,)) | ||
|
||
def test_action_spaces(self): | ||
action_spaces = self._make_env().action_spaces | ||
self.assertEqual(action_spaces['leadadversary_0'].n, 20) | ||
self.assertEqual(action_spaces['adversary_0'].n, 5) | ||
|
||
def test_list_agents(self): | ||
env = self._make_env() | ||
print(env.agents) | ||
self.assertEqual(env.agents, ['leadadversary_0', 'adversary_0', 'adversary_1', 'adversary_2', 'agent_0', 'agent_1']) | ||
|
||
def test_is_done(self): | ||
env = self._make_env() | ||
env.reset() | ||
self.assertFalse(env.is_done('leadadversary_0')) | ||
self.assertFalse(env.is_done('adversary_0')) | ||
|
||
def test_last(self): | ||
env = self._make_env() | ||
env.reset() | ||
state = env.last() | ||
self.assertEqual(state.observation.shape, (34,)) | ||
self.assertEqual(state.reward, 0) | ||
self.assertEqual(state.done, False) | ||
self.assertEqual(state.mask, 1.) | ||
self.assertEqual(state['agent'], 'leadadversary_0') | ||
|
||
def test_variable_spaces(self): | ||
env = MultiagentPettingZooEnv(simple_world_comm_v2.env(), name="simple_world_comm_v2", device='cpu') | ||
state = env.reset() | ||
# tests that action spaces work | ||
for agent in env.agents: | ||
state = env.last() | ||
self.assertTrue(env.observation_spaces[agent].contains(state['observation'].cpu().detach().numpy())) | ||
env.step(env.action_spaces[env.agent_selection].sample()) | ||
|
||
def _make_env(self): | ||
return MultiagentPettingZooEnv(simple_world_comm_v2.env(), name="simple_world_comm_v2", device='cpu') | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.