diff --git a/tests/test_metafiles.py b/.dev_scripts/meta_files_test.py similarity index 94% rename from tests/test_metafiles.py rename to .dev_scripts/meta_files_test.py index 8cf8d3a93..92c0f2f0d 100644 --- a/tests/test_metafiles.py +++ b/.dev_scripts/meta_files_test.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os +import unittest from pathlib import Path import requests @@ -8,7 +9,7 @@ MMRAZOR_ROOT = Path(__file__).absolute().parents[1] -class TestMetafiles: +class TestMetafiles(unittest.TestCase): def get_metafiles(self, code_path): """ @@ -51,3 +52,7 @@ def test_metafiles(self): assert model['Name'] == correct_name, \ f'name error in {metafile}, correct name should ' \ f'be {correct_name}' + + +if __name__ == '__main__': + unittest.main() diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py index 88d2e6067..057422290 100644 --- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -204,23 +204,24 @@ def forward(self, data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'tensor') -> ForwardResults: """Forward.""" - if not hasattr(self, 'prune_config_manager'): - # self._iters_per_epoch() only available after initiation - self.prune_config_manager = self._init_prune_config_manager() - if self.prune_config_manager.is_prune_time(self._iter): + if self.training: + if not hasattr(self, 'prune_config_manager'): + # self._iters_per_epoch() only available after initiation + self.prune_config_manager = self._init_prune_config_manager() + if self.prune_config_manager.is_prune_time(self._iter): - config = self.prune_config_manager.prune_at(self._iter) + config = self.prune_config_manager.prune_at(self._iter) - self.mutator.set_choices(config) + self.mutator.set_choices(config) - logger = MMLogger.get_current_instance() - if (self.by_epoch): - logger.info( - f'The model is pruned at {self._epoch}th epoch once.') - else: - logger.info( - f'The model is pruned at {self._iter}th iter once.') + logger = MMLogger.get_current_instance() + if (self.by_epoch): + logger.info( + f'The model is pruned at {self._epoch}th epoch once.') + else: + logger.info( + f'The model is pruned at {self._iter}th iter once.') return super().forward(inputs, data_samples, mode) diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py index 07b85f6c6..c2b4f9291 100644 --- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -27,7 +27,6 @@ def __init__(self, num_channels: int, choice_mode='number', **kwargs): super().__init__(num_channels, **kwargs) assert choice_mode in ['ratio', 'number'] self.choice_mode = choice_mode - self.mask = torch.ones([self.num_channels]).bool() @property def is_num_mode(self): @@ -50,14 +49,13 @@ def current_choice(self, choice: Union[int, float]): int_choice = self._ratio2num(choice) else: int_choice = choice - mask = torch.zeros([self.num_channels], device=self.mask.device) - mask[0:int_choice] = 1 - self.mask = mask.bool() + self.mask.fill_(0.0) + self.mask[0:int_choice] = 1.0 @property def current_mask(self) -> torch.Tensor: """Return current mask.""" - return self.mask + return self.mask.bool() # methods for diff --git a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py index dd3057fed..9e85f81a3 100644 --- a/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py @@ -20,7 +20,10 @@ class SimpleMutableChannel(BaseMutableChannel): def __init__(self, num_channels: int, **kwargs) -> None: super().__init__(num_channels, **kwargs) - self.mask = torch.ones(num_channels).bool() + mask = torch.ones([self.num_channels + ]) # save bool as float for dist training + self.register_buffer('mask', mask) + self.mask: torch.Tensor # choice @@ -32,7 +35,7 @@ def current_choice(self) -> torch.Tensor: @current_choice.setter def current_choice(self, choice: torch.Tensor): """Set current choice.""" - self.mask = choice.to(self.mask.device).bool() + self.mask = choice.to(self.mask.device).float() @property def current_mask(self) -> torch.Tensor: diff --git a/requirements/tests.txt b/requirements/tests.txt index 5dd7d144b..8763670ef 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -4,6 +4,7 @@ interrogate isort==4.3.21 nbconvert nbformat +numpy < 1.24.0 # A temporary solution for tests with mmdet. pytest xdoctest >= 0.10.0 yapf diff --git a/tests/test_models/test_algorithms/test_prune_algorithm.py b/tests/test_models/test_algorithms/test_prune_algorithm.py index d1524ead4..536da67fd 100644 --- a/tests/test_models/test_algorithms/test_prune_algorithm.py +++ b/tests/test_models/test_algorithms/test_prune_algorithm.py @@ -261,3 +261,30 @@ def test_dist_init(self): algorithm.forward( data['inputs'], data['data_samples'], mode='loss') self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch) + + def test_resume(self): + algorithm: ItePruneAlgorithm = ItePruneAlgorithm( + MODEL_CFG, + mutator_cfg=MUTATOR_CONFIG_NUM, + target_pruning_ratio=None, + step_freq=1, + prune_times=1, + ).to(DEVICE) + algorithm.mutator.set_choices(algorithm.mutator.sample_choices()) + state_dict = algorithm.state_dict() + print(state_dict.keys()) + + algorithm2: ItePruneAlgorithm = ItePruneAlgorithm( + MODEL_CFG, + mutator_cfg=MUTATOR_CONFIG_NUM, + target_pruning_ratio=None, + step_freq=1, + prune_times=1, + ).to(DEVICE) + + algorithm2.load_state_dict(state_dict) + + print(algorithm.mutator.current_choices) + print(algorithm2.mutator.current_choices) + self.assertDictEqual(algorithm.mutator.current_choices, + algorithm2.mutator.current_choices)