Permalink
Browse files

switch to pytest (#495)

* switch to pytest

* remove observation space sampling

* fix test
  • Loading branch information...
joschu committed Feb 12, 2017
1 parent 33a6112 commit 6f4f5653de3cee3d92bcbaa08ebd3a102a864299
View
@@ -249,13 +249,11 @@ See the ``examples`` directory.
Testing
=======
We are using `nose2 <https://github.com/nose-devs/nose2>`_ for tests. You can run them via:
We are using `pytest <http://doc.pytest.org>`_ for tests. You can run them via:
.. code:: shell
nose2
You can also run tests in a specific directory by using the ``-s`` option, or by passing in the specific name of the test. See the `nose2 docs <http://nose2.readthedocs.org/en/latest/usage.html#naming-tests>`_ for more details.
pytest
What's new
----------
@@ -254,6 +254,8 @@ def _get_obs(self, pos=None):
pos = self.read_head_position
if pos < 0:
return self.base
if isinstance(pos, np.ndarray):
pos = pos.item()
try:
return self.input_data[pos]
except IndexError:
@@ -0,0 +1,25 @@
from gym import envs
import os
import logging
logger = logging.getLogger(__name__)
def should_skip_env_spec_for_tests(spec):
# We skip tests for envs that require dependencies or are otherwise
# troublesome to run frequently
ep = spec._entry_point
# Skip mujoco tests for pull request CI
skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco')))
if skip_mujoco and ep.startswith('gym.envs.mujoco:'):
return True
if ( spec.id.startswith("Go") or
spec.id.startswith("Hex") or
ep.startswith('gym.envs.box2d:') or
ep.startswith('gym.envs.parameter_tuning:') or
ep.startswith('gym.envs.safety:Semisuper') or
(ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong"))
):
logger.warning("Skipping tests for env {}".format(ep))
return True
return False
spec_list = [spec for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec._entry_point is not None and not should_skip_env_spec_for_tests(spec)]
@@ -1,20 +1,14 @@
import numpy as np
from nose2 import tools
import pytest
import os
import logging
logger = logging.getLogger(__name__)
import gym
from gym import envs, spaces
from gym.envs.tests.spec_list import spec_list
from gym.envs.tests.test_envs import should_skip_env_spec_for_tests
specs = [spec for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec._entry_point is not None]
@tools.params(*specs)
@pytest.mark.parametrize("spec", spec_list)
def test_env(spec):
if should_skip_env_spec_for_tests(spec):
return
# Note that this precludes running this test in multiple
# threads. However, we probably already can't do multithreading
@@ -24,7 +18,6 @@ def test_env(spec):
env1 = spec.make()
env1.seed(0)
action_samples1 = [env1.action_space.sample() for i in range(4)]
observation_samples1 = [env1.observation_space.sample() for i in range(4)]
initial_observation1 = env1.reset()
step_responses1 = [env1.step(action) for action in action_samples1]
env1.close()
@@ -34,17 +27,13 @@ def test_env(spec):
env2 = spec.make()
env2.seed(0)
action_samples2 = [env2.action_space.sample() for i in range(4)]
observation_samples2 = [env2.observation_space.sample() for i in range(4)]
initial_observation2 = env2.reset()
step_responses2 = [env2.step(action) for action in action_samples2]
env2.close()
for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
assert_equals(action_sample1, action_sample2), '[{}] action_sample1: {}, action_sample2: {}'.format(i, action_sample1, action_sample2)
for (observation_sample1, observation_sample2) in zip(observation_samples1, observation_samples2):
assert_equals(observation_sample1, observation_sample2)
# Don't check rollout equality if it's a a nondeterministic
# environment.
if spec.nondeterministic:
@@ -1,46 +1,18 @@
import numpy as np
from nose2 import tools
import pytest
import os
import logging
logger = logging.getLogger(__name__)
import gym
from gym import envs
def should_skip_env_spec_for_tests(spec):
# We skip tests for envs that require dependencies or are otherwise
# troublesome to run frequently
ep = spec._entry_point
# Skip mujoco tests for pull request CI
skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco')))
if skip_mujoco and ep.startswith('gym.envs.mujoco:'):
return True
if ( spec.id.startswith("Go") or
spec.id.startswith("Hex") or
ep.startswith('gym.envs.box2d:') or
ep.startswith('gym.envs.parameter_tuning:') or
ep.startswith('gym.envs.safety:Semisuper') or
(ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong"))
):
logger.warning("Skipping tests for env {}".format(ep))
return True
return False
from gym.envs.tests.spec_list import spec_list
# This runs a smoketest on each official registered env. We may want
# to try also running environments which are not officially registered
# envs.
specs = [spec for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec._entry_point is not None]
@tools.params(*specs)
@pytest.mark.parametrize("spec", spec_list)
def test_env(spec):
if should_skip_env_spec_for_tests(spec):
return
env = spec.make()
ob_space = env.observation_space
act_space = env.action_space
@@ -3,14 +3,11 @@
import hashlib
import os
import sys
from nose2 import tools
import logging
import pytest
logger = logging.getLogger(__name__)
from gym import envs, spaces
from gym.envs.tests.test_envs import should_skip_env_spec_for_tests
from gym.envs.tests.spec_list import spec_list
DATA_DIR = os.path.dirname(__file__)
ROLLOUT_STEPS = 100
@@ -62,14 +59,13 @@ def generate_rollout_hash(spec):
return observations_hash, actions_hash, rewards_hash, dones_hash
specs = [spec for spec in sorted(envs.registry.all(), key=lambda x: x.id) if spec._entry_point is not None]
@tools.params(*specs)
@pytest.mark.parametrize("spec", spec_list)
def test_env_semantics(spec):
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
if spec.id not in rollout_dict or should_skip_env_spec_for_tests(spec):
if not spec.nondeterministic or should_skip_env_spec_for_tests(spec):
if spec.id not in rollout_dict:
if not spec.nondeterministic:
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
return
@@ -2,9 +2,7 @@
import os
import shutil
import tempfile
import numpy as np
from nose2 import tools
import gym
from gym.monitoring import VideoRecorder
@@ -1,16 +1,16 @@
import json # note: ujson fails this test due to float equality
import numpy as np
from nose2 import tools
import pytest
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete
@tools.params(Discrete(3),
@pytest.mark.parametrize("space", [
Discrete(3),
Tuple([Discrete(5), Discrete(10)]),
Tuple([Discrete(5), Box(np.array([0,0]),np.array([1,5]))]),
Tuple((Discrete(5), Discrete(2), Discrete(2))),
MultiDiscrete([ [0, 1], [0, 1], [0, 100] ]),
)
MultiDiscrete([ [0, 1], [0, 1], [0, 100] ])
])
def test_roundtripping(space):
sample_1 = space.sample()
sample_2 = space.sample()
View
@@ -1,5 +1,5 @@
# Testing
nose2
pytest
mock
-e .[all]
View
@@ -36,5 +36,5 @@
],
extras_require=extras,
package_data={'gym': ['envs/mujoco/assets/*.xml', 'envs/classic_control/assets/*.png']},
tests_require=['nose2', 'mock'],
tests_require=['pytest', 'mock'],
)
View
@@ -10,7 +10,7 @@ envlist = py27, py34
whitelist_externals=make
passenv=DISPLAY TRAVIS*
deps =
nose2
pytest
mock
atari_py>=0.0.17
Pillow
@@ -27,13 +27,13 @@ deps =
six
pyglet>=1.2.0
commands =
nose2 {posargs}
pytest {posargs}
[testenv:py27]
whitelist_externals=make
passenv=DISPLAY TRAVIS*
deps =
nose2
pytest
mock
atari_py>=0.0.17
Pillow
@@ -50,4 +50,4 @@ deps =
six
pyglet>=1.2.0
commands =
nose2 {posargs}
pytest {posargs}

0 comments on commit 6f4f565

Please sign in to comment.