diff --git a/.circleci/test.yml b/.circleci/test.yml index 2fc0270a2..5da20de36 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -119,7 +119,7 @@ jobs: docker exec mmrazor pip install -e /mmdetection docker exec mmrazor pip install -e /mmclassification docker exec mmrazor pip install -e /mmsegmentation - docker exec mmrazor pip install -r requirements/tests.txt + docker exec mmrazor pip install -r requirements.txt - run: name: Build and install command: | diff --git a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py index 5ff172e0c..6249b0160 100644 --- a/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py +++ b/configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py @@ -47,7 +47,7 @@ type='OneShotMutableChannelUnit', default_args=dict( candidate_choices=list(i / 12 for i in range(2, 13)), - candidate_mode='ratio', + choice_mode='ratio', divisor=8)), parse_cfg=dict( type='BackwardTracer', diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 712315105..137414cbb 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -13,5 +13,7 @@ 'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DAFLDataFreeDistillation', - 'OverhaulFeatureDistillation', 'Dsnas', 'DsnasDDP' + 'OverhaulFeatureDistillation', 'Dsnas', 'DsnasDDP', + 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas', + 'DsnasDDP' ] diff --git a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py index 5c653b557..e3e795fa4 100644 --- a/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py +++ b/mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py @@ -152,8 +152,7 @@ def init_candidates(self, candidates: List): for num in candidates: self.candidate_bn[str(num)] = nn.BatchNorm2d( num, self.eps, self.momentum, self.affine, - self.track_running_stats, self.weight.device, - self.weight.dtype) + self.track_running_stats) def forward(self, input: Tensor) -> Tensor: """Forward.""" diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 23f1fea3c..fb64d80dd 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -2,7 +2,8 @@ from .base_mutable import BaseMutable from .derived_mutable import DerivedMutable from .mutable_channel import (BaseMutableChannel, MutableChannelContainer, - SimpleMutableChannel, SquentialMutableChannel) + OneShotMutableChannel, SimpleMutableChannel, + SquentialMutableChannel) from .mutable_channel.units import (ChannelUnitType, L1MutableChannelUnit, MutableChannelUnit, OneShotMutableChannelUnit, @@ -22,5 +23,7 @@ 'BaseMutableChannel', 'MutableChannelContainer', 'ChannelUnitType', 'SquentialMutableChannel', 'BaseMutable', 'DiffChoiceRoute', 'DiffMutableModule', 'DerivedMutable', 'MutableValue', - 'OneShotMutableValue', 'OneHotMutableOP' + 'OneShotMutableValue', 'OneHotMutableOP', 'OneShotMutableChannel', + 'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel', + 'DerivedMutable', 'MutableValue', 'OneShotMutableValue', 'OneHotMutableOP' ] diff --git a/mmrazor/models/mutables/mutable_channel/__init__.py b/mmrazor/models/mutables/mutable_channel/__init__.py index 108a5d8ca..0ef09dc78 100644 --- a/mmrazor/models/mutables/mutable_channel/__init__.py +++ b/mmrazor/models/mutables/mutable_channel/__init__.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_mutable_channel import BaseMutableChannel -from .units import (ChannelUnitType, L1MutableChannelUnit, - MutableChannelUnit, OneShotMutableChannelUnit, - SequentialMutableChannelUnit, SlimmableChannelUnit) from .mutable_channel_container import MutableChannelContainer +from .oneshot_mutalbe_channel import OneShotMutableChannel from .sequential_mutable_channel import SquentialMutableChannel from .simple_mutable_channel import SimpleMutableChannel +from .units import (ChannelUnitType, L1MutableChannelUnit, MutableChannelUnit, + OneShotMutableChannelUnit, SequentialMutableChannelUnit, + SlimmableChannelUnit) __all__ = [ 'SimpleMutableChannel', 'L1MutableChannelUnit', 'SequentialMutableChannelUnit', 'MutableChannelUnit', - 'OneShotMutableChannelUnit', 'SlimmableChannelUnit', - 'BaseMutableChannel', 'MutableChannelContainer', 'SquentialMutableChannel', - 'ChannelUnitType' + 'OneShotMutableChannelUnit', 'SlimmableChannelUnit', 'BaseMutableChannel', + 'MutableChannelContainer', 'SquentialMutableChannel', 'ChannelUnitType', + 'OneShotMutableChannel' ] diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py index 150376f5e..9292d64c8 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel_container.py @@ -30,7 +30,7 @@ class MutableChannelContainer(BaseMutableChannel): def __init__(self, num_channels: int, **kwargs): super().__init__(num_channels, **kwargs) - self.mutable_channels: IndexDict[BaseMutableChannel] = IndexDict() + self.mutable_channels = IndexDict() # choice diff --git a/mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py b/mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py new file mode 100644 index 000000000..61d36fd18 --- /dev/null +++ b/mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +from .sequential_mutable_channel import SquentialMutableChannel + + +class OneShotMutableChannel(SquentialMutableChannel): + """OneShotMutableChannel is a subclass of SquentialMutableChannel. The + difference is that a OneShotMutableChannel limits the candidates of the + choice. + + Args: + num_channels (int): number of channels. + candidate_choices (List[Union[float, int]], optional): A list of + candidate width ratios. Each candidate indicates how many + channels to be reserved. Defaults to []. + choice_mode (str, optional): Mode of choices. Defaults to 'number'. + """ + + def __init__(self, + num_channels: int, + candidate_choices: List[Union[float, int]] = [], + choice_mode='number', + **kwargs): + super().__init__(num_channels, choice_mode, **kwargs) + self.candidate_choices = candidate_choices + if candidate_choices == []: + candidate_choices.append(num_channels if self.is_num_mode else 1.0) + + @property + def current_choice(self) -> Union[int, float]: + """Get current choice.""" + return super().current_choice + + @current_choice.setter + def current_choice(self, choice: Union[int, float]): + """Set current choice.""" + assert choice in self.candidate_choices + SquentialMutableChannel.current_choice.fset( # type: ignore + self, # type: ignore + choice) # type: ignore diff --git a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py index 1bcd00df3..eae559d41 100644 --- a/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py @@ -1,17 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable +from typing import Callable, Union import torch from mmrazor.registry import MODELS from ..derived_mutable import DerivedMutable -from .base_mutable_channel import BaseMutableChannel +from .simple_mutable_channel import SimpleMutableChannel # TODO discuss later @MODELS.register_module() -class SquentialMutableChannel(BaseMutableChannel): +class SquentialMutableChannel(SimpleMutableChannel): """SquentialMutableChannel defines a BaseMutableChannel which switch off channel mask from right to left sequentially, like '11111000'. @@ -22,21 +22,36 @@ class SquentialMutableChannel(BaseMutableChannel): num_channels (int): number of channels. """ - def __init__(self, num_channels: int, **kwargs): + def __init__(self, num_channels: int, choice_mode='number', **kwargs): super().__init__(num_channels, **kwargs) + assert choice_mode in ['ratio', 'number'] + self.choice_mode = choice_mode self.mask = torch.ones([self.num_channels]).bool() @property - def current_choice(self) -> int: + def is_num_mode(self): + """Get if the choice is number mode.""" + return self.choice_mode == 'number' + + @property + def current_choice(self) -> Union[int, float]: """Get current choice.""" - return (self.mask == 1).sum().item() + int_choice = (self.mask == 1).sum().item() + if self.is_num_mode: + return int_choice + else: + return self._num2ratio(int_choice) @current_choice.setter - def current_choice(self, choice: int): + def current_choice(self, choice: Union[int, float]): """Set choice.""" + if isinstance(choice, float): + int_choice = self._ratio2num(choice) + else: + int_choice = choice mask = torch.zeros([self.num_channels], device=self.mask.device) - mask[0:choice] = 1 + mask[0:int_choice] = 1 self.mask = mask.bool() @property @@ -58,20 +73,6 @@ def dump_chosen(self): """Dump chosen.""" return self.current_choice - # def __mul__(self, other): - # """multiplication.""" - # if isinstance(other, int): - # return self.derive_expand_mutable(other) - # else: - # return None - - # def __floordiv__(self, other): - # """division.""" - # if isinstance(other, int): - # return self.derive_divide_mutable(other) - # else: - # return None - def __rmul__(self, other) -> DerivedMutable: return self * other @@ -121,3 +122,17 @@ def __floordiv__(self, other) -> DerivedMutable: return self.derive_divide_mutable(*other) raise TypeError(f'Unsupported type {type(other)} for div!') + + def _num2ratio(self, choice: Union[int, float]) -> float: + """Convert the a number choice to a ratio choice.""" + if isinstance(choice, float): + return choice + else: + return choice / self.num_channels + + def _ratio2num(self, choice: Union[int, float]) -> int: + """Convert the a ratio choice to a number choice.""" + if isinstance(choice, int): + return choice + else: + return max(1, int(self.num_channels * choice)) diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py index cba91f810..576412ec0 100644 --- a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py @@ -113,7 +113,7 @@ def is_mutable(self) -> bool: def __repr__(self) -> str: return (f'{self.__class__.__name__}(' - f'{self.name}, index=({self.index}), ' + f'{self.name}, index={self.index}, ' f'is_output_channel=' f'{"true" if self.is_output_channel else "false"}, ' f'expand_ratio={self.expand_ratio}' diff --git a/mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py index 72c845a1d..8b3c258ad 100644 --- a/mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py @@ -5,6 +5,7 @@ import torch.nn as nn from mmrazor.registry import MODELS +from ..simple_mutable_channel import SimpleMutableChannel from .sequential_mutable_channel_unit import SequentialMutableChannelUnit @@ -25,6 +26,23 @@ def __init__(self, min_ratio=0.9) -> None: super().__init__(num_channels, choice_mode, divisor, min_value, min_ratio) + self.mutable_channel = SimpleMutableChannel(num_channels) + + # choices + + @property + def current_choice(self) -> Union[int, float]: + num = self.mutable_channel.activated_channels + if self.is_num_mode: + return num + else: + return self._num2ratio(num) + + @current_choice.setter + def current_choice(self, choice: Union[int, float]): + int_choice = self._get_valid_int_choice(choice) + mask = self._generate_mask(int_choice).bool() + self.mutable_channel.current_choice = mask # private methods diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py index 54130eb11..59039cd83 100644 --- a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py @@ -1,16 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. """This module defines MutableChannelUnit.""" import abc +from collections import Set from typing import Dict, List, Type, TypeVar import torch.nn as nn -import mmrazor.models.architectures.dynamic_ops as dynamic_ops +from mmrazor.models.architectures import dynamic_ops from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin from mmrazor.models.mutables import DerivedMutable -from mmrazor.models.mutables.mutable_channel.base_mutable_channel import \ - BaseMutableChannel -from ..mutable_channel_container import MutableChannelContainer +from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel, + MutableChannelContainer) from .channel_unit import Channel, ChannelUnit @@ -18,8 +18,8 @@ class MutableChannelUnit(ChannelUnit): # init methods def __init__(self, num_channels: int, **kwargs) -> None: - """MutableChannelUnit inherits from ChannelUnit, which manages - channels with channel-dependency. + """MutableChannelUnit inherits from ChannelUnit, which manages channels + with channel-dependency. Compared with ChannelUnit, MutableChannelUnit defines the core interfaces for pruning. By inheriting MutableChannelUnit, @@ -44,6 +44,70 @@ def __init__(self, num_channels: int, **kwargs) -> None: super().__init__(num_channels) + @classmethod + def init_from_mutable_channel(cls, mutable_channel: BaseMutableChannel): + unit = cls(mutable_channel.num_channels) + return unit + + @classmethod + def init_from_predefined_model(cls, model: nn.Module): + """Initialize units using the model with pre-defined dynamicops and + mutable-channels.""" + + def process_container(contanier: MutableChannelContainer, + module, + module_name, + mutable2units, + is_output=True): + for index, mutable in contanier.mutable_channels.items(): + if isinstance(mutable, DerivedMutable): + source_mutables: Set = \ + mutable._trace_source_mutables() + source_channel_mutables = [ + mutable for mutable in source_mutables + if isinstance(mutable, BaseMutableChannel) + ] + assert len(source_channel_mutables) == 1, ( + 'only support one mutable channel ' + 'used in DerivedMutable') + mutable = list(source_channel_mutables)[0] + + if mutable not in mutable2units: + mutable2units[mutable] = cls.init_from_mutable_channel( + mutable) + + unit: MutableChannelUnit = mutable2units[mutable] + if is_output: + unit.add_ouptut_related( + Channel( + module_name, + module, + index, + is_output_channel=is_output)) + else: + unit.add_input_related( + Channel( + module_name, + module, + index, + is_output_channel=is_output)) + + mutable2units: Dict = {} + for name, module in model.named_modules(): + if isinstance(module, DynamicChannelMixin): + in_container: MutableChannelContainer = \ + module.get_mutable_attr( + 'in_channels') + out_container: MutableChannelContainer = \ + module.get_mutable_attr( + 'out_channels') + process_container(in_container, module, name, mutable2units, + False) + process_container(out_container, module, name, mutable2units, + True) + units = list(mutable2units.values()) + return units + # properties @property @@ -97,7 +161,7 @@ def prepare_for_pruning(self, model): For example, we need to register mutables to dynamic-ops. """ - raise not NotImplementedError + raise NotImplementedError # pruning: choice-related diff --git a/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py index 23825756d..235978cfa 100644 --- a/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/one_shot_mutable_channel_unit.py @@ -6,6 +6,7 @@ import torch.nn as nn from mmrazor.registry import MODELS +from ..oneshot_mutalbe_channel import OneShotMutableChannel from .sequential_mutable_channel_unit import SequentialMutableChannelUnit @@ -21,8 +22,8 @@ class OneShotMutableChannelUnit(SequentialMutableChannelUnit): candidate_choices (List[Union[int, float]], optional): A list of candidate width ratios. Each candidate indicates how many channels to be reserved. - Defaults to [0.5, 1.0](candidate_mode='ratio'). - candidate_mode (str, optional): Mode of candidates. + Defaults to [0.5, 1.0](choice_mode='ratio'). + choice_mode (str, optional): Mode of candidates. One of "ratio" or "number". Defaults to 'ratio'. divisor (int): Used to make choice divisible. min_value (int): the minimal value used when make divisible. @@ -32,20 +33,31 @@ class OneShotMutableChannelUnit(SequentialMutableChannelUnit): def __init__(self, num_channels: int, candidate_choices: List[Union[int, float]] = [0.5, 1.0], - candidate_mode='ratio', + choice_mode='ratio', divisor=1, min_value=1, min_ratio=0.9) -> None: - super().__init__(num_channels, candidate_mode, divisor, min_value, + super().__init__(num_channels, choice_mode, divisor, min_value, min_ratio) candidate_choices = copy.copy(candidate_choices) if candidate_choices == []: candidate_choices.append( self.num_channels if self.is_num_mode else 1.0) self.candidate_choices = self._prepare_candidate_choices( - candidate_choices, candidate_mode) + candidate_choices, choice_mode) - self._choice = self.max_choice + self.mutable_channel = OneShotMutableChannel(num_channels, + self.candidate_choices, + choice_mode) + + @classmethod + def init_from_mutable_channel(cls, mutable_channel: OneShotMutableChannel): + unit = cls(mutable_channel.num_channels, + mutable_channel.candidate_choices, + mutable_channel.choice_mode) + mutable_channel.candidate_choices = unit.candidate_choices + unit.mutable_channel = mutable_channel + return unit def prepare_for_pruning(self, model: nn.Module): """Prepare for pruning.""" @@ -64,7 +76,7 @@ def config_template(self, init_cfg.pop('choice_mode') init_cfg.update({ 'candidate_choices': self.candidate_choices, - 'candidate_mode': self.choice_mode + 'choice_mode': self.choice_mode }) return config @@ -79,9 +91,10 @@ def current_choice(self) -> Union[int, float]: def current_choice(self, choice: Union[int, float]): """Set current choice.""" assert choice in self.candidate_choices - SequentialMutableChannelUnit.current_choice.fset( # type: ignore - self, # type: ignore - choice) # type: ignore + int_choice = self._get_valid_int_choice(choice) + choice_ = int_choice if self.is_num_mode else self._num2ratio( + int_choice) + self.mutable_channel.current_choice = choice_ def sample_choice(self) -> Union[int, float]: """Sample a valid choice.""" @@ -101,9 +114,9 @@ def max_choice(self) -> Union[int, float]: # private methods def _prepare_candidate_choices(self, candidate_choices: List, - candidate_mode) -> List: + choice_mode) -> List: """Process candidate_choices.""" - choice_type = int if candidate_mode == 'number' else float + choice_type = int if choice_mode == 'number' else float for choice in candidate_choices: assert isinstance(choice, choice_type) if self.is_num_mode: diff --git a/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py index ef4beda74..89a25d236 100644 --- a/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/sequential_mutable_channel_unit.py @@ -2,24 +2,23 @@ import random from typing import Dict, Union -import torch import torch.nn as nn from mmengine import MMLogger -import mmrazor.models.architectures.dynamic_ops as dynamic_ops +from mmrazor.models.architectures import dynamic_ops from mmrazor.models.utils import make_divisible from mmrazor.registry import MODELS from ..mutable_channel_container import MutableChannelContainer -from ..simple_mutable_channel import SimpleMutableChannel +from ..sequential_mutable_channel import SquentialMutableChannel from .mutable_channel_unit import MutableChannelUnit # TODO change the name of SequentialMutableChannelUnit @MODELS.register_module() class SequentialMutableChannelUnit(MutableChannelUnit): - """SequentialMutableChannelUnit accepts a intger(number) or float(ratio) - as the choice, which indicates how many of the channels are remained from - left to right, like 11110000. + """SequentialMutableChannelUnit accepts a intger(number) or float(ratio) as + the choice, which indicates how many of the channels are remained from left + to right, like 11110000. Args: num_channels (int): number of channels. @@ -38,15 +37,24 @@ def __init__( min_value=1, min_ratio=0.9) -> None: super().__init__(num_channels) - self.mutable_channel: SimpleMutableChannel = SimpleMutableChannel( - self.num_channels) assert choice_mode in ['ratio', 'number'] self.choice_mode = choice_mode + + self.mutable_channel: SquentialMutableChannel = \ + SquentialMutableChannel(num_channels, choice_mode=choice_mode) + # for make_divisible self.divisor = divisor self.min_value = min_value self.min_ratio = min_ratio + @classmethod + def init_from_mutable_channel(cls, + mutable_channel: SquentialMutableChannel): + unit = cls(mutable_channel.num_channels, mutable_channel.choice_mode) + unit.mutable_channel = mutable_channel + return unit + def prepare_for_pruning(self, model: nn.Module): """Prepare for pruning, including register mutable channels.""" # register MutableMask @@ -90,21 +98,13 @@ def config_template(self, @property def current_choice(self) -> Union[int, float]: """return current choice.""" - if self.is_num_mode: - return self.mutable_channel.activated_channels - else: - return self._num2ratio(self.mutable_channel.activated_channels) + return self.mutable_channel.current_choice @current_choice.setter def current_choice(self, choice: Union[int, float]): """set choice.""" - choice_num = self._ratio2num(choice) - choice_num_ = self._make_divisible(choice_num) - - mask = self._generate_mask(choice_num_) - self.mutable_channel.current_choice = mask - if choice_num != choice_num_: - self._make_divisible_info(choice, self.current_choice) + choice_num_ = self._get_valid_int_choice(choice) + self.mutable_channel.current_choice = choice_num_ def sample_choice(self) -> Union[int, float]: """Sample a choice in (0,1]""" @@ -116,6 +116,12 @@ def sample_choice(self) -> Union[int, float]: return self._num2ratio(num_choice) # private methods + def _get_valid_int_choice(self, choice: Union[float, int]) -> int: + choice_num = self._ratio2num(choice) + choice_num_ = self._make_divisible(choice_num) + if choice_num != choice_num_: + self._make_divisible_info(choice, self.current_choice) + return choice_num_ def _make_divisible(self, choice_int: int): """Make the choice divisible.""" @@ -136,12 +142,6 @@ def _ratio2num(self, choice: Union[int, float]) -> int: else: return max(1, int(self.num_channels * choice)) - def _generate_mask(self, choice: int) -> torch.Tensor: - """torch.Tesnor: generate mask for pruning""" - mask = torch.zeros([self.num_channels]) - mask[0:choice] = 1 - return mask - def _make_divisible_info(self, choice, new_choice): logger = MMLogger.get_current_instance() logger.info(f'The choice={choice}, which is set to {self.name}, ' diff --git a/mmrazor/models/mutables/mutable_channel/units/slimmable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/slimmable_channel_unit.py index 6b6eaa96e..a51dce80b 100644 --- a/mmrazor/models/mutables/mutable_channel/units/slimmable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/slimmable_channel_unit.py @@ -4,7 +4,7 @@ import torch.nn as nn -import mmrazor.models.architectures.dynamic_ops as dynamic_ops +from mmrazor.models.architectures import dynamic_ops from mmrazor.registry import MODELS from ..mutable_channel_container import MutableChannelContainer from .one_shot_mutable_channel_unit import OneShotMutableChannelUnit @@ -19,8 +19,8 @@ class SlimmableChannelUnit(OneShotMutableChannelUnit): candidate_choices (List[Union[int, float]], optional): A list of candidate width ratios. Each candidate indicates how many channels to be reserved. - Defaults to [0.5, 1.0](candidate_mode='ratio'). - candidate_mode (str, optional): Mode of candidates. + Defaults to [0.5, 1.0](choice_mode='ratio'). + choice_mode (str, optional): Mode of candidates. One of 'ratio' or 'number'. Defaults to 'number'. divisor (int, optional): Used to make choice divisible. min_value (int, optional): The minimal value used when make divisible. @@ -31,12 +31,12 @@ class SlimmableChannelUnit(OneShotMutableChannelUnit): def __init__(self, num_channels: int, candidate_choices: List[Union[int, float]] = [], - candidate_mode='number', + choice_mode='number', divisor=1, min_value=1, min_ratio=0.9) -> None: - super().__init__(num_channels, candidate_choices, candidate_mode, - divisor, min_value, min_ratio) + super().__init__(num_channels, candidate_choices, choice_mode, divisor, + min_value, min_ratio) def prepare_for_pruning(self, model: nn.Module): """Prepare for pruning.""" diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 3e45a86d0..c4ce92e96 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -1,17 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Dict, Generic, List, Optional, Set, Tuple, Type, Union +from typing import Dict, Generic, List, Optional, Tuple, Type, Union from mmengine import fileio from torch.nn import Module from mmrazor.models.architectures.dynamic_ops import DynamicChannelMixin -from mmrazor.models.mutables import (BaseMutableChannel, ChannelUnitType, - DerivedMutable, MutableChannelContainer, - MutableChannelUnit, +from mmrazor.models.mutables import (ChannelUnitType, MutableChannelUnit, SequentialMutableChannelUnit) -from mmrazor.models.mutables.mutable_channel.units.channel_unit import ( - Channel, ChannelUnit) +from mmrazor.models.mutables.mutable_channel.units.channel_unit import \ + ChannelUnit from mmrazor.registry import MODELS from mmrazor.structures.graph import ModuleGraph from ..base_mutator import BaseMutator @@ -65,15 +63,14 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): # init - def __init__( - self, - channel_unit_cfg: Union[ - dict, - Type[MutableChannelUnit]] = SequentialMutableChannelUnit, - parse_cfg: Dict = dict( - type='BackwardTracer', - loss_calculator=dict(type='ImageClassifierPseudoLoss')), - init_cfg: Optional[Dict] = None) -> None: + def __init__(self, + channel_unit_cfg: Union[ + dict, + Type[MutableChannelUnit]] = SequentialMutableChannelUnit, + parse_cfg: Dict = dict( + type='BackwardTracer', + loss_calculator=dict(type='ImageClassifierPseudoLoss')), + init_cfg: Optional[Dict] = None) -> None: super().__init__(init_cfg) @@ -244,8 +241,7 @@ def _convert_channel_unit_to_mutable(self, units: List[ChannelUnit]): if unit.name in self.units_cfg and \ 'init_args' in self.units_cfg[unit.name]: args = self.units_cfg[unit.name]['init_args'] - mutable_unit = self.unit_class.init_from_channel_unit( - unit, args) + mutable_unit = self.unit_class.init_from_channel_unit(unit, args) mutable_units.append(mutable_unit) return mutable_units @@ -314,69 +310,6 @@ def _prepare_from_predefined_model(self, model: Module): """Initialize units using the model with pre-defined dynamicops and mutable-channels.""" - def process_container(contanier: MutableChannelContainer, - module, - module_name, - mutable2units, - is_output=True): - for index, mutable in contanier.mutable_channels.items(): - if isinstance(mutable, DerivedMutable): - source_mutables: Set = \ - mutable._trace_source_mutables() - source_channel_mutables = [ - mutable for mutable in source_mutables - if isinstance(mutable, BaseMutableChannel) - ] - assert len(source_channel_mutables) == 1, ( - 'only support one mutable channel ' - 'used in DerivedMutable') - mutable = list(source_channel_mutables)[0] - - if mutable not in mutable2units: - if hasattr(self.unit_class, 'init_from_mutable_channel'): - mutable2units[ - mutable] = \ - self.unit_class.init_from_mutable_channel( - mutable) - else: - mutable2units[mutable] = self.unit_class( - mutable.num_channels, **self.unit_default_args) - - unit: MutableChannelUnit = mutable2units[mutable] - if is_output: - unit.add_ouptut_related( - Channel( - module_name, - module, - index, - is_output_channel=is_output)) - else: - unit.add_input_related( - Channel( - module_name, - module, - index, - is_output_channel=is_output)) - - mutable2units: Dict = {} - for name, module in model.named_modules(): - if isinstance(module, DynamicChannelMixin): - in_container: MutableChannelContainer = \ - module.get_mutable_attr( - 'in_channels') - out_container: MutableChannelContainer = \ - module.get_mutable_attr( - 'out_channels') - process_container(in_container, module, name, mutable2units, - False) - process_container(out_container, module, name, mutable2units, - True) - for mutable, unit in mutable2units.items(): - if isinstance(mutable, DerivedMutable): - continue - else: - unit.mutable_channel = mutable - units = list(mutable2units.values()) - for unit in units: - self._name2unit[unit.name] = unit + units = self.unit_class.init_from_predefined_model(model) + return units diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py index 7f2a28916..a5350ab2b 100644 --- a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py @@ -3,7 +3,7 @@ from mmrazor.models.mutables import OneShotMutableChannelUnit from mmrazor.registry import MODELS -from .channel_mutator import ChannelUnitType, ChannelMutator +from .channel_mutator import ChannelMutator, ChannelUnitType @MODELS.register_module() diff --git a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py index 7c3165c19..7c0d24fa6 100644 --- a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py @@ -22,8 +22,7 @@ class SlimmableChannelMutator(ChannelMutator[SlimmableChannelUnit]): """ def __init__(self, - channel_unit_cfg=dict( - type='SlimmableChannelUnit', units={}), + channel_unit_cfg=dict(type='SlimmableChannelUnit', units={}), parse_cfg=dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss')), diff --git a/mmrazor/structures/graph/channel_nodes.py b/mmrazor/structures/graph/channel_nodes.py index bababa8aa..1749b5875 100644 --- a/mmrazor/structures/graph/channel_nodes.py +++ b/mmrazor/structures/graph/channel_nodes.py @@ -196,8 +196,7 @@ def channel_forward(self, *in_channel_tensor: ChannelTensor): channel_lis.unit_dict for channel_lis in in_channel_tensor ] for key in node_units[0]: - BaseChannelUnit.union_units( - [units[key] for units in node_units]) + BaseChannelUnit.union_units([units[key] for units in node_units]) super().channel_forward(in_channel_tensor[0]) def __repr__(self) -> str: @@ -208,9 +207,8 @@ class CatChannelNode(ChannelNode): """A CatChannelNode cat all input channels.""" def channel_forward(self, *in_channel_tensors: ChannelTensor): - BaseChannelUnit.union_two_units( - self.in_channel_tensor.unit_list[0], - self.out_channel_tensor.unit_list[0]) + BaseChannelUnit.union_two_units(self.in_channel_tensor.unit_list[0], + self.out_channel_tensor.unit_list[0]) num_ch = [] for in_ch_tensor in in_channel_tensors: for start, end in in_ch_tensor.unit_dict: diff --git a/mmrazor/utils/index_dict.py b/mmrazor/utils/index_dict.py index 64d986453..8ac3661c2 100644 --- a/mmrazor/utils/index_dict.py +++ b/mmrazor/utils/index_dict.py @@ -5,7 +5,7 @@ VT = TypeVar('VT') # Value type -class IndexDict(OrderedDict[Tuple[int, int], VT]): +class IndexDict(OrderedDict): """IndexDict inherents from OrderedDict[Tuple[int, int], VT]. Each IndexDict object is a OrderDict object which using index(Tuple[int,int]) as key and Any as value. diff --git a/tests/data/MBV2_slimmable_config.json b/tests/data/MBV2_slimmable_config.json index a2c475918..f63029872 100644 --- a/tests/data/MBV2_slimmable_config.json +++ b/tests/data/MBV2_slimmable_config.json @@ -10,7 +10,7 @@ 8, 32 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 32 }, @@ -25,7 +25,7 @@ 8, 16 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 16 }, @@ -40,7 +40,7 @@ 96, 144 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 144 }, @@ -55,7 +55,7 @@ 16, 24 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 24 }, @@ -70,7 +70,7 @@ 96, 176 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 176 }, @@ -85,7 +85,7 @@ 96, 192 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 192 }, @@ -100,7 +100,7 @@ 24, 48 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 48 }, @@ -115,7 +115,7 @@ 144, 240 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 240 }, @@ -130,7 +130,7 @@ 144, 144 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 144 }, @@ -145,7 +145,7 @@ 144, 264 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 264 }, @@ -160,7 +160,7 @@ 56, 88 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 88 }, @@ -175,7 +175,7 @@ 288, 288 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 288 }, @@ -190,7 +190,7 @@ 288, 336 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 336 }, @@ -205,7 +205,7 @@ 288, 432 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 432 }, @@ -220,7 +220,7 @@ 288, 576 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 576 }, @@ -235,7 +235,7 @@ 96, 144 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 144 }, @@ -250,7 +250,7 @@ 432, 576 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 576 }, @@ -265,7 +265,7 @@ 432, 648 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 648 }, @@ -280,7 +280,7 @@ 864, 864 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 864 }, @@ -295,7 +295,7 @@ 240, 240 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 240 }, @@ -310,7 +310,7 @@ 1440, 1440 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 1440 }, @@ -325,7 +325,7 @@ 960, 1440 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 1440 }, @@ -340,7 +340,7 @@ 1440, 1440 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 1440 }, @@ -355,7 +355,7 @@ 480, 480 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 480 }, @@ -370,7 +370,7 @@ 1920, 1920 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 1920 }, @@ -385,7 +385,7 @@ 1000, 1000 ], - "candidate_mode": "number" + "choice_mode": "number" }, "choice": 1000 } diff --git a/tests/data/models.py b/tests/data/models.py index 8a24de078..60c8a7058 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -8,7 +8,7 @@ from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables import BaseMutable -from mmrazor.models.mutables import OneShotMutableChannelUnit, SimpleMutableChannel +from mmrazor.models.mutables import OneShotMutableChannelUnit, SquentialMutableChannel, OneShotMutableChannel from mmrazor.registry import MODELS from mmengine.model import BaseModel # this file includes models for tesing. @@ -497,7 +497,7 @@ def __init__(self, expand_ratio=1) -> None: self.ratio = expand_ratio def __mul__(self, other): - if isinstance(other, SampleOneshotMutableChannel): + if isinstance(other, OneShotMutableChannel): def _expand_mask(): mask = other.current_mask @@ -520,26 +520,6 @@ def num_choices(self) -> int: return super().num_choices -class SampleOneshotMutableChannel(SimpleMutableChannel): - - def __init__(self, num_channels: int, choices=[2, 4], **kwargs): - super().__init__(num_channels, **kwargs) - self.choices = choices - - -@MODELS.register_module() -class SampleOneshotMutableChannelUnit(OneShotMutableChannelUnit): - - @classmethod - def init_from_mutable_channel( - cls, mutable_channel: SampleOneshotMutableChannel): - return cls( - mutable_channel.num_channels, - candidate_choices=mutable_channel.choices, - candidate_mode='ratio' - if isinstance(mutable_channel.choices[0], float) else 'number') - - class DynamicLinearModel(nn.Module): """ x @@ -569,8 +549,8 @@ def forward(self, x): return self.linear(x1) def _register_mutable(self): - mutable1 = SampleOneshotMutableChannel(8, choices=[1, 4, 8]) - mutable2 = SampleOneshotMutableChannel(16, choices=[2, 8, 16]) + mutable1 = OneShotMutableChannel(8, candidate_choices=[1, 4, 8]) + mutable2 = OneShotMutableChannel(16, candidate_choices=[2, 8, 16]) mutable_value = SampleExpandDerivedMutable(1) MutableChannelContainer.register_mutable_channel_to_module( diff --git a/tests/test_core/test_graph/test_channel_graph.py b/tests/test_core/test_graph/test_channel_graph.py index 263688238..6eb3e1454 100644 --- a/tests/test_core/test_graph/test_channel_graph.py +++ b/tests/test_core/test_graph/test_channel_graph.py @@ -100,7 +100,7 @@ def test_split(self): channel_tensor1 = ChannelTensor(8) channel_tensor2 = ChannelTensor(8) BaseChannelUnit.union_two_units(channel_tensor1.unit_dict[(0, 8)], - channel_tensor2.unit_dict[(0, 8)]) + channel_tensor2.unit_dict[(0, 8)]) unit1 = channel_tensor1.unit_dict[(0, 8)] BaseChannelUnit.split_unit(unit1, [2, 6]) diff --git a/tests/test_models/test_algorithms/test_autoslim.py b/tests/test_models/test_algorithms/test_autoslim.py index e858a8e7d..79169b3cf 100644 --- a/tests/test_models/test_algorithms/test_autoslim.py +++ b/tests/test_models/test_algorithms/test_autoslim.py @@ -35,7 +35,7 @@ type='OneShotMutableChannelUnit', default_args=dict( candidate_choices=list(i / 12 for i in range(2, 13)), - candidate_mode='ratio')), + choice_mode='ratio')), parse_cfg=dict( type='BackwardTracer', loss_calculator=dict(type='ImageClassifierPseudoLoss'))) diff --git a/tests/test_models/test_algorithms/test_prune_algorithm.py b/tests/test_models/test_algorithms/test_prune_algorithm.py index 557cd4465..519407772 100644 --- a/tests/test_models/test_algorithms/test_prune_algorithm.py +++ b/tests/test_models/test_algorithms/test_prune_algorithm.py @@ -59,7 +59,10 @@ def __call__(self, model) -> torch.Tensor: } }) -DEVICE = torch.device('cuda:0') +if torch.cuda.is_available(): + DEVICE = torch.device('cuda:0') +else: + DEVICE = torch.device('cpu') class TestItePruneAlgorithm(unittest.TestCase): diff --git a/tests/test_models/test_mutables/test_units/__init__.py b/tests/test_models/test_mutables/test_mutable_channel/__init__.py similarity index 100% rename from tests/test_models/test_mutables/test_units/__init__.py rename to tests/test_models/test_mutables/test_mutable_channel/__init__.py diff --git a/tests/test_models/test_mutables/test_units/test_mutable_channels.py b/tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py similarity index 100% rename from tests/test_models/test_mutables/test_units/test_mutable_channels.py rename to tests/test_models/test_mutables/test_mutable_channel/test_mutable_channels.py diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py b/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py new file mode 100644 index 000000000..253084d07 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_sequential_mutable_channel.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmrazor.models.mutables import SquentialMutableChannel + + +class TestSquentialMutableChannel(TestCase): + + def _test_mutable(self, + mutable: SquentialMutableChannel, + set_choice, + get_choice, + activate_channels, + mask=None): + mutable.current_choice = set_choice + assert mutable.current_choice == get_choice + assert mutable.activated_channels == activate_channels + if mask is not None: + assert (mutable.current_mask == mask).all() + + def _generate_mask(self, num: int, all: int): + mask = torch.zeros([all]) + mask[0:num] = 1 + return mask.bool() + + def test_mul_float(self): + channel = SquentialMutableChannel(10) + new_channel = channel * 0.5 + self.assertEqual(new_channel.current_choice, 5) + channel.current_choice = 5 + self.assertEqual(new_channel.current_choice, 2) + + def test_int_choice(self): + channel = SquentialMutableChannel(10) + self._test_mutable(channel, 5, 5, 5, self._generate_mask(5, 10)) + self._test_mutable(channel, 0.2, 2, 2, self._generate_mask(2, 10)) + + def test_float_choice(self): + channel = SquentialMutableChannel(10, choice_mode='ratio') + self._test_mutable(channel, 0.5, 0.5, 5, self._generate_mask(5, 10)) + self._test_mutable(channel, 2, 0.2, 2, self._generate_mask(2, 10)) diff --git a/tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_mutables/test_units/test_l1_mutable_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py similarity index 95% rename from tests/test_models/test_mutables/test_units/test_l1_mutable_channel_unit.py rename to tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py index 053e906da..f1a0d8529 100644 --- a/tests/test_models/test_mutables/test_units/test_l1_mutable_channel_unit.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py @@ -5,7 +5,7 @@ from mmrazor.models.mutables import L1MutableChannelUnit from mmrazor.models.mutators import ChannelMutator -from ....data.models import LineModel +from .....data.models import LineModel class TestL1MutableChannelUnit(TestCase): diff --git a/tests/test_models/test_mutables/test_units/test_mutable_channel_units.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py similarity index 95% rename from tests/test_models/test_mutables/test_units/test_mutable_channel_units.py rename to tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py index a0c97b6a1..ad5b5e56b 100644 --- a/tests/test_models/test_mutables/test_units/test_mutable_channel_units.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_mutable_channel_units.py @@ -11,8 +11,8 @@ from mmrazor.models.mutables.mutable_channel.units.channel_unit import ( # noqa Channel, ChannelUnit) from mmrazor.structures.graph import ModuleGraph as ModuleGraph -from ....data.models import LineModel -from ....test_core.test_graph.test_graph import TestGraph +from .....data.models import LineModel +from .....test_core.test_graph.test_graph import TestGraph MUTABLE_CFG = dict(type='SimpleMutablechannel') PARSE_CFG = dict( @@ -84,8 +84,7 @@ def test_init_from_channel_unit(self): graph = ModuleGraph.init_from_backward_tracer(model) units: List[ChannelUnit] = ChannelUnit.init_from_graph(graph) mutable_units = [ - DefaultChannelUnit.init_from_channel_unit(unit) - for unit in units + DefaultChannelUnit.init_from_channel_unit(unit) for unit in units ] self._test_units(mutable_units, model) @@ -123,8 +122,7 @@ def test_replace_with_dynamic_ops(self): model: nn.Module = model_data() graph = ModuleGraph.init_from_backward_tracer(model) units: List[ - MutableChannelUnit] = unit_type.init_from_graph( - graph) + MutableChannelUnit] = unit_type.init_from_graph(graph) for unit in units: unit.prepare_for_pruning(model) diff --git a/tests/test_models/test_mutables/test_units/test_one_shot_mutable_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py similarity index 77% rename from tests/test_models/test_mutables/test_units/test_one_shot_mutable_channel_unit.py rename to tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py index 97e84b474..690382596 100644 --- a/tests/test_models/test_mutables/test_units/test_one_shot_mutable_channel_unit.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_one_shot_mutable_channel_unit.py @@ -8,9 +8,9 @@ class TestSequentialMutableChannelUnit(TestCase): def test_init(self): unit = OneShotMutableChannelUnit( - 48, [20, 30, 40], candidate_mode='number', divisor=8) + 48, [20, 30, 40], choice_mode='number', divisor=8) self.assertSequenceEqual(unit.candidate_choices, [24, 32, 40]) unit = OneShotMutableChannelUnit( - 48, [0.3, 0.5, 0.7], candidate_mode='ratio', divisor=8) + 48, [0.3, 0.5, 0.7], choice_mode='ratio', divisor=8) self.assertSequenceEqual(unit.candidate_choices, [1 / 3, 0.5, 2 / 3]) diff --git a/tests/test_models/test_mutables/test_units/test_sequential_mutable_channel_unit.py b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_sequential_mutable_channel_unit.py similarity index 87% rename from tests/test_models/test_mutables/test_units/test_sequential_mutable_channel_unit.py rename to tests/test_models/test_mutables/test_mutable_channel/test_units/test_sequential_mutable_channel_unit.py index d165ee3f4..8981a8a21 100644 --- a/tests/test_models/test_mutables/test_units/test_sequential_mutable_channel_unit.py +++ b/tests/test_models/test_mutables/test_mutable_channel/test_units/test_sequential_mutable_channel_unit.py @@ -28,14 +28,12 @@ def test_divisor(self): self.assertEqual(unit.current_choice, 24) self.assertTrue(unit.sample_choice() % 8 == 0) - unit = SequentialMutableChannelUnit( - 48, choice_mode='ratio', divisor=8) + unit = SequentialMutableChannelUnit(48, choice_mode='ratio', divisor=8) unit.current_choice = 0.3 self.assertEqual(unit.current_choice, 1 / 3) def test_config_template(self): - unit = SequentialMutableChannelUnit( - 48, choice_mode='ratio', divisor=8) + unit = SequentialMutableChannelUnit(48, choice_mode='ratio', divisor=8) config = unit.config_template(with_init_args=True) unit2 = SequentialMutableChannelUnit.init_from_cfg(None, config) self.assertDictEqual( diff --git a/tests/test_models/test_mutables/test_sequential_mutable_channel.py b/tests/test_models/test_mutables/test_sequential_mutable_channel.py deleted file mode 100644 index f7f4bb91e..000000000 --- a/tests/test_models/test_mutables/test_sequential_mutable_channel.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -from mmrazor.models.mutables import SquentialMutableChannel - - -class TestSquentialMutableChannel(TestCase): - - def test_mul_float(self): - channel = SquentialMutableChannel(10) - new_channel = channel * 0.5 - self.assertEqual(new_channel.current_choice, 5) - channel.current_choice = 5 - self.assertEqual(new_channel.current_choice, 2) diff --git a/tests/test_models/test_mutators/test_channel_mutator.py b/tests/test_models/test_mutators/test_channel_mutator.py index fbd2dad66..96908d807 100644 --- a/tests/test_models/test_mutators/test_channel_mutator.py +++ b/tests/test_models/test_mutators/test_channel_mutator.py @@ -6,7 +6,6 @@ import torch -# from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutables.mutable_channel import ( L1MutableChannelUnit, SequentialMutableChannelUnit) from mmrazor.models.mutators.channel_mutator import ChannelMutator @@ -129,7 +128,7 @@ def test_models_with_predefined_dynamic_op(self): model = Model() mutator = ChannelMutator( channel_unit_cfg={ - 'type': 'SampleOneshotMutableChannelUnit', + 'type': 'OneShotMutableChannelUnit', 'default_args': {} }, parse_cfg={'type': 'Predefined'})