Skip to content

Commit

Permalink
add choice and mask of units to checkpoint (#397)
Browse files Browse the repository at this point in the history
* add choice and mask of units to checkpoint

* update

* fix bug

* remove device operation

* fix bug

* fix circle ci error

* fix error in numpy for circle ci

* fix bug in requirements

* restore

* add a note

* a new solution

* save mutable_channel.mask as float for dist training

* refine

* mv meta file test

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: jacky <jacky@xx.com>
  • Loading branch information
3 people committed Dec 21, 2022
1 parent b2d15ec commit ae1af1d
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 21 deletions.
7 changes: 6 additions & 1 deletion tests/test_metafiles.py → .dev_scripts/meta_files_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from pathlib import Path

import requests
Expand All @@ -8,7 +9,7 @@
MMRAZOR_ROOT = Path(__file__).absolute().parents[1]


class TestMetafiles:
class TestMetafiles(unittest.TestCase):

def get_metafiles(self, code_path):
"""
Expand Down Expand Up @@ -51,3 +52,7 @@ def test_metafiles(self):
assert model['Name'] == correct_name, \
f'name error in {metafile}, correct name should ' \
f'be {correct_name}'


if __name__ == '__main__':
unittest.main()
27 changes: 14 additions & 13 deletions mmrazor/models/algorithms/pruning/ite_prune_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,23 +204,24 @@ def forward(self,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'tensor') -> ForwardResults:
"""Forward."""
if not hasattr(self, 'prune_config_manager'):
# self._iters_per_epoch() only available after initiation
self.prune_config_manager = self._init_prune_config_manager()

if self.prune_config_manager.is_prune_time(self._iter):
if self.training:
if not hasattr(self, 'prune_config_manager'):
# self._iters_per_epoch() only available after initiation
self.prune_config_manager = self._init_prune_config_manager()
if self.prune_config_manager.is_prune_time(self._iter):

config = self.prune_config_manager.prune_at(self._iter)
config = self.prune_config_manager.prune_at(self._iter)

self.mutator.set_choices(config)
self.mutator.set_choices(config)

logger = MMLogger.get_current_instance()
if (self.by_epoch):
logger.info(
f'The model is pruned at {self._epoch}th epoch once.')
else:
logger.info(
f'The model is pruned at {self._iter}th iter once.')
logger = MMLogger.get_current_instance()
if (self.by_epoch):
logger.info(
f'The model is pruned at {self._epoch}th epoch once.')
else:
logger.info(
f'The model is pruned at {self._iter}th iter once.')

return super().forward(inputs, data_samples, mode)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, num_channels: int, choice_mode='number', **kwargs):
super().__init__(num_channels, **kwargs)
assert choice_mode in ['ratio', 'number']
self.choice_mode = choice_mode
self.mask = torch.ones([self.num_channels]).bool()

@property
def is_num_mode(self):
Expand All @@ -50,14 +49,13 @@ def current_choice(self, choice: Union[int, float]):
int_choice = self._ratio2num(choice)
else:
int_choice = choice
mask = torch.zeros([self.num_channels], device=self.mask.device)
mask[0:int_choice] = 1
self.mask = mask.bool()
self.mask.fill_(0.0)
self.mask[0:int_choice] = 1.0

@property
def current_mask(self) -> torch.Tensor:
"""Return current mask."""
return self.mask
return self.mask.bool()

# methods for

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ class SimpleMutableChannel(BaseMutableChannel):

def __init__(self, num_channels: int, **kwargs) -> None:
super().__init__(num_channels, **kwargs)
self.mask = torch.ones(num_channels).bool()
mask = torch.ones([self.num_channels
]) # save bool as float for dist training
self.register_buffer('mask', mask)
self.mask: torch.Tensor

# choice

Expand All @@ -32,7 +35,7 @@ def current_choice(self) -> torch.Tensor:
@current_choice.setter
def current_choice(self, choice: torch.Tensor):
"""Set current choice."""
self.mask = choice.to(self.mask.device).bool()
self.mask = choice.to(self.mask.device).float()

@property
def current_mask(self) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ interrogate
isort==4.3.21
nbconvert
nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet.
pytest
xdoctest >= 0.10.0
yapf
27 changes: 27 additions & 0 deletions tests/test_models/test_algorithms/test_prune_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,30 @@ def test_dist_init(self):
algorithm.forward(
data['inputs'], data['data_samples'], mode='loss')
self.assertEqual(algorithm.step_freq, epoch_step * iter_per_epoch)

def test_resume(self):
algorithm: ItePruneAlgorithm = ItePruneAlgorithm(
MODEL_CFG,
mutator_cfg=MUTATOR_CONFIG_NUM,
target_pruning_ratio=None,
step_freq=1,
prune_times=1,
).to(DEVICE)
algorithm.mutator.set_choices(algorithm.mutator.sample_choices())
state_dict = algorithm.state_dict()
print(state_dict.keys())

algorithm2: ItePruneAlgorithm = ItePruneAlgorithm(
MODEL_CFG,
mutator_cfg=MUTATOR_CONFIG_NUM,
target_pruning_ratio=None,
step_freq=1,
prune_times=1,
).to(DEVICE)

algorithm2.load_state_dict(state_dict)

print(algorithm.mutator.current_choices)
print(algorithm2.mutator.current_choices)
self.assertDictEqual(algorithm.mutator.current_choices,
algorithm2.mutator.current_choices)

0 comments on commit ae1af1d

Please sign in to comment.