From b897cfdfcb2737653d15155979701271a8d2c80e Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Mon, 8 Aug 2022 09:21:19 +0800 Subject: [PATCH 1/9] fix lint --- mmrazor/models/mutables/__init__.py | 5 +- mmrazor/models/mutables/derived_mutable.py | 196 ++++++++++++++++++ .../mutable_channel/mutable_channel.py | 7 +- .../one_shot_mutable_channel.py | 95 ++++++++- .../models/mutables/mutable_value/__init__.py | 4 + .../mutables/mutable_value/mutable_value.py | 126 +++++++++++ .../test_mutables/test_derived_mutable.py | 81 ++++++++ 7 files changed, 503 insertions(+), 11 deletions(-) create mode 100644 mmrazor/models/mutables/derived_mutable.py create mode 100644 mmrazor/models/mutables/mutable_value/__init__.py create mode 100644 mmrazor/models/mutables/mutable_value/mutable_value.py create mode 100644 tests/test_models/test_mutables/test_derived_mutable.py diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 94dce2a7d..123e597ae 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -1,12 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .derived_mutable import DerivedMutable from .mutable_channel import (MutableChannel, OneShotMutableChannel, SlimmableMutableChannel) from .mutable_manage_mixin import MutableManageMixIn from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, OneShotMutableModule, OneShotMutableOP) +from .mutable_value import MutableValue, OneShotMutableValue __all__ = [ 'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', 'DiffChoiceRoute', 'DiffMutableModule', 'MutableManageMixIn', - 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel' + 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel', + 'DerivedMutable', 'MutableValue', 'OneShotMutableValue' ] diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py new file mode 100644 index 000000000..6c75ce015 --- /dev/null +++ b/mmrazor/models/mutables/derived_mutable.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Callable, Dict, Iterable, Optional, Protocol + +import torch +from mmcls.models.utils import make_divisible +from torch import Tensor + +from .base_mutable import CHOICE_TYPE, BaseMutable + + +class MutableProtocol(Protocol): + + @property + def current_choice(self) -> Any: + ... + + def derive_expand_mutable(self, expand_ratio: int) -> Any: + ... + + def derive_divide_mutable(self, ratio: int, divisor: int) -> Any: + ... + + +class ChannelMutableProtocol(MutableProtocol): + + @property + def current_mask(self) -> Tensor: + ... + + +def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: + + def fn(): + return mutable.current_choice * expand_ratio + + return fn + + +def _expand_mask_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: + if not hasattr(mutable, 'current_mask'): + raise ValueError('mutable must have attribute `currnet_mask`') + + def fn(): + mask = mutable.current_mask + expand_num_channels = mask.size(0) * expand_ratio + expand_choice = mutable.current_choice * expand_ratio + expand_mask = torch.zeros(expand_num_channels).bool() + expand_mask[:expand_choice] = True + + return expand_mask + + return fn + + +def _divide_and_divise(x: int, ratio: int, divisor: int = 8) -> int: + new_x = x // ratio + + return make_divisible(new_x, divisor) + + +def _divide_choice_fn(mutable: MutableProtocol, + ratio: int, + divisor: int = 8) -> Callable: + + def fn(): + return _divide_and_divise(mutable.current_choice, ratio, divisor) + + return fn + + +def _divide_mask_fn(mutable: MutableProtocol, + ratio: int, + divisor: int = 8) -> Callable: + if not hasattr(mutable, 'current_mask'): + raise ValueError('mutable must have attribute `currnet_mask`') + + def fn(): + mask = mutable.current_mask + divide_num_channels = _divide_and_divise(mask.size(0), ratio, divisor) + divide_choice = _divide_and_divise(mutable.current_choice, ratio, + divisor) + divide_mask = torch.zeros(divide_num_channels).bool() + divide_mask[:divide_choice] = True + + return divide_mask + + return fn + + +def _concat_choice_fn(mutables: Iterable[ChannelMutableProtocol]) -> Callable: + + def fn(): + return sum((m.current_choice for m in mutables)) + + return fn + + +def _concat_mask_fn(mutables: Iterable[ChannelMutableProtocol]) -> Callable: + for mutable in mutables: + if not hasattr(mutable, 'current_mask'): + raise ValueError('mutable must have attribute `currnet_mask`') + + def fn(): + return torch.cat([m.current_mask for m in mutables]) + + return fn + + +class DerivedMethodMixin: + + def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable': + return self.derive_expand_mutable(expand_ratio=1) + + def derive_expand_mutable(self: MutableProtocol, + expand_ratio: int) -> 'DerivedMutable': + choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) + + mask_fn: Optional[Callable] = None + if hasattr(self, 'current_mask'): + mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio) + + return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) + + def derive_divide_mutable(self: MutableProtocol, + ratio: int, + divisor: int = 8) -> 'DerivedMutable': + choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor) + + mask_fn: Optional[Callable] = None + if hasattr(self, 'current_mask'): + mask_fn = _divide_mask_fn(self, ratio=ratio, divisor=divisor) + + return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) + + @staticmethod + def derive_concat_mutable( + mutables: Iterable[ChannelMutableProtocol]) -> 'DerivedMutable': + choice_fn = _concat_choice_fn(mutables) + mask_fn = _concat_mask_fn(mutables) + + return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) + + +class DerivedMutable(BaseMutable[CHOICE_TYPE, Dict], DerivedMethodMixin): + + def __init__(self, + choice_fn: Callable, + mask_fn: Optional[Callable] = None, + alias: Optional[str] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(alias, init_cfg) + + self.choice_fn = choice_fn + self.mask_fn = mask_fn + + # TODO + # has no effect + def fix_chosen(self, chosen: Dict) -> None: + if self.is_fixed: + raise RuntimeError('DerivedMutable can not be fixed twice') + + self.is_fixed = True + + def dump_chosen(self) -> Dict: + return dict(current_choice=self.current_choice) + + @property + def num_choices(self) -> int: + return 1 + + @property + def current_choice(self) -> CHOICE_TYPE: + return self.choice_fn() + + @current_choice.setter + def current_choice(self, choice: CHOICE_TYPE) -> None: + raise RuntimeError('Choice of drived mutable can not be set!') + + @property + def current_mask(self) -> Tensor: + if self.mask_fn is None: + raise RuntimeError( + '`mask_fn` must be set before access `current_mask`') + return self.mask_fn() + + # TODO + # should be __str__? but can not provide info when debug + def __repr__(self) -> str: + s = f'{self.__class__.__name__}(' + if self.choice_fn is not None: + s += f'current_choice={self.current_choice}, ' + if self.mask_fn is not None: + s += f'activated_mask_nums={self.current_mask.sum().item()}, ' + s += f'is_fixed={self.is_fixed})' + + return s diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel.py b/mmrazor/models/mutables/mutable_channel/mutable_channel.py index e0bbf62d9..a2d301b35 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel.py @@ -5,9 +5,11 @@ import torch from ..base_mutable import CHOICE_TYPE, CHOSEN_TYPE, BaseMutable +from ..derived_mutable import DerivedMethodMixin -class MutableChannel(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE]): +class MutableChannel(BaseMutable[CHOICE_TYPE, CHOSEN_TYPE], + DerivedMethodMixin): """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In single path supernet, each module only has one choice invoked at the same time. A path is obtained by sampling all the available choices. It is the @@ -102,9 +104,6 @@ def fix_chosen(self, chosen: CHOSEN_TYPE) -> None: 'The mode of current MUTABLE is `fixed`. ' 'Please do not call `fix_chosen` function again.') - # TODO - # should fixed op still have candidate_choices? - self._candidate_choices = [chosen] self._chosen = chosen self.is_fixed = True diff --git a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py index 58327ecdf..810f7fee4 100644 --- a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py @@ -1,15 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional, Union import numpy as np import torch from mmrazor.registry import MODELS +from ..derived_mutable import DerivedMutable +from ..mutable_value import OneShotMutableValue from .mutable_channel import MutableChannel @MODELS.register_module() -class OneShotMutableChannel(MutableChannel[int, int]): +class OneShotMutableChannel(MutableChannel[int, Dict]): """A type of ``MUTABLES`` for single path supernet such as AutoSlim. In single path supernet, each module only has one choice invoked at the same time. A path is obtained by sampling all the available choices. It is the @@ -36,7 +38,7 @@ class OneShotMutableChannel(MutableChannel[int, int]): def __init__(self, num_channels: int, - candidate_choices: List, + candidate_choices: List[Union[int, float]], candidate_mode: str = 'ratio', init_cfg: Optional[Dict] = None): super(OneShotMutableChannel, self).__init__( @@ -108,7 +110,7 @@ def current_choice(self, choice: int): self._current_choice = choice @property - def choices(self) -> List[int]: + def choices(self) -> List: """list: all choices. """ if self._candidate_mode == 'number': return self._candidate_choices @@ -129,5 +131,86 @@ def convert_choice_to_mask(self, choice: int) -> torch.Tensor: mask[:num_channels] = True return mask - def dump_chosen(self) -> int: - return self.current_choice + def dump_chosen(self) -> Dict: + assert self.current_choice is not None + + return dict( + current_choice=self.current_choice, + origin_channels=self.num_channels) + + def fix_chosen(self, dumped_chosen: Dict) -> None: + if self.is_fixed: + raise RuntimeError('OneShotMutableChannel can not be fixed twice') + + current_choice = dumped_chosen['current_choice'] + origin_channels = dumped_chosen['origin_channels'] + + assert current_choice <= origin_channels + assert origin_channels == self.num_channels + + self.current_choice = current_choice + self.is_fixed = True + + def __repr__(self): + concat_mutable_name = [ + mutable.name for mutable in self.concat_parent_mutables + ] + repr_str = self.__class__.__name__ + repr_str += f'(name={self.name}, ' + repr_str += f'num_channels={self.num_channels}, ' + repr_str += f'current_choice={self.current_choice}, ' + repr_str += f'choices={self.choices}, ' + repr_str += f'current_mask_shape={self.current_mask.shape}, ' + repr_str += f'concat_mutable_name={concat_mutable_name})' + return repr_str + + def __rmul__(self, other) -> DerivedMutable: + return self * other + + def __mul__(self, other) -> DerivedMutable: + if isinstance(other, int): + return self.derive_expand_mutable(other) + + def expand_choice_fn(mutable1: 'OneShotMutableChannel', + mutable2: OneShotMutableValue) -> Callable: + + def fn(): + return mutable1.current_choice * mutable2.current_choice + + return fn + + def expand_mask_fn(mutable1: 'OneShotMutableChannel', + mutable2: OneShotMutableValue) -> Callable: + + def fn(): + mask = mutable1.current_mask + max_expand_ratio = mutable2.max_choice + current_expand_ratio = mutable2.current_choice + expand_num_channels = mask.size(0) * max_expand_ratio + + expand_choice = mutable1.current_choice * current_expand_ratio + expand_mask = torch.zeros(expand_num_channels).bool() + expand_mask[:expand_choice] = True + + return expand_mask + + return fn + + if isinstance(other, OneShotMutableValue): + return DerivedMutable( + choice_fn=expand_choice_fn(self, other), + mask_fn=expand_mask_fn(self, other)) + + raise TypeError(f'Unsupported type {type(other)} for mul!') + + def __rdiv__(self, other) -> DerivedMutable: + return self / other + + def __div__(self, other) -> DerivedMutable: + if isinstance(other, int): + return self.derive_divide_mutable(other) + if isinstance(other, tuple): + assert len(other) == 2 + return self.derive_divide_mutable(*other) + + raise TypeError(f'Unsupported type {type(other)} for div!') diff --git a/mmrazor/models/mutables/mutable_value/__init__.py b/mmrazor/models/mutables/mutable_value/__init__.py new file mode 100644 index 000000000..f83c93fe9 --- /dev/null +++ b/mmrazor/models/mutables/mutable_value/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mutable_value import MutableValue, OneShotMutableValue + +__all__ = ['MutableValue', 'OneShotMutableValue'] diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py new file mode 100644 index 000000000..aed0f0ff9 --- /dev/null +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from typing import Any, Dict, List, Optional + +from mmrazor.registry import MODELS +from ..base_mutable import BaseMutable +from ..derived_mutable import DerivedMethodMixin, DerivedMutable + + +@MODELS.register_module() +class MutableValue(BaseMutable[Any, Dict], DerivedMethodMixin): + + def __init__(self, + value_list: List[Any], + default_value: Optional[Any] = None, + alias: Optional[str] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(alias, init_cfg) + + self._check_is_same_type(value_list) + self._value_list = value_list + + if default_value is None: + default_value = value_list[0] + self.current_choice = default_value + + @staticmethod + def _check_is_same_type(value_list: List[Any]) -> None: + if len(value_list) == 1: + return + + for i in range(1, len(value_list)): + is_same_type = type(value_list[i - 1]) is \ + type(value_list[i]) # noqa: E721 + if not is_same_type: + raise TypeError( + 'All elements in `value_list` must have same ' + f'type, but both types {type(value_list[i-1])} ' + f'and type {type(value_list[i])} exist.') + + @property + def choices(self) -> List[Any]: + return self._value_list + + def fix_chosen(self, chosen: Dict[str, Any]) -> None: + if self.is_fixed: + raise RuntimeError('MutableValue can not be fixed twice') + + all_choices = chosen['all_choices'] + current_choice = chosen['current_choice'] + + assert all_choices == self.choices, \ + f'Expect choices to be: {self.choices}, but got: {all_choices}' + assert current_choice in self.choices + + self.current_choice = current_choice + self.is_fixed = True + + def dump_chosen(self) -> Dict[str, Any]: + return dict( + current_choice=self.current_choice, all_choices=self.choices) + + def num_choices(self) -> int: + return len(self.choices) + + @property + def current_choice(self) -> Optional[Any]: + return self._current_choice + + @current_choice.setter + def current_choice(self, choice: Any) -> Any: + if choice not in self.choices: + raise ValueError(f'Expected choice in: {self.choices}, ' + f'but got: {choice}') + + self._current_choice = choice + + def __repr__(self) -> str: + s = self.__class__.__name__ + s += f'(value_list={self._value_list}, ' + s += f'current_choice={self.current_choice})' + + return s + + def __rmul__(self, other) -> DerivedMutable: + return self * other + + def __mul__(self, other) -> DerivedMutable: + if isinstance(other, int): + return self.derive_expand_mutable(other) + + raise TypeError(f'Unsupported type {type(other)} for mul!') + + +# TODO +# 1. use comparable for type hint +# 2. use mixin +@MODELS.register_module() +class OneShotMutableValue(MutableValue): + + def __init__(self, + value_list: List[Any], + default_value: Optional[Any] = None, + alias: Optional[str] = None, + init_cfg: Optional[Dict] = None) -> None: + value_list = sorted(value_list) + # set default value as max value + if default_value is None: + default_value = value_list[-1] + + super().__init__( + value_list=value_list, + default_value=default_value, + alias=alias, + init_cfg=init_cfg) + + def sample_choice(self) -> Any: + return random.choice(self.choices) + + @property + def max_choice(self) -> Any: + return self.choices[-1] + + @property + def min_choice(self) -> Any: + return self.choices[0] diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py new file mode 100644 index 000000000..6885d47a6 --- /dev/null +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch + +from mmrazor.models.mutables import (DerivedMutable, OneShotMutableChannel, + OneShotMutableValue) +from mmrazor.models.mutables.base_mutable import BaseMutable + + +class TestDerivedMutable(TestCase): + + def test_mutable_drived(self) -> None: + mv = OneShotMutableValue(value_list=[3, 5, 7]) + + mv_derived = mv * 4 + assert isinstance(mv_derived, BaseMutable) + assert isinstance(mv_derived, DerivedMutable) + assert not mv_derived.is_fixed + assert mv_derived.num_choices == 1 + + mv.current_choice = mv.max_choice + assert mv_derived.current_choice == 28 + mv.current_choice = mv.min_choice + assert mv_derived.current_choice == 12 + + with pytest.raises(RuntimeError): + mv_derived.current_choice = 123 + with pytest.raises(RuntimeError): + _ = mv_derived.current_mask + + chosen = mv_derived.dump_chosen() + assert chosen == {'current_choice': 12} + mv_derived.fix_chosen(chosen) + assert mv_derived.is_fixed + + mv.current_choice = 5 + assert mv_derived.current_choice == 20 + + def test_mutable_concat_derived(self) -> None: + mc1 = OneShotMutableChannel( + num_channels=3, candidate_choices=[1, 3], candidate_mode='number') + mc2 = OneShotMutableChannel( + num_channels=4, candidate_choices=[1, 4], candidate_mode='number') + ms = [mc1, mc2] + + mc_derived = DerivedMutable.derive_concat_mutable(ms) + + mc1.current_choice = 1 + mc2.current_choice = 4 + assert mc_derived.current_choice == 5 + assert torch.equal( + mc_derived.current_mask, + torch.tensor([1, 0, 0, 1, 1, 1, 1], dtype=torch.bool)) + + mc1.current_choice = 1 + mc2.current_choice = 1 + assert mc_derived.current_choice == 2 + assert torch.equal( + mc_derived.current_mask, + torch.tensor([1, 0, 0, 1, 0, 0, 0], dtype=torch.bool)) + + def test_mutable_channel_derived(self) -> None: + mc = OneShotMutableChannel( + num_channels=3, + candidate_choices=[1, 2, 3], + candidate_mode='number') + mc_derived = mc * 3 + + mc.current_choice = 1 + assert mc_derived.current_choice == 3 + assert torch.equal( + mc_derived.current_mask, + torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0], dtype=torch.bool)) + + mc.current_choice = 2 + assert mc_derived.current_choice == 6 + assert torch.equal( + mc_derived.current_mask, + torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=torch.bool)) From 2313ba262e68053aba655a1ee9092835bd1e4c71 Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Mon, 8 Aug 2022 13:59:31 +0800 Subject: [PATCH 2/9] complement unittest for derived mutable --- mmrazor/models/mutables/derived_mutable.py | 76 ++++++++- .../mutable_channel/mutable_channel.py | 3 + .../one_shot_mutable_channel.py | 7 +- .../slimmable_mutable_channel.py | 1 + .../mutables/mutable_value/mutable_value.py | 9 + .../test_mutables/test_derived_mutable.py | 157 +++++++++++++++--- 6 files changed, 215 insertions(+), 38 deletions(-) diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 6c75ce015..73286ccd0 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -1,14 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Dict, Iterable, Optional, Protocol +import inspect +from collections.abc import Iterable +from typing import Any, Callable, Dict, Optional, Protocol, Set, Union import torch from mmcls.models.utils import make_divisible +from mmengine.logging import MMLogger from torch import Tensor from .base_mutable import CHOICE_TYPE, BaseMutable +logger = MMLogger.get_current_instance() -class MutableProtocol(Protocol): + +class MutableProtocol(Protocol): # pragma: no cover @property def current_choice(self) -> Any: @@ -21,7 +26,7 @@ def derive_divide_mutable(self, ratio: int, divisor: int) -> Any: ... -class ChannelMutableProtocol(MutableProtocol): +class ChannelMutableProtocol(MutableProtocol): # pragma: no cover @property def current_mask(self) -> Tensor: @@ -36,7 +41,8 @@ def fn(): return fn -def _expand_mask_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: +def _expand_mask_fn(mutable: MutableProtocol, + expand_ratio: int) -> Callable: # pragma: no cover if not hasattr(mutable, 'current_mask'): raise ValueError('mutable must have attribute `currnet_mask`') @@ -70,7 +76,7 @@ def fn(): def _divide_mask_fn(mutable: MutableProtocol, ratio: int, - divisor: int = 8) -> Callable: + divisor: int = 8) -> Callable: # pragma: no cover if not hasattr(mutable, 'current_mask'): raise ValueError('mutable must have attribute `currnet_mask`') @@ -98,7 +104,7 @@ def fn(): def _concat_mask_fn(mutables: Iterable[ChannelMutableProtocol]) -> Callable: for mutable in mutables: if not hasattr(mutable, 'current_mask'): - raise ValueError('mutable must have attribute `currnet_mask`') + raise RuntimeError('mutable must have attribute `currnet_mask`') def fn(): return torch.cat([m.current_mask for m in mutables]) @@ -146,6 +152,7 @@ class DerivedMutable(BaseMutable[CHOICE_TYPE, Dict], DerivedMethodMixin): def __init__(self, choice_fn: Callable, mask_fn: Optional[Callable] = None, + source_mutables: Optional[Iterable[BaseMutable]] = None, alias: Optional[str] = None, init_cfg: Optional[Dict] = None) -> None: super().__init__(alias, init_cfg) @@ -153,6 +160,20 @@ def __init__(self, self.choice_fn = choice_fn self.mask_fn = mask_fn + if source_mutables is None: + source_mutables = self._find_source_mutables() + if len(source_mutables) == 0: + # TODO + # warning or raise error? + logger.warning('Can not find source mutables automatically') + else: + source_mutables = set(source_mutables) + for mutable in source_mutables: + if not self.is_source_mutable(mutable): + raise ValueError('Expect all mutable to be source mutable, ' + f'but {mutable} is not') + self.source_mutables = source_mutables + # TODO # has no effect def fix_chosen(self, chosen: Dict) -> None: @@ -183,14 +204,53 @@ def current_mask(self) -> Tensor: '`mask_fn` must be set before access `current_mask`') return self.mask_fn() + @staticmethod + def _extract_source_mutables_from_fn(fn: Callable) -> Set[BaseMutable]: + source_mutables: Set[BaseMutable] = set() + + def add_mutables_dfs( + mutable: Union[Iterable, BaseMutable, Dict]) -> None: + nonlocal source_mutables + if isinstance(mutable, BaseMutable): + if isinstance(mutable, DerivedMutable): + source_mutables |= mutable.source_mutables + else: + source_mutables.add(mutable) + # dict is also iterable, should parse first + elif isinstance(mutable, dict): + add_mutables_dfs(mutable.values()) + add_mutables_dfs(mutable.keys()) + elif isinstance(mutable, Iterable): + for m in mutable: + add_mutables_dfs(m) + + noncolcal_pars = inspect.getclosurevars(fn).nonlocals + add_mutables_dfs(noncolcal_pars.values()) + + return source_mutables + + def _find_source_mutables(self) -> Set[BaseMutable]: + source_mutables = self._extract_source_mutables_from_fn(self.choice_fn) + if self.mask_fn is not None: + source_mutables |= self._extract_source_mutables_from_fn( + self.mask_fn) + + return source_mutables + + @staticmethod + def is_source_mutable(mutable: BaseMutable) -> bool: + return isinstance(mutable, BaseMutable) and \ + not isinstance(mutable, DerivedMutable) + # TODO # should be __str__? but can not provide info when debug - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover s = f'{self.__class__.__name__}(' if self.choice_fn is not None: s += f'current_choice={self.current_choice}, ' if self.mask_fn is not None: - s += f'activated_mask_nums={self.current_mask.sum().item()}, ' + s += f'activated_channels={self.current_mask.sum().item()}, ' + s += f'source_mutables={self.source_mutables}, ' s += f'is_fixed={self.is_fixed})' return s diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel.py b/mmrazor/models/mutables/mutable_channel/mutable_channel.py index a2d301b35..f3ba2063e 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel.py @@ -33,6 +33,7 @@ def __init__(self, num_channels: int, **kwargs): # outputs, we add the mutable out of these modules to the # `concat_parent_mutables` of this module. self.concat_parent_mutables: List[MutableChannel] = list() + self.name = 'unbind' @property def same_mutables(self): @@ -104,6 +105,8 @@ def fix_chosen(self, chosen: CHOSEN_TYPE) -> None: 'The mode of current MUTABLE is `fixed`. ' 'Please do not call `fix_chosen` function again.') + # TODO + # should fixed op still have candidate_choices? self._chosen = chosen self.is_fixed = True diff --git a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py index 810f7fee4..8fde79a95 100644 --- a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py @@ -160,7 +160,7 @@ def __repr__(self): repr_str += f'num_channels={self.num_channels}, ' repr_str += f'current_choice={self.current_choice}, ' repr_str += f'choices={self.choices}, ' - repr_str += f'current_mask_shape={self.current_mask.shape}, ' + repr_str += f'activated_channels={self.current_mask.sum().item()}, ' repr_str += f'concat_mutable_name={concat_mutable_name})' return repr_str @@ -203,10 +203,7 @@ def fn(): raise TypeError(f'Unsupported type {type(other)} for mul!') - def __rdiv__(self, other) -> DerivedMutable: - return self / other - - def __div__(self, other) -> DerivedMutable: + def __floordiv__(self, other) -> DerivedMutable: if isinstance(other, int): return self.derive_divide_mutable(other) if isinstance(other, tuple): diff --git a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py index dda61814a..ebf8b41ef 100644 --- a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py @@ -73,6 +73,7 @@ def fix_chosen(self, dumped_chosen: Dict) -> None: # TODO # remove after remove `current_choice` self.current_choice = self._candidate_choices.index(chosen) + self._candidate_choices = [chosen] super().fix_chosen(chosen) diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py index aed0f0ff9..55e7a8801 100644 --- a/mmrazor/models/mutables/mutable_value/mutable_value.py +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -91,6 +91,15 @@ def __mul__(self, other) -> DerivedMutable: raise TypeError(f'Unsupported type {type(other)} for mul!') + def __floordiv__(self, other) -> DerivedMutable: + if isinstance(other, int): + return self.derive_divide_mutable(other) + if isinstance(other, tuple): + assert len(other) == 2 + return self.derive_divide_mutable(*other) + + raise TypeError(f'Unsupported type {type(other)} for div!') + # TODO # 1. use comparable for type hint diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py index 6885d47a6..efd32606e 100644 --- a/tests/test_models/test_mutables/test_derived_mutable.py +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -11,32 +11,18 @@ class TestDerivedMutable(TestCase): - def test_mutable_drived(self) -> None: - mv = OneShotMutableValue(value_list=[3, 5, 7]) - - mv_derived = mv * 4 - assert isinstance(mv_derived, BaseMutable) - assert isinstance(mv_derived, DerivedMutable) - assert not mv_derived.is_fixed - assert mv_derived.num_choices == 1 - - mv.current_choice = mv.max_choice - assert mv_derived.current_choice == 28 - mv.current_choice = mv.min_choice - assert mv_derived.current_choice == 12 - - with pytest.raises(RuntimeError): - mv_derived.current_choice = 123 - with pytest.raises(RuntimeError): - _ = mv_derived.current_mask - - chosen = mv_derived.dump_chosen() - assert chosen == {'current_choice': 12} - mv_derived.fix_chosen(chosen) - assert mv_derived.is_fixed + def test_derived_same_mutable(self) -> None: + mc = OneShotMutableChannel( + num_channels=3, + candidate_choices=[1, 2, 3], + candidate_mode='number') + mc_derived = mc.derive_same_mutable() + assert mc_derived.source_mutables == {mc} - mv.current_choice = 5 - assert mv_derived.current_choice == 20 + mc.current_choice = 2 + assert mc_derived.current_choice == 2 + assert torch.equal(mc_derived.current_mask, + torch.tensor([1, 1, 0], dtype=torch.bool)) def test_mutable_concat_derived(self) -> None: mc1 = OneShotMutableChannel( @@ -46,6 +32,7 @@ def test_mutable_concat_derived(self) -> None: ms = [mc1, mc2] mc_derived = DerivedMutable.derive_concat_mutable(ms) + assert mc_derived.source_mutables == set(ms) mc1.current_choice = 1 mc2.current_choice = 4 @@ -61,12 +48,18 @@ def test_mutable_concat_derived(self) -> None: mc_derived.current_mask, torch.tensor([1, 0, 0, 1, 0, 0, 0], dtype=torch.bool)) + mv = OneShotMutableValue(value_list=[1, 2, 3]) + ms = [mc1, mv] + with pytest.raises(RuntimeError): + _ = DerivedMutable.derive_concat_mutable(ms) + def test_mutable_channel_derived(self) -> None: mc = OneShotMutableChannel( num_channels=3, candidate_choices=[1, 2, 3], candidate_mode='number') mc_derived = mc * 3 + assert mc_derived.source_mutables == {mc} mc.current_choice = 1 assert mc_derived.current_choice == 3 @@ -79,3 +72,117 @@ def test_mutable_channel_derived(self) -> None: assert torch.equal( mc_derived.current_mask, torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=torch.bool)) + + def test_mutable_divide(self) -> None: + mc = OneShotMutableChannel( + num_channels=128, + candidate_choices=[112, 120, 128], + candidate_mode='number') + mc_derived = mc // 8 + assert mc_derived.source_mutables == {mc} + + mc.current_choice = 128 + assert mc_derived.current_choice == 16 + assert torch.equal(mc_derived.current_mask, + torch.ones(16, dtype=torch.bool)) + mc.current_choice = 120 + assert mc_derived.current_choice == 16 + assert torch.equal(mc_derived.current_mask, + torch.ones(16, dtype=torch.bool)) + + mv = OneShotMutableValue(value_list=[112, 120, 128]) + mv_derived = mv // 8 + assert mv_derived.source_mutables == {mv} + + mv.current_choice == 128 + assert mv_derived.current_choice == 16 + mv.current_choice == 120 + assert mv_derived.current_choice == 16 + + def test_double_fixed(self) -> None: + choice_fn = lambda x: x # noqa: E731 + derived_mutable = DerivedMutable(choice_fn) + derived_mutable.fix_chosen({}) + + with pytest.raises(RuntimeError): + derived_mutable.fix_chosen({}) + + def test_source_mutables(self) -> None: + mc1 = OneShotMutableChannel( + num_channels=3, candidate_choices=[1, 3], candidate_mode='number') + mc2 = OneShotMutableChannel( + num_channels=4, candidate_choices=[1, 4], candidate_mode='number') + ms = [mc1, mc2] + + mc_derived1 = DerivedMutable.derive_concat_mutable(ms) + + from mmrazor.models.mutables.derived_mutable import (_concat_choice_fn, + _concat_mask_fn) + mc_derived2 = DerivedMutable( + choice_fn=_concat_choice_fn(ms), + mask_fn=_concat_mask_fn(ms), + source_mutables=ms) + assert mc_derived1.source_mutables == mc_derived2.source_mutables + + dd_mutable = mc_derived1.derive_same_mutable() + assert dd_mutable.source_mutables == mc_derived1.source_mutables + + with pytest.raises(ValueError): + _ = DerivedMutable( + choice_fn=lambda x: x, source_mutables=[mc_derived1]) + + def dict_closure_fn(x, y): + + def fn(): + nonlocal x, y + + return fn + + ddd_mutable = DerivedMutable( + choice_fn=dict_closure_fn({ + mc1: [2, 3], + mc2: 2 + }, None), + mask_fn=dict_closure_fn({2: [mc1, mc2]}, {3: dd_mutable})) + assert ddd_mutable.source_mutables == mc_derived1.source_mutables + + mc3 = OneShotMutableChannel( + num_channels=4, candidate_choices=[2, 4], candidate_mode='number') + dddd_mutable = DerivedMutable( + choice_fn=dict_closure_fn({ + mc1: [2, 3], + mc2: 2 + }, []), + mask_fn=dict_closure_fn({2: [mc1, mc2, mc3]}, {3: dd_mutable})) + assert dddd_mutable.source_mutables == {mc1, mc2, mc3} + + +@pytest.mark.parametrize('expand_ratio', [1, 2, 3]) +def test_derived_expand_mutable(expand_ratio: int) -> None: + mv = OneShotMutableValue(value_list=[3, 5, 7]) + + mv_derived = mv * expand_ratio + assert mv_derived.source_mutables == {mv} + + assert isinstance(mv_derived, BaseMutable) + assert isinstance(mv_derived, DerivedMutable) + assert not mv_derived.is_fixed + assert mv_derived.num_choices == 1 + + mv.current_choice = mv.max_choice + assert mv_derived.current_choice == mv.current_choice * expand_ratio + mv.current_choice = mv.min_choice + assert mv_derived.current_choice == mv.current_choice * expand_ratio + + with pytest.raises(RuntimeError): + mv_derived.current_choice = 123 + with pytest.raises(RuntimeError): + _ = mv_derived.current_mask + + chosen = mv_derived.dump_chosen() + assert chosen == {'current_choice': mv.current_choice * expand_ratio} + mv_derived.fix_chosen(chosen) + assert mv_derived.is_fixed + + mv.current_choice = 5 + assert mv_derived.current_choice == 5 * expand_ratio From 20a06cedd94efbe1754051cf4b213dd82a81c45c Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Mon, 8 Aug 2022 14:57:16 +0800 Subject: [PATCH 3/9] add docstring for derived mutable --- mmrazor/models/mutables/derived_mutable.py | 113 +++++++++++++++++++-- 1 file changed, 102 insertions(+), 11 deletions(-) diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 73286ccd0..9267ceecc 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -14,6 +14,7 @@ class MutableProtocol(Protocol): # pragma: no cover + """Protocol for Mutable.""" @property def current_choice(self) -> Any: @@ -26,7 +27,8 @@ def derive_divide_mutable(self, ratio: int, divisor: int) -> Any: ... -class ChannelMutableProtocol(MutableProtocol): # pragma: no cover +class MutableChannelProtocol(MutableProtocol): # pragma: no cover + """Protocol for MutableChannel.""" @property def current_mask(self) -> Tensor: @@ -34,6 +36,7 @@ def current_mask(self) -> Tensor: def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: + """Helper function to build `choice_fn` for expand derived mutable.""" def fn(): return mutable.current_choice * expand_ratio @@ -43,6 +46,7 @@ def fn(): def _expand_mask_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: # pragma: no cover + """Helper function to build `mask_fn` for expand derived mutable.""" if not hasattr(mutable, 'current_mask'): raise ValueError('mutable must have attribute `currnet_mask`') @@ -59,6 +63,7 @@ def fn(): def _divide_and_divise(x: int, ratio: int, divisor: int = 8) -> int: + """Helper function for divide and divise.""" new_x = x // ratio return make_divisible(new_x, divisor) @@ -67,6 +72,7 @@ def _divide_and_divise(x: int, ratio: int, divisor: int = 8) -> int: def _divide_choice_fn(mutable: MutableProtocol, ratio: int, divisor: int = 8) -> Callable: + """Helper function to build `choice_fn` for divide derived mutable.""" def fn(): return _divide_and_divise(mutable.current_choice, ratio, divisor) @@ -77,6 +83,7 @@ def fn(): def _divide_mask_fn(mutable: MutableProtocol, ratio: int, divisor: int = 8) -> Callable: # pragma: no cover + """Helper function to build `mask_fn` for divide derived mutable.""" if not hasattr(mutable, 'current_mask'): raise ValueError('mutable must have attribute `currnet_mask`') @@ -93,7 +100,8 @@ def fn(): return fn -def _concat_choice_fn(mutables: Iterable[ChannelMutableProtocol]) -> Callable: +def _concat_choice_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: + """Helper function to build `choice_fn` for concat derived mutable.""" def fn(): return sum((m.current_choice for m in mutables)) @@ -101,7 +109,8 @@ def fn(): return fn -def _concat_mask_fn(mutables: Iterable[ChannelMutableProtocol]) -> Callable: +def _concat_mask_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: + """Helper function to build `mask_fn` for concat derived mutable.""" for mutable in mutables: if not hasattr(mutable, 'current_mask'): raise RuntimeError('mutable must have attribute `currnet_mask`') @@ -140,7 +149,7 @@ def derive_divide_mutable(self: MutableProtocol, @staticmethod def derive_concat_mutable( - mutables: Iterable[ChannelMutableProtocol]) -> 'DerivedMutable': + mutables: Iterable[MutableChannelProtocol]) -> 'DerivedMutable': choice_fn = _concat_choice_fn(mutables) mask_fn = _concat_mask_fn(mutables) @@ -148,6 +157,49 @@ def derive_concat_mutable( class DerivedMutable(BaseMutable[CHOICE_TYPE, Dict], DerivedMethodMixin): + """Class for derived mutable. + + A derived mutable is a mutable derived from other mutables that has + `current_choice` and `current_mask` attributes (if any). + + Note: + A derived mutable does not have its own search space, so it is + not legal to modify its `current_choice` or `current_mask` directly. + And the only way to modify them is by modifying `current_choice` or + `current_mask` in corresponding source mutables. + + Args: + choice_fn (callable): A closure that controls how to generate + `current_choice`. + mask_fn (callable, optional): A closure that controls how to generate + `current_mask`. Defaults to None. + source_mutables (iterable, optional): Specify source mutables for this + derived mutable. If the argument is None, source mutables will be + traced automatically by parsing mutables in closure variables. + Defaults to None. + alias (str, optional): alias of the `MUTABLE`. Defaults to None. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. Defaults to None. + + Examples: + >>> from mmrazor.models.mutables import OneShotMutableChannel + >>> mutable_channel = OneShotMutableChannel( + ... num_channels=3, + ... candidate_choices=[1, 2, 3], + ... candidate_mode='number') + >>> # derive expand mutable + >>> derived_mutable_channel = mutable_channel * 2 + >>> # source mutables will be traced automatically + >>> derived_mutable_channel.source_mutables + {OneShotMutableChannel(name=unbind, num_channels=3, current_choice=3, choices=[1, 2, 3], activated_channels=3, concat_mutable_name=[])} # noqa: E501 + >>> # modify `current_choice` of `mutable_channel` + >>> mutable_channel.current_choice = 2 + >>> # `current_choice` and `current_mask` of derived mutable will be modified automatically # noqa: E501 + >>> derived_mutable_channel + DerivedMutable(current_choice=4, activated_channels=4, source_mutables={OneShotMutableChannel(name=unbind, num_channels=3, current_choice=2, choices=[1, 2, 3], activated_channels=2, concat_mutable_name=[])}, is_fixed=False) # noqa: E501 + """ def __init__(self, choice_fn: Callable, @@ -161,7 +213,7 @@ def __init__(self, self.mask_fn = mask_fn if source_mutables is None: - source_mutables = self._find_source_mutables() + source_mutables = self._trace_source_mutables() if len(source_mutables) == 0: # TODO # warning or raise error? @@ -177,35 +229,64 @@ def __init__(self, # TODO # has no effect def fix_chosen(self, chosen: Dict) -> None: + """Fix mutable with subnet config. + + Warning: + Fix derived mutable will have no actually effect. + """ if self.is_fixed: raise RuntimeError('DerivedMutable can not be fixed twice') self.is_fixed = True def dump_chosen(self) -> Dict: + """Dump information of chosen. + + Returns: + Dict: Dumped information. + """ return dict(current_choice=self.current_choice) @property def num_choices(self) -> int: + """Number of all choices. + + Note: + Since derive mutable does not have its own search space, the number + of choices will always be `1`. + + Returns: + int: Number of choices. + """ return 1 @property def current_choice(self) -> CHOICE_TYPE: + """Current choice of derived mutable.""" return self.choice_fn() @current_choice.setter def current_choice(self, choice: CHOICE_TYPE) -> None: + """Setter of current choice. + + Raises: + RuntimeError: Error when `current_choice` of derived mutable + is modified directly. + """ raise RuntimeError('Choice of drived mutable can not be set!') @property def current_mask(self) -> Tensor: + """Current mask of derived mutable.""" if self.mask_fn is None: raise RuntimeError( '`mask_fn` must be set before access `current_mask`') return self.mask_fn() @staticmethod - def _extract_source_mutables_from_fn(fn: Callable) -> Set[BaseMutable]: + def _trace_source_mutables_from_closure( + closure: Callable) -> Set[BaseMutable]: + """Trace source mutables from closure.""" source_mutables: Set[BaseMutable] = set() def add_mutables_dfs( @@ -224,21 +305,31 @@ def add_mutables_dfs( for m in mutable: add_mutables_dfs(m) - noncolcal_pars = inspect.getclosurevars(fn).nonlocals + noncolcal_pars = inspect.getclosurevars(closure).nonlocals add_mutables_dfs(noncolcal_pars.values()) return source_mutables - def _find_source_mutables(self) -> Set[BaseMutable]: - source_mutables = self._extract_source_mutables_from_fn(self.choice_fn) + def _trace_source_mutables(self) -> Set[BaseMutable]: + """Trace source mutables.""" + source_mutables = self._trace_source_mutables_from_closure( + self.choice_fn) if self.mask_fn is not None: - source_mutables |= self._extract_source_mutables_from_fn( + source_mutables |= self._trace_source_mutables_from_closure( self.mask_fn) return source_mutables @staticmethod - def is_source_mutable(mutable: BaseMutable) -> bool: + def is_source_mutable(mutable: object) -> bool: + """Judge whether an object is source mutable(not derived mutable). + + Args: + mutable (object): An object. + + Returns: + bool: Indicate whether the object is source mutable or not. + """ return isinstance(mutable, BaseMutable) and \ not isinstance(mutable, DerivedMutable) From 1504bf91c2bcfbaeb0f961da6e1ef68d89021085 Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Mon, 8 Aug 2022 16:12:36 +0800 Subject: [PATCH 4/9] add unittest for mutable value --- mmrazor/models/mutables/derived_mutable.py | 13 ++- .../mutables/mutable_value/mutable_value.py | 104 +++++++++++++++-- .../test_mutables/test_mutable_value.py | 110 ++++++++++++++++++ 3 files changed, 213 insertions(+), 14 deletions(-) create mode 100644 tests/test_models/test_mutables/test_mutable_value.py diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 9267ceecc..180b8c1c4 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -18,13 +18,13 @@ class MutableProtocol(Protocol): # pragma: no cover @property def current_choice(self) -> Any: - ... + """Current choice.""" def derive_expand_mutable(self, expand_ratio: int) -> Any: - ... + """Derive expand mutable.""" def derive_divide_mutable(self, ratio: int, divisor: int) -> Any: - ... + """Derive divide mutable.""" class MutableChannelProtocol(MutableProtocol): # pragma: no cover @@ -32,7 +32,7 @@ class MutableChannelProtocol(MutableProtocol): # pragma: no cover @property def current_mask(self) -> Tensor: - ... + """Current mask.""" def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable: @@ -122,12 +122,15 @@ def fn(): class DerivedMethodMixin: + """A mixin that provides some useful method to derive mutable.""" def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable': + """Derive same mutable as the source.""" return self.derive_expand_mutable(expand_ratio=1) def derive_expand_mutable(self: MutableProtocol, expand_ratio: int) -> 'DerivedMutable': + """Derive expand mutable, usually used with `expand_ratio`.""" choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio) mask_fn: Optional[Callable] = None @@ -139,6 +142,7 @@ def derive_expand_mutable(self: MutableProtocol, def derive_divide_mutable(self: MutableProtocol, ratio: int, divisor: int = 8) -> 'DerivedMutable': + """Derive divide mutable, usually used with `make_divisable`.""" choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor) mask_fn: Optional[Callable] = None @@ -150,6 +154,7 @@ def derive_divide_mutable(self: MutableProtocol, @staticmethod def derive_concat_mutable( mutables: Iterable[MutableChannelProtocol]) -> 'DerivedMutable': + """Derive concat mutable, usually used with `torch.cat`.""" choice_fn = _concat_choice_fn(mutables) mask_fn = _concat_mask_fn(mutables) diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py index 55e7a8801..22b7b8d21 100644 --- a/mmrazor/models/mutables/mutable_value/mutable_value.py +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import random -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union from mmrazor.registry import MODELS from ..base_mutable import BaseMutable @@ -9,6 +9,21 @@ @MODELS.register_module() class MutableValue(BaseMutable[Any, Dict], DerivedMethodMixin): + """Base class for mutable value. + + A mutable value is actually a mutable that adds some functionality to a + list containing objects of the same type. + + Args: + value_list (list): List of value, each value must have the same type. + default_value (any, optional): Default value, must be one in + `value_list`. Default to None. + alias (str, optional): alias of the `MUTABLE`. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ def __init__(self, value_list: List[Any], @@ -26,6 +41,7 @@ def __init__(self, @staticmethod def _check_is_same_type(value_list: List[Any]) -> None: + """Check whether value in `value_list` has the same type.""" if len(value_list) == 1: return @@ -40,9 +56,15 @@ def _check_is_same_type(value_list: List[Any]) -> None: @property def choices(self) -> List[Any]: + """List of choices.""" return self._value_list def fix_chosen(self, chosen: Dict[str, Any]) -> None: + """Fix mutable value with subnet config. + + Args: + chosen (dict): the information of chosen. + """ if self.is_fixed: raise RuntimeError('MutableValue can not be fixed twice') @@ -57,41 +79,66 @@ def fix_chosen(self, chosen: Dict[str, Any]) -> None: self.is_fixed = True def dump_chosen(self) -> Dict[str, Any]: + """Dump information of chosen. + + Returns: + Dict[str, Any]: Dumped information. + """ return dict( current_choice=self.current_choice, all_choices=self.choices) + @property def num_choices(self) -> int: + """Number of all choices. + + Returns: + int: Number of choices. + """ return len(self.choices) @property def current_choice(self) -> Optional[Any]: + """Current choice of mutable value.""" return self._current_choice @current_choice.setter def current_choice(self, choice: Any) -> Any: + """Setter of current choice.""" if choice not in self.choices: raise ValueError(f'Expected choice in: {self.choices}, ' f'but got: {choice}') self._current_choice = choice - def __repr__(self) -> str: - s = self.__class__.__name__ - s += f'(value_list={self._value_list}, ' - s += f'current_choice={self.current_choice})' + def __rmul__(self, other: int) -> DerivedMutable: + """Please refer to method :func:`__mul__`.""" + return self * other - return s + def __mul__(self, other: int) -> DerivedMutable: + """Overload `*` operator. - def __rmul__(self, other) -> DerivedMutable: - return self * other + Args: + other (int): Expand ratio. - def __mul__(self, other) -> DerivedMutable: + Returns: + DerivedMutable: Derived expand mutable. + """ if isinstance(other, int): return self.derive_expand_mutable(other) raise TypeError(f'Unsupported type {type(other)} for mul!') - def __floordiv__(self, other) -> DerivedMutable: + def __floordiv__(self, other: Union[int, Tuple[int, + int]]) -> DerivedMutable: + """Overload `//` operator. + + Args: + other: (int, tuple): divide ratio for int or + (divide ratio, divisor) for tuple. + + Returns: + DerivedMutable: Derived divide mutable. + """ if isinstance(other, int): return self.derive_divide_mutable(other) if isinstance(other, tuple): @@ -100,12 +147,34 @@ def __floordiv__(self, other) -> DerivedMutable: raise TypeError(f'Unsupported type {type(other)} for div!') + def __repr__(self) -> str: + s = self.__class__.__name__ + s += f'(value_list={self._value_list}, ' + s += f'current_choice={self.current_choice})' + + return s + # TODO # 1. use comparable for type hint # 2. use mixin @MODELS.register_module() class OneShotMutableValue(MutableValue): + """Class for one-shot mutable value. + + one-shot mutable value provides `sample_choice` method and `min_choice`, + `max_choice` properties on the top of mutable value. + + Args: + value_list (list): List of value, each value must have the same type. + default_value (any, optional): Default value, must be one in + `value_list`. Default to None. + alias (str, optional): alias of the `MUTABLE`. + init_cfg (dict, optional): initialization configuration dict for + ``BaseModule``. OpenMMLab has implement 5 initializer including + `Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`, + and `Pretrained`. + """ def __init__(self, value_list: List[Any], @@ -124,12 +193,27 @@ def __init__(self, init_cfg=init_cfg) def sample_choice(self) -> Any: + """Random sampling from choices. + + Returns: + Any: Selected choice. + """ return random.choice(self.choices) @property def max_choice(self) -> Any: + """Max choice of all choices. + + Returns: + Any: Max choice. + """ return self.choices[-1] @property def min_choice(self) -> Any: + """Min choice of all choices. + + Returns: + Any: Min choice. + """ return self.choices[0] diff --git a/tests/test_models/test_mutables/test_mutable_value.py b/tests/test_models/test_mutables/test_mutable_value.py new file mode 100644 index 000000000..9003ed693 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_value.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +import pytest + +from mmrazor.models.mutables import MutableValue, OneShotMutableValue + + +class TestMutableValue(TestCase): + + def test_init_mutable_value(self) -> None: + value_list = [2, 4, 6] + mv = MutableValue(value_list=value_list) + assert mv.current_choice == 2 + assert mv.num_choices == 3 + + mv = MutableValue(value_list=value_list, default_value=4) + assert mv.current_choice == 4 + + with pytest.raises(ValueError): + mv = MutableValue(value_list=value_list, default_value=5) + + mv = MutableValue(value_list=[2]) + assert mv.current_choice == 2 + assert mv.choices == [2] + + with pytest.raises(TypeError): + mv = MutableValue(value_list=[2, 3.2]) + + def test_init_one_shot_mutable_value(self) -> None: + value_list = [6, 4, 2] + mv = OneShotMutableValue(value_list=value_list) + assert mv.current_choice == 6 + assert mv.choices == [2, 4, 6] + + mv = OneShotMutableValue(value_list=value_list, default_value=4) + assert mv.current_choice == 4 + + def test_fix_chosen(self) -> None: + mv = MutableValue([2, 3, 4]) + chosen = mv.dump_chosen() + assert chosen == { + 'current_choice': mv.current_choice, + 'all_choices': mv.choices + } + + chosen['current_choice'] = 5 + with pytest.raises(AssertionError): + mv.fix_chosen(chosen) + + chosen_copied = copy.deepcopy(chosen) + chosen_copied['all_choices'] = [1, 2, 3] + with pytest.raises(AssertionError): + mv.fix_chosen(chosen_copied) + + chosen['current_choice'] = 3 + mv.fix_chosen(chosen) + assert mv.current_choice == 3 + + with pytest.raises(RuntimeError): + mv.fix_chosen(chosen) + + def test_one_shot_mutable_value_sample(self) -> None: + mv = OneShotMutableValue(value_list=[2, 3, 4]) + assert mv.max_choice == 4 + assert mv.min_choice == 2 + + for _ in range(100): + assert mv.sample_choice() in mv.choices + + def test_mul(self) -> None: + mv = MutableValue(value_list=[1, 2, 3], default_value=3) + mul_derived_mv = mv * 2 + rmul_derived_mv = 2 * mv + + assert mul_derived_mv.current_choice == 6 + assert rmul_derived_mv.current_choice == 6 + + mv.current_choice = 2 + assert mul_derived_mv.current_choice == 4 + assert rmul_derived_mv.current_choice == 4 + + with pytest.raises(TypeError): + _ = mv * 1.2 + + def test_floordiv(self) -> None: + mv = MutableValue(value_list=[120, 128, 136]) + derived_mv = mv // 8 + + mv.current_choice = 120 + assert derived_mv.current_choice == 16 + mv.current_choice = 128 + assert derived_mv.current_choice == 16 + + derived_mv = mv // (8, 3) + mv.current_choice = 120 + assert derived_mv.current_choice == 15 + mv.current_choice = 136 + assert derived_mv.current_choice == 18 + + with pytest.raises(TypeError): + _ = mv // 1.2 + + def test_repr(self) -> None: + value_list = [2, 4, 6] + mv = MutableValue(value_list=value_list) + + assert repr(mv) == \ + f'MutableValue(value_list={value_list}, current_choice=2)' From eaf57a1e6752e092fbbdf1622e94d17cca27230d Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Tue, 9 Aug 2022 13:30:08 +0800 Subject: [PATCH 5/9] fix logger error --- mmrazor/models/mutables/derived_mutable.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 180b8c1c4..96e849a74 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -1,17 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +import logging from collections.abc import Iterable from typing import Any, Callable, Dict, Optional, Protocol, Set, Union import torch from mmcls.models.utils import make_divisible -from mmengine.logging import MMLogger +from mmengine.logging import print_log from torch import Tensor from .base_mutable import CHOICE_TYPE, BaseMutable -logger = MMLogger.get_current_instance() - class MutableProtocol(Protocol): # pragma: no cover """Protocol for Mutable.""" @@ -222,7 +221,9 @@ def __init__(self, if len(source_mutables) == 0: # TODO # warning or raise error? - logger.warning('Can not find source mutables automatically') + print_log( + 'Can not find source mutables automatically', + level=logging.WARNING) else: source_mutables = set(source_mutables) for mutable in source_mutables: From 953dfb705856c7300559b1c3aca7aa11e1d7df8d Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Wed, 10 Aug 2022 10:17:24 +0800 Subject: [PATCH 6/9] fix according to comments --- mmrazor/models/mutables/derived_mutable.py | 46 +++++++++++-------- mmrazor/models/utils/__init__.py | 5 +- mmrazor/models/utils/make_divisible.py | 31 +++++++++++++ .../test_mutables/test_derived_mutable.py | 45 +++++++++++++++++- 4 files changed, 105 insertions(+), 22 deletions(-) create mode 100644 mmrazor/models/utils/make_divisible.py diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 96e849a74..cf944fdb6 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -1,14 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect -import logging from collections.abc import Iterable from typing import Any, Callable, Dict, Optional, Protocol, Set, Union import torch -from mmcls.models.utils import make_divisible -from mmengine.logging import print_log from torch import Tensor +from ..utils import make_divisible from .base_mutable import CHOICE_TYPE, BaseMutable @@ -110,9 +108,6 @@ def fn(): def _concat_mask_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable: """Helper function to build `mask_fn` for concat derived mutable.""" - for mutable in mutables: - if not hasattr(mutable, 'current_mask'): - raise RuntimeError('mutable must have attribute `currnet_mask`') def fn(): return torch.cat([m.current_mask for m in mutables]) @@ -154,13 +149,19 @@ def derive_divide_mutable(self: MutableProtocol, def derive_concat_mutable( mutables: Iterable[MutableChannelProtocol]) -> 'DerivedMutable': """Derive concat mutable, usually used with `torch.cat`.""" + for mutable in mutables: + if not hasattr(mutable, 'current_mask'): + raise RuntimeError('Source mutable of concat derived mutable ' + 'must have attribute `currnet_mask`') + choice_fn = _concat_choice_fn(mutables) mask_fn = _concat_mask_fn(mutables) return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn) -class DerivedMutable(BaseMutable[CHOICE_TYPE, Dict], DerivedMethodMixin): +class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE], + DerivedMethodMixin): """Class for derived mutable. A derived mutable is a mutable derived from other mutables that has @@ -219,11 +220,9 @@ def __init__(self, if source_mutables is None: source_mutables = self._trace_source_mutables() if len(source_mutables) == 0: - # TODO - # warning or raise error? - print_log( - 'Can not find source mutables automatically', - level=logging.WARNING) + raise RuntimeError( + 'Can not find source mutables automatically, ' + 'please provide manually.') else: source_mutables = set(source_mutables) for mutable in source_mutables: @@ -234,7 +233,7 @@ def __init__(self, # TODO # has no effect - def fix_chosen(self, chosen: Dict) -> None: + def fix_chosen(self, chosen: CHOICE_TYPE) -> None: """Fix mutable with subnet config. Warning: @@ -245,13 +244,13 @@ def fix_chosen(self, chosen: Dict) -> None: self.is_fixed = True - def dump_chosen(self) -> Dict: + def dump_chosen(self) -> CHOICE_TYPE: """Dump information of chosen. Returns: Dict: Dumped information. """ - return dict(current_choice=self.current_choice) + return self.current_choice @property def num_choices(self) -> int: @@ -279,16 +278,26 @@ def current_choice(self, choice: CHOICE_TYPE) -> None: RuntimeError: Error when `current_choice` of derived mutable is modified directly. """ - raise RuntimeError('Choice of drived mutable can not be set!') + raise RuntimeError('Choice of drived mutable can not be set.') @property def current_mask(self) -> Tensor: """Current mask of derived mutable.""" if self.mask_fn is None: raise RuntimeError( - '`mask_fn` must be set before access `current_mask`') + '`mask_fn` must be set before access `current_mask`.') return self.mask_fn() + @current_mask.setter + def current_mask(self, mask: Tensor) -> None: + """Setter of current mask. + + Raises: + RuntimeError: Error when `current_mask` of derived mutable + is modified directly. + """ + raise RuntimeError('Mask of drived mutable can not be set.') + @staticmethod def _trace_source_mutables_from_closure( closure: Callable) -> Set[BaseMutable]: @@ -343,8 +352,7 @@ def is_source_mutable(mutable: object) -> bool: # should be __str__? but can not provide info when debug def __repr__(self) -> str: # pragma: no cover s = f'{self.__class__.__name__}(' - if self.choice_fn is not None: - s += f'current_choice={self.current_choice}, ' + s += f'current_choice={self.current_choice}, ' if self.mask_fn is not None: s += f'activated_channels={self.current_mask.sum().item()}, ' s += f'source_mutables={self.source_mutables}, ' diff --git a/mmrazor/models/utils/__init__.py b/mmrazor/models/utils/__init__.py index 7a477f6dd..fd83be434 100644 --- a/mmrazor/models/utils/__init__.py +++ b/mmrazor/models/utils/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .make_divisible import make_divisible from .misc import add_prefix from .optim_wrapper import reinitialize_optim_wrapper_count_status -__all__ = ['add_prefix', 'reinitialize_optim_wrapper_count_status'] +__all__ = [ + 'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible' +] diff --git a/mmrazor/models/utils/make_divisible.py b/mmrazor/models/utils/make_divisible.py new file mode 100644 index 000000000..5056aeb15 --- /dev/null +++ b/mmrazor/models/utils/make_divisible.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + + +def make_divisible(value: int, + divisor: int, + min_value: Optional[int] = None, + min_ratio: float = 0.9) -> int: + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py index efd32606e..d491a4f81 100644 --- a/tests/test_models/test_mutables/test_derived_mutable.py +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -73,6 +73,10 @@ def test_mutable_channel_derived(self) -> None: mc_derived.current_mask, torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=torch.bool)) + with pytest.raises(RuntimeError): + mc_derived.current_mask = torch.ones( + mc_derived.current_mask.size()) + def test_mutable_divide(self) -> None: mc = OneShotMutableChannel( num_channels=128, @@ -101,13 +105,17 @@ def test_mutable_divide(self) -> None: def test_double_fixed(self) -> None: choice_fn = lambda x: x # noqa: E731 - derived_mutable = DerivedMutable(choice_fn) + derived_mutable = DerivedMutable(choice_fn, source_mutables=[]) derived_mutable.fix_chosen({}) with pytest.raises(RuntimeError): derived_mutable.fix_chosen({}) def test_source_mutables(self) -> None: + useless_fn = lambda x: x # noqa: E731 + with pytest.raises(RuntimeError): + _ = DerivedMutable(choice_fn=useless_fn) + mc1 = OneShotMutableChannel( num_channels=3, candidate_choices=[1, 3], candidate_mode='number') mc2 = OneShotMutableChannel( @@ -156,6 +164,39 @@ def fn(): mask_fn=dict_closure_fn({2: [mc1, mc2, mc3]}, {3: dd_mutable})) assert dddd_mutable.source_mutables == {mc1, mc2, mc3} + def test_nested_mutables(self) -> None: + source_a = OneShotMutableChannel( + num_channels=2, candidate_choices=[1, 2], candidate_mode='number') + source_b = OneShotMutableChannel( + num_channels=3, candidate_choices=[2, 3], candidate_mode='number') + + # derive from + derived_c = source_a * 1 + concat_mutables = [source_b, derived_c] + derived_d = DerivedMutable.derive_concat_mutable(concat_mutables) + concat_mutables = [derived_c, derived_d] + derived_e = DerivedMutable.derive_concat_mutable(concat_mutables) + + assert derived_c.source_mutables == {source_a} + assert derived_d.source_mutables == {source_a, source_b} + assert derived_e.source_mutables == {source_a, source_b} + + source_a.current_choice = 1 + source_b.current_choice = 3 + + assert derived_c.current_choice == 1 + assert torch.equal(derived_c.current_mask, + torch.tensor([1, 0], dtype=torch.bool)) + + assert derived_d.current_choice == 4 + assert torch.equal(derived_d.current_mask, + torch.tensor([1, 1, 1, 1, 0], dtype=torch.bool)) + + assert derived_e.current_choice == 5 + assert torch.equal( + derived_e.current_mask, + torch.tensor([1, 0, 1, 1, 1, 1, 0], dtype=torch.bool)) + @pytest.mark.parametrize('expand_ratio', [1, 2, 3]) def test_derived_expand_mutable(expand_ratio: int) -> None: @@ -180,7 +221,7 @@ def test_derived_expand_mutable(expand_ratio: int) -> None: _ = mv_derived.current_mask chosen = mv_derived.dump_chosen() - assert chosen == {'current_choice': mv.current_choice * expand_ratio} + assert chosen == mv.current_choice * expand_ratio mv_derived.fix_chosen(chosen) assert mv_derived.is_fixed From fd7e4f9df611d42993da75c5202cfc33ae079793 Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Wed, 10 Aug 2022 13:04:34 +0800 Subject: [PATCH 7/9] not dump derived mutable when export --- mmrazor/models/mutables/derived_mutable.py | 29 ++++++++++-- mmrazor/structures/subnet/fix_subnet.py | 16 +++++-- .../test_mutables/test_derived_mutable.py | 47 ++++++++++++++----- .../test_subnet/test_fix_subnet.py | 38 ++++++++++++++- 4 files changed, 108 insertions(+), 22 deletions(-) diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index cf944fdb6..7cef44dcd 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +import logging from collections.abc import Iterable from typing import Any, Callable, Dict, Optional, Protocol, Set, Union import torch +from mmengine.logging import print_log from torch import Tensor from ..utils import make_divisible @@ -239,10 +241,10 @@ def fix_chosen(self, chosen: CHOICE_TYPE) -> None: Warning: Fix derived mutable will have no actually effect. """ - if self.is_fixed: - raise RuntimeError('DerivedMutable can not be fixed twice') - - self.is_fixed = True + print_log( + 'Trying to fix chosen for derived mutable, ' + 'which will have no effect.', + level=logging.WARNING) def dump_chosen(self) -> CHOICE_TYPE: """Dump information of chosen. @@ -250,8 +252,27 @@ def dump_chosen(self) -> CHOICE_TYPE: Returns: Dict: Dumped information. """ + print_log( + 'Trying to dump chosen for derived mutable, ' + 'but its value depend on the source mutables.', + level=logging.WARNING) return self.current_choice + @property + def is_fixed(self) -> bool: + """Whether the derived mutable is fixed. + + Note: + Depends on whether all source mutables are already fixed. + """ + return all(m.is_fixed for m in self.source_mutables) + + @is_fixed.setter + def is_fixed(self, is_fixed: bool) -> bool: + """Setter of is fixed.""" + raise RuntimeError( + '`is_fixed` of derived mutable should not be modified directly') + @property def num_choices(self) -> int: """Number of all choices. diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index 147a26e7b..708da5ad6 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -31,6 +31,7 @@ def load_fix_subnet(model: nn.Module, raise TypeError('fix_mutable should be a `str` or `dict`' f'but got {type(fix_mutable)}') # Avoid circular import + from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables.base_mutable import BaseMutable for name, module in model.named_modules(): @@ -46,9 +47,11 @@ def load_fix_subnet(model: nn.Module, chosen = fix_mutable.get(alias, None) else: mutable_name = name.lstrip(prefix) - assert mutable_name in fix_mutable, \ - f'The module name {mutable_name} is not in ' \ - 'fix_mutable, please check your `fix_mutable`.' + if mutable_name not in fix_mutable and \ + not isinstance(module, DerivedMutable): + raise RuntimeError( + f'The module name {mutable_name} is not in ' + 'fix_mutable, please check your `fix_mutable`.') chosen = fix_mutable.get(mutable_name, None) module.fix_chosen(chosen) @@ -56,15 +59,20 @@ def load_fix_subnet(model: nn.Module, _dynamic_to_static(model) -def export_fix_subnet(model: nn.Module) -> FixMutable: +def export_fix_subnet(model: nn.Module, + dump_derived_mutable: bool = False) -> FixMutable: """Export subnet that can be loaded by :func:`load_fix_subnet`.""" # Avoid circular import + from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables.base_mutable import BaseMutable fix_subnet = dict() for name, module in model.named_modules(): if isinstance(module, BaseMutable): + if isinstance(module, DerivedMutable) and not dump_derived_mutable: + continue + assert not module.is_fixed if module.alias: fix_subnet[module.alias] = module.dump_chosen() diff --git a/tests/test_models/test_mutables/test_derived_mutable.py b/tests/test_models/test_mutables/test_derived_mutable.py index d491a4f81..99da8dc71 100644 --- a/tests/test_models/test_mutables/test_derived_mutable.py +++ b/tests/test_models/test_mutables/test_derived_mutable.py @@ -11,6 +11,40 @@ class TestDerivedMutable(TestCase): + def test_is_fixed(self) -> None: + mc = OneShotMutableChannel( + num_channels=10, + candidate_choices=[2, 8, 10], + candidate_mode='number') + mc.current_choice = 2 + + mv = OneShotMutableValue(value_list=[2, 3, 4]) + mv.current_choice = 3 + + derived_mutable = mc * mv + assert not derived_mutable.is_fixed + + with pytest.raises(RuntimeError): + derived_mutable.is_fixed = True + + mc.fix_chosen(mc.dump_chosen()) + assert not derived_mutable.is_fixed + mv.fix_chosen(mv.dump_chosen()) + assert derived_mutable.is_fixed + + def test_fix_dump_chosen(self) -> None: + mv = OneShotMutableValue(value_list=[2, 3, 4]) + mv.current_choice = 3 + + derived_mutable = mv * 2 + assert derived_mutable.dump_chosen() == 6 + + mv.current_choice = 4 + assert derived_mutable.dump_chosen() == 8 + + # nothing will happen + derived_mutable.fix_chosen(derived_mutable.dump_chosen()) + def test_derived_same_mutable(self) -> None: mc = OneShotMutableChannel( num_channels=3, @@ -103,14 +137,6 @@ def test_mutable_divide(self) -> None: mv.current_choice == 120 assert mv_derived.current_choice == 16 - def test_double_fixed(self) -> None: - choice_fn = lambda x: x # noqa: E731 - derived_mutable = DerivedMutable(choice_fn, source_mutables=[]) - derived_mutable.fix_chosen({}) - - with pytest.raises(RuntimeError): - derived_mutable.fix_chosen({}) - def test_source_mutables(self) -> None: useless_fn = lambda x: x # noqa: E731 with pytest.raises(RuntimeError): @@ -220,10 +246,5 @@ def test_derived_expand_mutable(expand_ratio: int) -> None: with pytest.raises(RuntimeError): _ = mv_derived.current_mask - chosen = mv_derived.dump_chosen() - assert chosen == mv.current_choice * expand_ratio - mv_derived.fix_chosen(chosen) - assert mv_derived.is_fixed - mv.current_choice = 5 assert mv_derived.current_choice == 5 * expand_ratio diff --git a/tests/test_models/test_subnet/test_fix_subnet.py b/tests/test_models/test_subnet/test_fix_subnet.py index 28e691bd7..010372212 100644 --- a/tests/test_models/test_subnet/test_fix_subnet.py +++ b/tests/test_models/test_subnet/test_fix_subnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from mmrazor.models import * # noqa:F403,F401 -from mmrazor.models.mutables import OneShotMutableOP +from mmrazor.models.mutables import OneShotMutableOP, OneShotMutableValue from mmrazor.registry import MODELS from mmrazor.structures import export_fix_subnet, load_fix_subnet from mmrazor.utils import FixMutable @@ -37,6 +37,15 @@ def forward(self, x): return x +class MockModelWithDerivedMutable(nn.Module): + + def __init__(self) -> None: + super().__init__() + + self.source_mutable = OneShotMutableValue([2, 3, 4], default_value=3) + self.derived_mutable = self.source_mutable * 2 + + class TestFixSubnet(TestCase): def test_load_fix_subnet(self): @@ -63,6 +72,11 @@ def test_load_fix_subnet(self): model = MockModel() load_fix_subnet(model, fix_subnet=10) + model = MockModel() + fix_subnet.pop('mutable1') + with pytest.raises(RuntimeError): + load_fix_subnet(model, fix_subnet) + def test_export_fix_subnet(self): # get FixSubnet fix_subnet = { @@ -82,3 +96,25 @@ def test_export_fix_subnet(self): exported_fix_subnet = export_fix_subnet(model) self.assertDictEqual(fix_subnet, exported_fix_subnet) + + def test_export_fix_subnet_with_derived_mutable(self) -> None: + model = MockModelWithDerivedMutable() + fix_subnet = export_fix_subnet(model) + self.assertDictEqual( + fix_subnet, {'source_mutable': model.source_mutable.dump_chosen()}) + fix_subnet['source_mutable']['current_choice'] = 4 + load_fix_subnet(model, fix_subnet) + assert model.source_mutable.current_choice == 4 + assert model.derived_mutable.current_choice == 8 + + model = MockModelWithDerivedMutable() + fix_subnet = export_fix_subnet(model, dump_derived_mutable=True) + self.assertDictEqual( + fix_subnet, { + 'source_mutable': model.source_mutable.dump_chosen(), + 'derived_mutable': model.derived_mutable.dump_chosen() + }) + fix_subnet['source_mutable']['current_choice'] = 2 + load_fix_subnet(model, fix_subnet) + assert model.source_mutable.current_choice == 2 + assert model.derived_mutable.current_choice == 4 From e9d10644cfd49b00084847f706d17991d86f4756 Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Wed, 10 Aug 2022 13:10:02 +0800 Subject: [PATCH 8/9] add warning in `export_fix_subnet` --- mmrazor/structures/subnet/fix_subnet.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index 708da5ad6..ead21413e 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging + import mmcv +from mmengine.logging import print_log from torch import nn from mmrazor.utils import FixMutable, ValidFixMutable @@ -62,6 +65,11 @@ def load_fix_subnet(model: nn.Module, def export_fix_subnet(model: nn.Module, dump_derived_mutable: bool = False) -> FixMutable: """Export subnet that can be loaded by :func:`load_fix_subnet`.""" + if dump_derived_mutable: + print_log( + 'Trying to dump information of all derived mutables, ' + 'this might harm readability of the exported configurations.', + level=logging.WARNING) # Avoid circular import from mmrazor.models.mutables import DerivedMutable From dff284e5e71c1abbe5da801734d12031b03b6f0b Mon Sep 17 00:00:00 2001 From: wutongshenqiu <44188071+wutongshenqiu@users.noreply.github.com> Date: Wed, 10 Aug 2022 15:41:08 +0800 Subject: [PATCH 9/9] fix __mul__ in mutable value --- .../one_shot_mutable_channel.py | 3 +- .../mutables/mutable_value/mutable_value.py | 19 ++++++++++++- .../test_mutables/test_mutable_value.py | 28 ++++++++++++++++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py index 8fde79a95..7f6eea3ad 100644 --- a/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/one_shot_mutable_channel.py @@ -6,7 +6,6 @@ from mmrazor.registry import MODELS from ..derived_mutable import DerivedMutable -from ..mutable_value import OneShotMutableValue from .mutable_channel import MutableChannel @@ -171,6 +170,8 @@ def __mul__(self, other) -> DerivedMutable: if isinstance(other, int): return self.derive_expand_mutable(other) + from ..mutable_value import OneShotMutableValue + def expand_choice_fn(mutable1: 'OneShotMutableChannel', mutable2: OneShotMutableValue) -> Callable: diff --git a/mmrazor/models/mutables/mutable_value/mutable_value.py b/mmrazor/models/mutables/mutable_value/mutable_value.py index 22b7b8d21..748d83e78 100644 --- a/mmrazor/models/mutables/mutable_value/mutable_value.py +++ b/mmrazor/models/mutables/mutable_value/mutable_value.py @@ -110,7 +110,7 @@ def current_choice(self, choice: Any) -> Any: self._current_choice = choice - def __rmul__(self, other: int) -> DerivedMutable: + def __rmul__(self, other) -> DerivedMutable: """Please refer to method :func:`__mul__`.""" return self * other @@ -217,3 +217,20 @@ def min_choice(self) -> Any: Any: Min choice. """ return self.choices[0] + + def __mul__(self, other) -> DerivedMutable: + """Overload `*` operator. + + Args: + other (int, OneShotMutableChannel): Expand ratio or + OneShotMutableChannel. + + Returns: + DerivedMutable: Derived expand mutable. + """ + from ..mutable_channel import OneShotMutableChannel + + if isinstance(other, OneShotMutableChannel): + return other * self + + return super().__mul__(other) diff --git a/tests/test_models/test_mutables/test_mutable_value.py b/tests/test_models/test_mutables/test_mutable_value.py index 9003ed693..0b5ed7947 100644 --- a/tests/test_models/test_mutables/test_mutable_value.py +++ b/tests/test_models/test_mutables/test_mutable_value.py @@ -3,8 +3,10 @@ from unittest import TestCase import pytest +import torch -from mmrazor.models.mutables import MutableValue, OneShotMutableValue +from mmrazor.models.mutables import (MutableValue, OneShotMutableChannel, + OneShotMutableValue) class TestMutableValue(TestCase): @@ -84,6 +86,30 @@ def test_mul(self) -> None: with pytest.raises(TypeError): _ = mv * 1.2 + mv = MutableValue(value_list=[1, 2, 3], default_value=3) + mc = OneShotMutableChannel( + num_channels=4, candidate_choices=[2, 4], candidate_mode='number') + + with pytest.raises(TypeError): + _ = mc * mv + with pytest.raises(TypeError): + _ = mv * mc + + mv = OneShotMutableValue(value_list=[1, 2, 3], default_value=3) + mc.current_choice = 2 + + derived1 = mc * mv + derived2 = mv * mc + + assert derived1.current_choice == 6 + assert derived2.current_choice == 6 + assert torch.equal(derived1.current_mask, derived2.current_mask) + + mv.current_choice = 2 + assert derived1.current_choice == 4 + assert derived2.current_choice == 4 + assert torch.equal(derived1.current_mask, derived2.current_mask) + def test_floordiv(self) -> None: mv = MutableValue(value_list=[120, 128, 136]) derived_mv = mv // 8