From 6afafeafa284e55cd154aac5761245803c7cd163 Mon Sep 17 00:00:00 2001 From: aptsunny Date: Tue, 18 Oct 2022 16:56:30 +0800 Subject: [PATCH] redesign candidate --- mmrazor/engine/runner/evolution_search_loop.py | 5 ----- mmrazor/engine/runner/utils/check.py | 2 +- mmrazor/structures/subnet/candidate.py | 10 +++++++--- tests/test_runners/test_utils/test_check.py | 5 ----- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index 6107a3fb6..4aea0f94d 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -348,8 +348,3 @@ def _check_constraints( constraints_range=self.constraints_range) return is_pass, results - - -if __name__ == '__main__': - import unittest - unittest.main() diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index d094b0a47..203abfd88 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -22,7 +22,7 @@ def check_subnet_resources( """Check whether is beyond resources constraints. Returns: - bool: The result of checking. + bool, result: The result of checking. """ if constraints_range is None: return True, dict() diff --git a/mmrazor/structures/subnet/candidate.py b/mmrazor/structures/subnet/candidate.py index bba75c242..ca0c8236e 100644 --- a/mmrazor/structures/subnet/candidate.py +++ b/mmrazor/structures/subnet/candidate.py @@ -96,13 +96,13 @@ def _format(self, data: _format_input) -> _format_return: def _format_item( cond: Union[Dict, Dict[str, Dict]]) -> Dict[str, Dict]: """Transform Dict to Dict[str, Dict].""" - if isinstance(list(cond.values())[0], str): - return {str(cond): {}.fromkeys(self._indicators, -1)} - else: + if isinstance(list(cond.values())[0], dict): for value in list(cond.values()): for key in list(self._indicators): value.setdefault(key, 0.) return cond + else: + return {str(cond): {}.fromkeys(self._indicators, -1)} if isinstance(data, UserList): return [_format_item(i) for i in data.data] @@ -134,6 +134,10 @@ def extend(self, other: Any) -> None: else: self.data.extend([other]) + def set_score(self, i: int, score: float) -> None: + """Set score to the specified subnet by index.""" + self.set_resource(i, score, 'score') + def set_resource(self, i: int, resources: float, diff --git a/tests/test_runners/test_utils/test_check.py b/tests/test_runners/test_utils/test_check.py index 12d97c053..2f3a80eaa 100644 --- a/tests/test_runners/test_utils/test_check.py +++ b/tests/test_runners/test_utils/test_check.py @@ -42,8 +42,3 @@ def test_check_subnet_resources(mock_model, mock_estimator): is_pass, _ = check_subnet_resources(mock_model, fake_subnet, mock_estimator, constraints_range) assert is_pass is False - - -if __name__ == '__main__': - import unittest - unittest.main()