diff --git a/.coveragerc b/.coveragerc index a5f7fcee8..24dcb8d27 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,3 +3,9 @@ omit = */python?.?/* */site-packages/nose/* *__init__* + */setup.py + *rllab* + +exclude_lines = + if __name__ == .__main__.: + raise NotImplementedError diff --git a/tests/fast_tests/test_environment_base_class.py b/tests/fast_tests/test_environment_base_class.py index b10aab9cd..ee61c873d 100644 --- a/tests/fast_tests/test_environment_base_class.py +++ b/tests/fast_tests/test_environment_base_class.py @@ -442,11 +442,7 @@ def tearDown(self): def test_get_state(self): """Checks that get_state raises an error.""" - try: - self.env.get_state() - raise AssertionError - except NotImplementedError: - return + self.assertRaises(NotImplementedError, self.env.get_state) def test_action_space(self): try: @@ -467,11 +463,8 @@ def test_compute_reward(self): self.assertEqual(self.env.compute_reward([]), 0) def test__apply_rl_actions(self): - try: - self.env._apply_rl_actions(None) - raise AssertionError - except NotImplementedError: - return + self.assertRaises(NotImplementedError, self.env._apply_rl_actions, + rl_actions=None) class TestVehicleColoring(unittest.TestCase): diff --git a/tests/fast_tests/test_scenario_base_class.py b/tests/fast_tests/test_scenario_base_class.py index 785c388ca..7a05b5302 100644 --- a/tests/fast_tests/test_scenario_base_class.py +++ b/tests/fast_tests/test_scenario_base_class.py @@ -348,8 +348,7 @@ def test_edges_distribution(self): # check that all vehicles are only placed in edges specified in the # edges_distribution term for veh_id in self.env.vehicles.get_ids(): - if self.env.vehicles.get_edge(veh_id) not in edges: - raise AssertionError + self.assertTrue(self.env.vehicles.get_edge(veh_id) in edges) def test_num_vehicles(self): """ @@ -487,9 +486,8 @@ def test_lanes_distribution(self): # verify that all vehicles are located in the number of allocated lanes for veh_id in self.env.vehicles.get_ids(): - if self.env.vehicles.get_lane(veh_id) >= \ - initial_config.lanes_distribution: - raise AssertionError + self.assertLess(self.env.vehicles.get_lane(veh_id), + initial_config.lanes_distribution) def test_edges_distribution(self): """ @@ -507,8 +505,7 @@ def test_edges_distribution(self): # check that all vehicles are only placed in edges specified in the # edges_distribution term for veh_id in self.env.vehicles.get_ids(): - if self.env.vehicles.get_edge(veh_id) not in edges: - raise AssertionError + self.assertTrue(self.env.vehicles.get_edge(veh_id) in edges) class TestEvenStartPosVariableLanes(unittest.TestCase): @@ -549,8 +546,7 @@ def test_even_start_pos_coverage(self): # check that all possible lanes are covered lanes = self.env.vehicles.get_lane(self.env.vehicles.get_ids()) - if any(i not in lanes for i in range(4)): - raise AssertionError + self.assertFalse(any(i not in lanes for i in range(4))) class TestRandomStartPosVariableLanes(TestEvenStartPosVariableLanes): diff --git a/tests/fast_tests/test_util.py b/tests/fast_tests/test_util.py index 0fc8716a8..1406ecf9e 100644 --- a/tests/fast_tests/test_util.py +++ b/tests/fast_tests/test_util.py @@ -327,9 +327,8 @@ def search_dicts(obj1, obj2): # make sure that the Vehicles class that was imported matches the # original one - if not search_dicts(imported_flow_params["veh"].__dict__, - flow_params["veh"].__dict__): - raise AssertionError + self.assertTrue(search_dicts(imported_flow_params["veh"].__dict__, + flow_params["veh"].__dict__)) if __name__ == '__main__': diff --git a/tests/fast_tests/test_vehicles.py b/tests/fast_tests/test_vehicles.py index 59dab4229..39ee568f2 100644 --- a/tests/fast_tests/test_vehicles.py +++ b/tests/fast_tests/test_vehicles.py @@ -20,7 +20,7 @@ class TestVehiclesClass(unittest.TestCase): Tests various functions in the vehicles class """ - def runSpeedLaneChangeModes(self): + def test_speed_lane_change_modes(self): """ Check to make sure vehicle class correctly specifies lane change and speed modes @@ -33,7 +33,7 @@ def runSpeedLaneChangeModes(self): lane_change_mode="no_lat_collide") self.assertEqual(vehicles.get_speed_mode("typeA_0"), 1) - self.assertEqual(vehicles.get_lane_change_mode("typeA_0"), 256) + self.assertEqual(vehicles.get_lane_change_mode("typeA_0"), 512) vehicles.add( "typeB", @@ -42,7 +42,7 @@ def runSpeedLaneChangeModes(self): lane_change_mode="strategic") self.assertEqual(vehicles.get_speed_mode("typeB_0"), 0) - self.assertEqual(vehicles.get_lane_change_mode("typeB_0"), 853) + self.assertEqual(vehicles.get_lane_change_mode("typeB_0"), 1621) vehicles.add( "typeC", @@ -145,18 +145,18 @@ def test_remove(self): vehicles.remove("test_rl_0") # ensure that the removed vehicle's ID is not in any lists of vehicles - if "test_0" in vehicles.get_ids(): - raise AssertionError("vehicle still in get_ids()") - if "test_0" in vehicles.get_human_ids(): - raise AssertionError("vehicle still in get_controlled_lc_ids()") - if "test_0" in vehicles.get_controlled_lc_ids(): - raise AssertionError("vehicle still in get_controlled_lc_ids()") - if "test_0" in vehicles.get_controlled_ids(): - raise AssertionError("vehicle still in get_controlled_ids()") - if "test_rl_0" in vehicles.get_ids(): - raise AssertionError("RL vehicle still in get_ids()") - if "test_rl_0" in vehicles.get_rl_ids(): - raise AssertionError("RL vehicle still in get_rl_ids()") + self.assertTrue("test_0" not in vehicles.get_ids(), + msg="vehicle still in get_ids()") + self.assertTrue("test_0" not in vehicles.get_human_ids(), + msg="vehicle still in get_controlled_lc_ids()") + self.assertTrue("test_0" not in vehicles.get_controlled_lc_ids(), + msg="vehicle still in get_controlled_lc_ids()") + self.assertTrue("test_0" not in vehicles.get_controlled_ids(), + msg="vehicle still in get_controlled_ids()") + self.assertTrue("test_rl_0" not in vehicles.get_ids(), + msg="RL vehicle still in get_ids()") + self.assertTrue("test_rl_0" not in vehicles.get_rl_ids(), + msg="RL vehicle still in get_rl_ids()") # ensure that the vehicles are not storing extra information in the # vehicles.__vehicles dict