Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ omit =
*/python?.?/*
*/site-packages/nose/*
*__init__*
*/setup.py
*rllab*

exclude_lines =
if __name__ == .__main__.:
raise NotImplementedError
13 changes: 3 additions & 10 deletions tests/fast_tests/test_environment_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,7 @@ def tearDown(self):

def test_get_state(self):
"""Checks that get_state raises an error."""
try:
self.env.get_state()
raise AssertionError
except NotImplementedError:
return
self.assertRaises(NotImplementedError, self.env.get_state)

def test_action_space(self):
try:
Expand All @@ -467,11 +463,8 @@ def test_compute_reward(self):
self.assertEqual(self.env.compute_reward([]), 0)

def test__apply_rl_actions(self):
try:
self.env._apply_rl_actions(None)
raise AssertionError
except NotImplementedError:
return
self.assertRaises(NotImplementedError, self.env._apply_rl_actions,
rl_actions=None)


class TestVehicleColoring(unittest.TestCase):
Expand Down
14 changes: 5 additions & 9 deletions tests/fast_tests/test_scenario_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ def test_edges_distribution(self):
# check that all vehicles are only placed in edges specified in the
# edges_distribution term
for veh_id in self.env.vehicles.get_ids():
if self.env.vehicles.get_edge(veh_id) not in edges:
raise AssertionError
self.assertTrue(self.env.vehicles.get_edge(veh_id) in edges)

def test_num_vehicles(self):
"""
Expand Down Expand Up @@ -487,9 +486,8 @@ def test_lanes_distribution(self):

# verify that all vehicles are located in the number of allocated lanes
for veh_id in self.env.vehicles.get_ids():
if self.env.vehicles.get_lane(veh_id) >= \
initial_config.lanes_distribution:
raise AssertionError
self.assertLess(self.env.vehicles.get_lane(veh_id),
initial_config.lanes_distribution)

def test_edges_distribution(self):
"""
Expand All @@ -507,8 +505,7 @@ def test_edges_distribution(self):
# check that all vehicles are only placed in edges specified in the
# edges_distribution term
for veh_id in self.env.vehicles.get_ids():
if self.env.vehicles.get_edge(veh_id) not in edges:
raise AssertionError
self.assertTrue(self.env.vehicles.get_edge(veh_id) in edges)


class TestEvenStartPosVariableLanes(unittest.TestCase):
Expand Down Expand Up @@ -549,8 +546,7 @@ def test_even_start_pos_coverage(self):

# check that all possible lanes are covered
lanes = self.env.vehicles.get_lane(self.env.vehicles.get_ids())
if any(i not in lanes for i in range(4)):
raise AssertionError
self.assertFalse(any(i not in lanes for i in range(4)))


class TestRandomStartPosVariableLanes(TestEvenStartPosVariableLanes):
Expand Down
5 changes: 2 additions & 3 deletions tests/fast_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,8 @@ def search_dicts(obj1, obj2):

# make sure that the Vehicles class that was imported matches the
# original one
if not search_dicts(imported_flow_params["veh"].__dict__,
flow_params["veh"].__dict__):
raise AssertionError
self.assertTrue(search_dicts(imported_flow_params["veh"].__dict__,
flow_params["veh"].__dict__))


if __name__ == '__main__':
Expand Down
30 changes: 15 additions & 15 deletions tests/fast_tests/test_vehicles.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TestVehiclesClass(unittest.TestCase):
Tests various functions in the vehicles class
"""

def runSpeedLaneChangeModes(self):
def test_speed_lane_change_modes(self):
"""
Check to make sure vehicle class correctly specifies lane change and
speed modes
Expand All @@ -33,7 +33,7 @@ def runSpeedLaneChangeModes(self):
lane_change_mode="no_lat_collide")

self.assertEqual(vehicles.get_speed_mode("typeA_0"), 1)
self.assertEqual(vehicles.get_lane_change_mode("typeA_0"), 256)
self.assertEqual(vehicles.get_lane_change_mode("typeA_0"), 512)

vehicles.add(
"typeB",
Expand All @@ -42,7 +42,7 @@ def runSpeedLaneChangeModes(self):
lane_change_mode="strategic")

self.assertEqual(vehicles.get_speed_mode("typeB_0"), 0)
self.assertEqual(vehicles.get_lane_change_mode("typeB_0"), 853)
self.assertEqual(vehicles.get_lane_change_mode("typeB_0"), 1621)

vehicles.add(
"typeC",
Expand Down Expand Up @@ -145,18 +145,18 @@ def test_remove(self):
vehicles.remove("test_rl_0")

# ensure that the removed vehicle's ID is not in any lists of vehicles
if "test_0" in vehicles.get_ids():
raise AssertionError("vehicle still in get_ids()")
if "test_0" in vehicles.get_human_ids():
raise AssertionError("vehicle still in get_controlled_lc_ids()")
if "test_0" in vehicles.get_controlled_lc_ids():
raise AssertionError("vehicle still in get_controlled_lc_ids()")
if "test_0" in vehicles.get_controlled_ids():
raise AssertionError("vehicle still in get_controlled_ids()")
if "test_rl_0" in vehicles.get_ids():
raise AssertionError("RL vehicle still in get_ids()")
if "test_rl_0" in vehicles.get_rl_ids():
raise AssertionError("RL vehicle still in get_rl_ids()")
self.assertTrue("test_0" not in vehicles.get_ids(),
msg="vehicle still in get_ids()")
self.assertTrue("test_0" not in vehicles.get_human_ids(),
msg="vehicle still in get_controlled_lc_ids()")
self.assertTrue("test_0" not in vehicles.get_controlled_lc_ids(),
msg="vehicle still in get_controlled_lc_ids()")
self.assertTrue("test_0" not in vehicles.get_controlled_ids(),
msg="vehicle still in get_controlled_ids()")
self.assertTrue("test_rl_0" not in vehicles.get_ids(),
msg="RL vehicle still in get_ids()")
self.assertTrue("test_rl_0" not in vehicles.get_rl_ids(),
msg="RL vehicle still in get_rl_ids()")

# ensure that the vehicles are not storing extra information in the
# vehicles.__vehicles dict
Expand Down