Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor Mutables and Mutators #324

Merged
merged 17 commits into from
Nov 1, 2022
33 changes: 28 additions & 5 deletions mmrazor/models/algorithms/pruning/ite_prune_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self,
channel_unit_cfg=dict(
type='SequentialMutableChannelUnit')),
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
target_pruning_ratio={},
target_pruning_ratio: Optional[Dict[str, float]] = None,
step_epoch=1,
prune_times=1,
init_cfg: Optional[Dict] = None) -> None:
Expand All @@ -119,14 +119,35 @@ def __init__(self,
self.mutator.prepare_from_supernet(self.architecture)

# config_manager
self.check_prune_targe(target_pruning_ratio)
if target_pruning_ratio is None:
group_target_ratio = self.mutator.current_choices
else:
group_target_ratio = self.group_target_pruning_ratio(
target_pruning_ratio, self.mutator.search_groups)

self.prune_config_manager = ItePruneConfigManager(
target_pruning_ratio,
self.mutator.choice_template,
group_target_ratio,
self.mutator.current_choices,
step_epoch,
times=prune_times)

def check_prune_targe(self, config: Dict):
def group_target_pruning_ratio(self, target, search_groups):

group_target = dict()
for group_id, units in search_groups.items():
for unit in units:
unit_name = unit.name
if group_id in group_target:
unit_target = target[unit_name]
assert unit_target == group_target[group_id]
else:
unit_target = target[unit_name]
assert isinstance(unit_target, (float, int))
group_target[group_id] = unit_target

return group_target

def check_prune_target(self, config: Dict):
"""Check if the prune-target is supported."""
for value in config.values():
assert isinstance(value, int) or isinstance(value, float)
Expand All @@ -141,7 +162,9 @@ def forward(self,
self._iteration):

config = self.prune_config_manager.prune_at(self._epoch)

self.mutator.set_choices(config)

logger = MMLogger.get_current_instance()
logger.info(f'The model is pruned at {self._epoch}th epoch once.')

Expand Down
44 changes: 23 additions & 21 deletions mmrazor/models/mutables/base_mutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from mmengine.model import BaseModule

CHOICE_TYPE = TypeVar('CHOICE_TYPE')
CHOSEN_TYPE = TypeVar('CHOSEN_TYPE')
from mmrazor.utils.typing import DumpChosen

Chosen = TypeVar('Chosen')
Choice = TypeVar('Choice')

class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]):

class BaseMutable(BaseModule, ABC, Generic[Choice, Chosen]):
"""Base Class for mutables. Mutable means a searchable module widely used
in Neural Architecture Search(NAS).

Expand All @@ -17,13 +19,12 @@ class BaseMutable(BaseModule, ABC, Generic[CHOICE_TYPE, CHOSEN_TYPE]):

All subclass should implement the following APIs:

- ``forward()``
- ``fix_chosen()``
- ``choices()``
- ``dump_chosen()``
- ``current_choice.setter()``
- ``current_choice.getter()``

Args:
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
Expand All @@ -38,19 +39,18 @@ def __init__(self,

self.alias = alias
self._is_fixed = False
self._current_choice: Optional[CHOICE_TYPE] = None

@property
def current_choice(self) -> Optional[CHOICE_TYPE]:
@property # type: ignore
@abstractmethod
def current_choice(self) -> Choice:
"""Current choice will affect :meth:`forward` and will be used in
:func:`mmrazor.core.subnet.utils.export_fix_subnet` or mutator.
"""
return self._current_choice

@current_choice.setter
def current_choice(self, choice: Optional[CHOICE_TYPE]) -> None:
@current_choice.setter # type: ignore
@abstractmethod
def current_choice(self, choice) -> None:
"""Current choice setter will be executed in mutator."""
self._current_choice = choice

@property
def is_fixed(self) -> bool:
Expand All @@ -76,22 +76,24 @@ def is_fixed(self, is_fixed: bool) -> None:
self._is_fixed = is_fixed

@abstractmethod
def fix_chosen(self, chosen: CHOSEN_TYPE) -> None:
def fix_chosen(self, chosen: Chosen) -> None:
"""Fix mutable with choice. This function would fix the choice of
pppppM marked this conversation as resolved.
Show resolved Hide resolved
Mutable. The :attr:`is_fixed` will be set to True and only the selected
operations can be retained. All subclasses must implement this method.

Note:
This operation is irreversible.
"""
raise NotImplementedError()

# TODO
# type hint
@abstractmethod
def dump_chosen(self) -> CHOSEN_TYPE:
...
def dump_chosen(self) -> DumpChosen:
"""Save the current state of the mutable as a dictionary.

@property
@abstractmethod
def num_choices(self) -> int:
pass
``DumpChosen`` has ``chosen`` and ``meta`` fields. ``chosen`` is
necessary, ``fix_chosen`` will use the ``chosen`` . ``meta`` is used to
store some non-essential information.
"""
raise NotImplementedError()
17 changes: 10 additions & 7 deletions mmrazor/models/mutables/derived_mutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from mmengine.logging import print_log
from torch import Tensor

from mmrazor.utils.typing import DumpChosen
from ..utils import make_divisible
from .base_mutable import CHOICE_TYPE, BaseMutable
from .base_mutable import BaseMutable


class MutableProtocol(Protocol): # pragma: no cover
Expand Down Expand Up @@ -172,8 +173,7 @@ def derive_concat_mutable(
return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)


class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE],
DerivedMethodMixin):
class DerivedMutable(BaseMutable, DerivedMethodMixin):
"""Class for derived mutable.

A derived mutable is a mutable derived from other mutables that has
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(self,

# TODO
# has no effect
def fix_chosen(self, chosen: CHOICE_TYPE) -> None:
def fix_chosen(self, chosen) -> None:
"""Fix mutable with subnet config.

Warning:
Expand All @@ -253,7 +253,7 @@ def fix_chosen(self, chosen: CHOICE_TYPE) -> None:
'which will have no effect.',
level=logging.WARNING)

def dump_chosen(self) -> CHOICE_TYPE:
def dump_chosen(self) -> DumpChosen:
"""Dump information of chosen.

Returns:
Expand All @@ -263,6 +263,9 @@ def dump_chosen(self) -> CHOICE_TYPE:
'Trying to dump chosen for derived mutable, '
'but its value depend on the source mutables.',
level=logging.WARNING)
return DumpChosen(chosen=self.export_chosen(), meta=None)

def export_chosen(self):
return self.current_choice

@property
Expand Down Expand Up @@ -314,12 +317,12 @@ def num_choices(self) -> int:
return 1

@property
def current_choice(self) -> CHOICE_TYPE:
def current_choice(self):
"""Current choice of derived mutable."""
return self.choice_fn()

@current_choice.setter
def current_choice(self, choice: CHOICE_TYPE) -> None:
def current_choice(self, choice) -> None:
"""Setter of current choice.

Raises:
Expand Down
27 changes: 10 additions & 17 deletions mmrazor/models/mutables/mutable_channel/base_mutable_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from mmrazor.utils.typing import DumpChosen
from ..base_mutable import BaseMutable
from ..derived_mutable import DerivedMethodMixin

Expand Down Expand Up @@ -34,20 +35,6 @@ def __init__(self, num_channels: int, **kwargs):
self.name = ''
self.num_channels = num_channels

# choice

@property # type: ignore
@abstractmethod
def current_choice(self):
"""get current choice."""
raise NotImplementedError()

@current_choice.setter # type: ignore
@abstractmethod
def current_choice(self):
"""set current choice."""
raise NotImplementedError()

@property # type: ignore
@abstractmethod
def current_mask(self) -> torch.Tensor:
Expand All @@ -73,9 +60,15 @@ def fix_chosen(self, chosen=None):

self.is_fixed = True

def dump_chosen(self):
"""dump current choice to a dict."""
raise NotImplementedError()
def dump_chosen(self) -> DumpChosen:
pppppM marked this conversation as resolved.
Show resolved Hide resolved
"""Dump chosen."""
meta = dict(max_channels=self.mask.size(0))
chosen = self.export_chosen()

return DumpChosen(chosen=chosen, meta=meta)

def export_chosen(self) -> int:
return self.activated_channels

def num_choices(self) -> int:
"""Number of available choices."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ def fix_chosen(self, chosen=...):
self.current_choice = chosen
self.is_fixed = True

def dump_chosen(self):
"""Dump chosen."""
return self.current_choice

def __rmul__(self, other) -> DerivedMutable:
return self * other

Expand Down
11 changes: 11 additions & 0 deletions mmrazor/models/mutables/mutable_channel/simple_mutable_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from mmrazor.registry import MODELS
from mmrazor.utils.typing import DumpChosen
from ..derived_mutable import DerivedMutable
from .base_mutable_channel import BaseMutableChannel

Expand Down Expand Up @@ -37,6 +38,16 @@ def current_mask(self) -> torch.Tensor:
"""Get current mask."""
return self.current_choice.bool()

def dump_chosen(self) -> DumpChosen:
"""Dump chosen."""
meta = dict(max_channels=self.mask.size(0))
chosen = self.export_chosen()

return DumpChosen(chosen=chosen, meta=meta)

def export_chosen(self) -> int:
return self.activated_channels

# basic extension

def expand_mutable_channel(self, expand_ratio: int) -> DerivedMutable:
Expand Down
13 changes: 7 additions & 6 deletions mmrazor/models/mutables/mutable_channel/units/channel_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ class ChannelUnit(BaseModule):

def __init__(self, num_channels: int, **kwargs):
super().__init__()
self.alias = None
self.num_channels = num_channels
self.output_related: nn.ModuleList = nn.ModuleList()
self.input_related: nn.ModuleList = nn.ModuleList()
self.output_related: List[nn.Module] = list()
self.input_related: List[nn.Module] = list()
self.init_args: Dict = {
} # is used to generate new channel unit with same args

Expand Down Expand Up @@ -208,14 +209,14 @@ def init_from_graph(cls,

def init_from_base_channel_unit(base_channel_unit: BaseChannelUnit):
unit = cls(len(base_channel_unit.channel_elems), **unit_args)
unit.input_related = nn.ModuleList([
unit.input_related = [
Channel.init_from_base_channel(channel)
for channel in base_channel_unit.input_related
])
unit.output_related = nn.ModuleList([
]
unit.output_related = [
Channel.init_from_base_channel(channel)
for channel in base_channel_unit.output_related
])
]
return unit

unit_graph = ChannelGraph.copy_from(graph,
Expand Down
Loading