Skip to content

Commit

Permalink
fixes CI test by using new version of pong (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
benblack769 committed Jun 14, 2021
1 parent 4122e0a commit 3cf3d5e
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
24 changes: 12 additions & 12 deletions all/environments/multiagent_atari_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

class MultiagentAtariEnvTest(unittest.TestCase):
def test_init(self):
MultiagentAtariEnv('pong_v1', device='cpu')
MultiagentAtariEnv('pong_v2', device='cpu')
MultiagentAtariEnv('mario_bros_v2', device='cpu')
MultiagentAtariEnv('entombed_cooperative_v2', device='cpu')

def test_reset(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
env = MultiagentAtariEnv('pong_v2', device='cpu')
state = env.reset()
self.assertEqual(state.observation.shape, (1, 84, 84))
self.assertEqual(state.reward, 0)
Expand All @@ -19,7 +19,7 @@ def test_reset(self):
self.assertEqual(state['agent'], 'first_0')

def test_step(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
env = MultiagentAtariEnv('pong_v2', device='cpu')
env.reset()
state = env.step(0)
self.assertEqual(state.observation.shape, (1, 84, 84))
Expand All @@ -29,7 +29,7 @@ def test_step(self):
self.assertEqual(state['agent'], 'second_0')

def test_step_tensor(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
env = MultiagentAtariEnv('pong_v2', device='cpu')
env.reset()
state = env.step(torch.tensor([0]))
self.assertEqual(state.observation.shape, (1, 84, 84))
Expand All @@ -39,37 +39,37 @@ def test_step_tensor(self):
self.assertEqual(state['agent'], 'second_0')

def test_name(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
self.assertEqual(env.name, 'pong_v1')
env = MultiagentAtariEnv('pong_v2', device='cpu')
self.assertEqual(env.name, 'pong_v2')

def test_agent_iter(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
env = MultiagentAtariEnv('pong_v2', device='cpu')
env.reset()
it = iter(env.agent_iter())
self.assertEqual(next(it), 'first_0')

def test_state_spaces(self):
state_spaces = MultiagentAtariEnv('pong_v1', device='cpu').state_spaces
state_spaces = MultiagentAtariEnv('pong_v2', device='cpu').state_spaces
self.assertEqual(state_spaces['first_0'].shape, (1, 84, 84))
self.assertEqual(state_spaces['second_0'].shape, (1, 84, 84))

def test_action_spaces(self):
action_spaces = MultiagentAtariEnv('pong_v1', device='cpu').action_spaces
action_spaces = MultiagentAtariEnv('pong_v2', device='cpu').action_spaces
self.assertEqual(action_spaces['first_0'].n, 18)
self.assertEqual(action_spaces['second_0'].n, 18)

def test_list_agents(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
env = MultiagentAtariEnv('pong_v2', device='cpu')
self.assertEqual(env.agents, ['first_0', 'second_0'])

def test_is_done(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
env = MultiagentAtariEnv('pong_v2', device='cpu')
env.reset()
self.assertFalse(env.is_done('first_0'))
self.assertFalse(env.is_done('second_0'))

def test_last(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
env = MultiagentAtariEnv('pong_v2', device='cpu')
env.reset()
state = env.last()
self.assertEqual(state.observation.shape, (1, 84, 84))
Expand Down
4 changes: 2 additions & 2 deletions all/presets/multiagent_atari_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

class TestMultiagentAtariPresets(unittest.TestCase):
def setUp(self):
self.env = MultiagentAtariEnv('pong_v1', device='cpu')
self.env = MultiagentAtariEnv('pong_v2', device='cpu')
self.env.reset()

def tearDown(self):
if os.path.exists('test_preset.pt'):
os.remove('test_preset.pt')

def test_independent(self):
env = MultiagentAtariEnv('pong_v1', device='cpu')
env = MultiagentAtariEnv('pong_v2', device='cpu')
presets = {
agent_id: dqn.device('cpu').env(env.subenvs[agent_id]).build()
for agent_id in env.agents
Expand Down
4 changes: 2 additions & 2 deletions integration/multiagent_atari_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@

class TestMultiagentAtariPresets(unittest.TestCase):
def test_independent(self):
env = MultiagentAtariEnv('pong_v1', max_cycles=1000, device=CPU)
env = MultiagentAtariEnv('pong_v2', max_cycles=1000, device=CPU)
presets = {
agent_id: dqn.device(CPU).env(env.subenvs[agent_id]).build()
for agent_id in env.agents
}
validate_multiagent(IndependentMultiagentPreset('independent', CPU, presets), env)

def test_independent_cuda(self):
env = MultiagentAtariEnv('pong_v1', max_cycles=1000, device=CUDA)
env = MultiagentAtariEnv('pong_v2', max_cycles=1000, device=CUDA)
presets = {
agent_id: dqn.device(CUDA).env(env.subenvs[agent_id]).build()
for agent_id in env.agents
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"pybullet>=3.0.6", # open-source robotics environments
],
"ma-atari": [
"PettingZoo[atari]>=1.5.0", # Multiagent atari environments
"PettingZoo[atari]>=1.9.0", # Multiagent atari environments
"supersuit>=2.4.0", # Multiagent env wrappers
"AutoROM>=0.1.19", # Tool for downloading ROMs
],
Expand Down

0 comments on commit 3cf3d5e

Please sign in to comment.