From d190037a5e351c1880cd1929fb23a0fcc28d446e Mon Sep 17 00:00:00 2001 From: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Date: Fri, 19 Aug 2022 15:15:38 +0800 Subject: [PATCH] [Feature] Add dynamic bricks (#228) * add dynamic bricks * add dynamic conv2d test * add tests for dynamic linear and dynamic norm * add docstring for dynamic conv2d * add docstring for dynamic linear * add docstring for dynamic batchnorm * Refactor the dynamic op ( put more logic into the mixin ) * fix UT * Fix UT ( fileio was moved to mmengine) * derived mutable adds choices property * Unify the register interface of mutable in dynamic op * Unified getter interface of mutable in dynamic op Co-authored-by: gaojianfei Co-authored-by: pppppM --- .../engine/runner/evolution_search_loop.py | 12 +- mmrazor/engine/runner/subnet_sampler_loop.py | 4 +- .../algorithms/pruning/slimmable_network.py | 5 +- .../architectures/dynamic_op/__init__.py | 4 +- .../models/architectures/dynamic_op/base.py | 93 +++- .../dynamic_op/bricks/__init__.py | 14 + .../dynamic_op/bricks/dynamic_conv.py | 176 ++++++++ .../dynamic_op/bricks/dynamic_conv_mixins.py | 406 ++++++++++++++++++ .../dynamic_op/bricks/dynamic_linear.py | 53 +++ .../dynamic_op/bricks/dynamic_mixins.py | 399 +++++++++++++++++ .../dynamic_op/bricks/dynamic_norm.py | 130 ++++++ .../dynamic_op/default_dynamic_ops.py | 30 +- .../architectures/dynamic_op/head/__init__.py | 1 + mmrazor/models/mutables/__init__.py | 7 +- mmrazor/models/mutables/derived_mutable.py | 21 + .../mutable_channel/mutable_channel.py | 9 +- .../slimmable_mutable_channel.py | 8 +- .../models/mutables/mutable_manage_mixin.py | 9 - mmrazor/structures/subnet/fix_subnet.py | 16 +- .../test_algorithms/test_slimmable_network.py | 2 +- .../test_dynamic_op/__init__.py | 1 + .../test_dynamic_op/test_bricks/__init__.py | 1 + .../test_bricks/test_dynamic_conv.py | 266 ++++++++++++ .../test_bricks/test_dynamic_linear.py | 115 +++++ .../test_bricks/test_dynamic_norm.py | 119 +++++ .../test_default_dynamic_op.py | 13 +- .../test_dynamic_op/utils.py | 15 + tests/test_models/test_mutators/utils.py | 2 +- .../test_evolution_search_loop.py | 4 +- 29 files changed, 1876 insertions(+), 59 deletions(-) create mode 100644 mmrazor/models/architectures/dynamic_op/bricks/__init__.py create mode 100644 mmrazor/models/architectures/dynamic_op/bricks/dynamic_conv.py create mode 100644 mmrazor/models/architectures/dynamic_op/bricks/dynamic_conv_mixins.py create mode 100644 mmrazor/models/architectures/dynamic_op/bricks/dynamic_linear.py create mode 100644 mmrazor/models/architectures/dynamic_op/bricks/dynamic_mixins.py create mode 100644 mmrazor/models/architectures/dynamic_op/bricks/dynamic_norm.py create mode 100644 mmrazor/models/architectures/dynamic_op/head/__init__.py delete mode 100644 mmrazor/models/mutables/mutable_manage_mixin.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/__init__.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/__init__.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py create mode 100644 tests/test_models/test_architectures/test_dynamic_op/utils.py diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index e9a29a6db..a6636c238 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -5,8 +5,8 @@ import warnings from typing import Dict, List, Optional, Tuple, Union -import mmcv import torch +from mmengine import fileio from mmengine.dist import broadcast_object_list from mmengine.evaluator import Evaluator from mmengine.runner import EpochBasedTrainLoop @@ -90,7 +90,7 @@ def __init__(self, if init_candidates is None: self.candidates = Candidates() else: - self.candidates = mmcv.fileio.load(init_candidates) + self.candidates = fileio.load(init_candidates) assert isinstance(self.candidates, Candidates), 'please use the \ correct init candidates file' @@ -228,7 +228,7 @@ def _crossover(self) -> SupportRandomSubnet: def _resume(self): """Resume searching.""" if self.runner.rank == 0: - searcher_resume = mmcv.fileio.load(self.resume_from) + searcher_resume = fileio.load(self.resume_from) for k in searcher_resume.keys(): setattr(self, k, searcher_resume[k]) epoch_start = int(searcher_resume['_epoch']) @@ -244,8 +244,8 @@ def _save_best_fix_subnet(self): self.model.set_subnet(best_random_subnet) best_fix_subnet = export_fix_subnet(self.model) save_name = 'best_fix_subnet.yaml' - mmcv.fileio.dump(best_fix_subnet, - osp.join(self.runner.work_dir, save_name)) + fileio.dump(best_fix_subnet, + osp.join(self.runner.work_dir, save_name)) self.runner.logger.info( 'Search finished and ' f'{save_name} saved in {self.runner.work_dir}.') @@ -271,7 +271,7 @@ def _save_searcher_ckpt(self) -> None: save_for_resume['_epoch'] = self.runner.epoch for k in ['candidates', 'top_k_candidates']: save_for_resume[k] = getattr(self, k) - mmcv.fileio.dump( + fileio.dump( save_for_resume, osp.join(self.runner.work_dir, f'search_epoch_{self.runner.epoch}.pkl')) diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index 603cd04b9..b6cd3be59 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -5,8 +5,8 @@ from abc import abstractmethod from typing import Dict, List, Optional, Sequence, Tuple, Union -import mmcv import torch +from mmengine import fileio from mmengine.evaluator import Evaluator from mmengine.runner import IterBasedTrainLoop from mmengine.utils import is_list_of @@ -326,6 +326,6 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool: def _save_candidates(self) -> None: """Save the candidates to init the next searching.""" save_path = os.path.join(self.runner.work_dir, 'candidates.pkl') - mmcv.fileio.dump(self.candidates, save_path) + fileio.dump(self.candidates, save_path) self.runner.logger.info(f'candidates.pkl saved in ' f'{self.runner.work_dir}') diff --git a/mmrazor/models/algorithms/pruning/slimmable_network.py b/mmrazor/models/algorithms/pruning/slimmable_network.py index faf30649a..b8cebf1da 100644 --- a/mmrazor/models/algorithms/pruning/slimmable_network.py +++ b/mmrazor/models/algorithms/pruning/slimmable_network.py @@ -4,9 +4,8 @@ from pathlib import Path from typing import Dict, List, Optional, Union -import mmcv import torch -from mmengine import BaseDataElement +from mmengine import BaseDataElement, fileio from mmengine.model import BaseModel, MMDistributedDataParallel from mmengine.optim import OptimWrapper from torch import nn @@ -86,7 +85,7 @@ def _load_and_merge_channel_cfgs( """Load and merge channel config.""" channel_cfgs = list() for channel_cfg_path in channel_cfg_paths: - channel_cfg = mmcv.fileio.load(channel_cfg_path) + channel_cfg = fileio.load(channel_cfg_path) channel_cfgs.append(channel_cfg) return self.merge_channel_cfgs(channel_cfgs) diff --git a/mmrazor/models/architectures/dynamic_op/__init__.py b/mmrazor/models/architectures/dynamic_op/__init__.py index b6fa0ee43..6b5796688 100644 --- a/mmrazor/models/architectures/dynamic_op/__init__.py +++ b/mmrazor/models/architectures/dynamic_op/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .base import DynamicOP from .default_dynamic_ops import (DynamicBatchNorm, DynamicConv2d, DynamicGroupNorm, DynamicInstanceNorm, DynamicLinear) @@ -6,5 +7,6 @@ __all__ = [ 'DynamicConv2d', 'DynamicLinear', 'DynamicBatchNorm', - 'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d' + 'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d', + 'DynamicOP' ] diff --git a/mmrazor/models/architectures/dynamic_op/base.py b/mmrazor/models/architectures/dynamic_op/base.py index b9d9ea56f..2a1720ea2 100644 --- a/mmrazor/models/architectures/dynamic_op/base.py +++ b/mmrazor/models/architectures/dynamic_op/base.py @@ -1,25 +1,106 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from typing import Any, Optional, Set from torch import nn -from mmrazor.models.mutables.mutable_channel import MutableChannel +from mmrazor.models.mutables.base_mutable import BaseMutable class DynamicOP(ABC): + """Base class for dynamic OP. A dynamic OP usually consists of a normal + static OP and mutables, where mutables are used to control the searchable + (mutable) part of the dynamic OP. + + Note: + When the dynamic OP has just been initialized, its forward propagation + logic should be the same as the corresponding static OP. Only after + the searchable part accepts the specific mutable through the + corresponding interface does the part really become dynamic. + + Note: + All subclass should implement ``to_static_op`` API. + + Args: + accepted_mutables (set): The string set of all accepted mutables. + """ + accepted_mutables: Set[str] = set() @abstractmethod def to_static_op(self) -> nn.Module: - ... + """Convert dynamic OP to static OP. + + Note: + The forward result for the same input between dynamic OP and its + corresponding static OP must be same. + + Returns: + nn.Module: Corresponding static OP. + """ + + def check_if_mutables_fixed(self) -> None: + """Check if all mutables are fixed. + + Raises: + RuntimeError: Error if a existing mutable is not fixed. + """ + + def check_fixed(mutable: Optional[BaseMutable]) -> None: + if mutable is not None and not mutable.is_fixed: + raise RuntimeError(f'Mutable {type(mutable)} is not fixed.') + + for mutable in self.accepted_mutables: + check_fixed(getattr(self, f'{mutable}')) + + @staticmethod + def get_current_choice(mutable: BaseMutable) -> Any: + """Get current choice of given mutable. + + Args: + mutable (BaseMutable): Given mutable. + + Raises: + RuntimeError: Error if `current_choice` is None. + + Returns: + Any: Current choice of given mutable. + """ + current_choice = mutable.current_choice + if current_choice is None: + raise RuntimeError(f'current choice of mutable {type(mutable)} ' + 'can not be None at runtime') + + return current_choice class ChannelDynamicOP(DynamicOP): + """Base class for dynamic OP with mutable channels. + + Note: + All subclass should implement ``mutable_in`` and ``mutable_out`` APIs. + """ @property @abstractmethod - def mutable_in(self) -> MutableChannel: - ... + def mutable_in(self) -> Optional[BaseMutable]: + """Mutable related to input.""" @property - def mutable_out(self) -> MutableChannel: - ... + @abstractmethod + def mutable_out(self) -> Optional[BaseMutable]: + """Mutable related to output.""" + + @staticmethod + def check_mutable_channels(mutable_channels: BaseMutable) -> None: + """Check if mutable has `currnet_mask` attribute. + + Args: + mutable_channels (BaseMutable): Mutable to be checked. + + Raises: + ValueError: Error if mutable does not have `current_mask` + attribute. + """ + if not hasattr(mutable_channels, 'current_mask'): + raise ValueError( + 'channel mutable must have attribute `current_mask`') diff --git a/mmrazor/models/architectures/dynamic_op/bricks/__init__.py b/mmrazor/models/architectures/dynamic_op/bricks/__init__.py new file mode 100644 index 000000000..7eaad52fa --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/bricks/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d +from .dynamic_linear import DynamicLinear +from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, + DynamicLinearMixin, DynamicMixin) +from .dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, + DynamicBatchNorm3d) + +__all__ = [ + 'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', + 'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', + 'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin', + 'DynamicLinearMixin' +] diff --git a/mmrazor/models/architectures/dynamic_op/bricks/dynamic_conv.py b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_conv.py new file mode 100644 index 000000000..66fcc5981 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_conv.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Dict + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.registry import CONV_LAYERS +from torch import Tensor + +from mmrazor.models.mutables.base_mutable import BaseMutable +from mmrazor.registry import MODELS +from .dynamic_conv_mixins import (BigNasConvMixin, DynamicConvMixin, + OFAConvMixin) + + +@CONV_LAYERS.register_module() +@MODELS.register_module() +class DynamicConv2d(nn.Conv2d, DynamicConvMixin): + """Dynamic Conv2d OP. + + Note: + Arguments for ``__init__`` of ``DynamicConv2d`` is totally same as + :obj:`torch.nn.Conv2d`. + + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `in_channels`. The key of the dict must in + ``accepted_mutable_attrs``. + """ + accepted_mutable_attrs = {'in_channels', 'out_channels'} + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # TODO + # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d + assert self.padding_mode == 'zeros' + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + @classmethod + def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d': + """Convert an instance of nn.Conv2d to a new instance of + DynamicConv2d.""" + return cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=True if module.bias is not None else False, + padding_mode=module.padding_mode) + + @property + def conv_func(self) -> Callable: + """The function that will be used in ``forward_mixin``.""" + return F.conv2d + + @property + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return nn.Conv2d + + def forward(self, x: Tensor) -> Tensor: + """Forward of dynamic conv2d OP.""" + return self.forward_mixin(x) + + +@CONV_LAYERS.register_module() +@MODELS.register_module() +class BigNasConv2d(nn.Conv2d, BigNasConvMixin): + """Conv2d used in BigNas. + + Note: + Arguments for ``__init__`` of ``DynamicConv2d`` is totally same as + :obj:`torch.nn.Conv2d`. + + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `in_channels`. The key of the dict must in + ``accepted_mutable_attrs``. + """ + accepted_mutable_attrs = {'in_channels', 'out_channels', 'kernel_size'} + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # TODO + # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d + assert self.padding_mode == 'zeros' + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + @classmethod + def convert_from(cls, module: nn.Conv2d) -> 'BigNasConv2d': + """Convert an instance of `nn.Conv2d` to a new instance of + `BigNasConv2d`.""" + return cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=True if module.bias is not None else False, + padding_mode=module.padding_mode) + + @property + def conv_func(self) -> Callable: + """The function that will be used in ``forward_mixin``.""" + return F.conv2d + + @property + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return nn.Conv2d + + def forward(self, x: Tensor) -> Tensor: + """Forward of bignas' conv2d.""" + return self.forward_mixin(x) + + +@CONV_LAYERS.register_module() +@MODELS.register_module() +class OFAConv2d(nn.Conv2d, OFAConvMixin): + """Conv2d used in `Once-for-All`. + + Refers to `Once-for-All: Train One Network and Specialize it for Efficient + Deployment `_. + """ + """Dynamic Conv2d OP. + + Note: + Arguments for ``__init__`` of ``OFAConv2d`` is totally same as + :obj:`torch.nn.Conv2d`. + + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `in_channels`. The key of the dict must in + ``accepted_mutable_attrs``. + """ + accepted_mutable_attrs = {'in_channels', 'out_channels'} + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # TODO + # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d + assert self.padding_mode == 'zeros' + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + @classmethod + def convert_from(cls, module: nn.Conv2d) -> 'OFAConv2d': + """Convert an instance of `nn.Conv2d` to a new instance of + `OFAConv2d`.""" + return cls( + in_channels=module.in_channels, + out_channels=module.out_channels, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=True if module.bias is not None else False, + padding_mode=module.padding_mode) + + @property + def conv_func(self) -> Callable: + """The function that will be used in ``forward_mixin``.""" + return F.conv2d + + @property + def static_op_factory(self): + """Corresponding Pytorch OP.""" + return nn.Conv2d + + def forward(self, x: Tensor) -> Tensor: + """Forward of OFA's conv2d.""" + return self.forward_mixin(x) diff --git a/mmrazor/models/architectures/dynamic_op/bricks/dynamic_conv_mixins.py b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_conv_mixins.py new file mode 100644 index 000000000..e3ed46ded --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_conv_mixins.py @@ -0,0 +1,406 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from itertools import repeat +from typing import Callable, Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.modules.conv import _ConvNd + +from mmrazor.models.mutables.base_mutable import BaseMutable +from .dynamic_mixins import DynamicChannelMixin + + +def _ntuple(n: int) -> Callable: # pragma: no cover + """Repeat a number n times.""" + + def parse(x): + if isinstance(x, Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def _get_current_kernel_pos(source_kernel_size: int, + target_kernel_size: int) -> Tuple[int, int]: + """Get position of current kernel size. + + Returns: + Tuple[int, int]: (upper left position, bottom right position) + """ + assert source_kernel_size >= target_kernel_size, \ + '`source_kernel_size` must greater or equal than `target_kernel_size`' + + center = source_kernel_size >> 1 + current_offset = target_kernel_size >> 1 + + start_offset = center - current_offset + end_offset = center + current_offset + 1 + + return start_offset, end_offset + + +def _get_same_padding(kernel_size: int, n_dims: int) -> Tuple[int]: + """Get same padding according to kernel size.""" + assert kernel_size & 1 + _pair = _ntuple(n_dims) + + return _pair(kernel_size >> 1) + + +class DynamicConvMixin(DynamicChannelMixin): + """A mixin class for Pytorch conv, which can mutate ``in_channels`` and + ``out_channels``. + + Note: + All subclass should implement ``conv_func``API. + """ + + @property + @abstractmethod + def conv_func(self: _ConvNd): + """The function that will be used in ``forward_mixin``.""" + pass + + def register_mutable_attr(self, attr, mutable): + + if attr == 'in_channels': + self._register_mutable_in_channels(mutable) + elif attr == 'out_channels': + self._register_mutable_out_channels(mutable) + else: + raise NotImplementedError + + def _register_mutable_in_channels( + self: _ConvNd, mutable_in_channels: BaseMutable) -> None: + """Mutate ``in_channels`` with given mutable. + + Args: + mutable_in_channels (BaseMutable): Mutable for controlling + ``in_channels``. + + Raises: + ValueError: Error if size of mask if not same as ``in_channels``. + """ + assert hasattr(self, 'mutable_attrs') + self.check_mutable_channels(mutable_in_channels) + mask_size = mutable_in_channels.current_mask.size(0) + if mask_size != self.in_channels: + raise ValueError( + f'Expect mask size of mutable to be {self.in_channels} as ' + f'`in_channels`, but got: {mask_size}.') + + self.mutable_attrs['in_channels'] = mutable_in_channels + + def _register_mutable_out_channels( + self: _ConvNd, mutable_out_channels: BaseMutable) -> None: + """Mutate ``out_channels`` with given mutable. + + Args: + mutable_out_channels (BaseMutable): Mutable for controlling + ``out_channels``. + + Raises: + ValueError: Error if size of mask if not same as ``out_channels``. + """ + assert hasattr(self, 'mutable_attrs') + self.check_mutable_channels(mutable_out_channels) + mask_size = mutable_out_channels.current_mask.size(0) + if mask_size != self.out_channels: + raise ValueError( + f'Expect mask size of mutable to be {self.out_channels} as ' + f'`out_channels`, but got: {mask_size}.') + + self.mutable_attrs['out_channels'] = mutable_out_channels + + @property + def mutable_in_channels(self: _ConvNd) -> Optional[BaseMutable]: + """Mutable related to input.""" + assert hasattr(self, 'mutable_attrs') + return getattr(self.mutable_attrs, 'in_channels', None) # type:ignore + + @property + def mutable_out_channels(self: _ConvNd) -> Optional[BaseMutable]: + """Mutable related to output.""" + assert hasattr(self, 'mutable_attrs') + return getattr(self.mutable_attrs, 'out_channels', None) # type:ignore + + def get_dynamic_params( + self: _ConvNd) -> Tuple[Tensor, Optional[Tensor], Tuple[int]]: + """Get dynamic parameters that will be used in forward process. + + Returns: + Tuple[Tensor, Optional[Tensor], Tuple[int]]: Sliced weight, bias + and padding. + """ + # slice in/out channel of weight according to + # mutable in_channels/out_channels + weight, bias = self._get_dynamic_params_by_mutable_channels( + self.weight, self.bias) + return weight, bias, self.padding + + def _get_dynamic_params_by_mutable_channels( + self: _ConvNd, weight: Tensor, + bias: Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]: + """Get sliced weight and bias according to ``mutable_in_channels`` and + ``mutable_out_channels``. + + Returns: + Tuple[Tensor, Optional[Tensor]]: Sliced weight and bias. + """ + if 'in_channels' not in self.mutable_attrs and \ + 'out_channels' not in self.mutable_attrs: + return weight, bias + + if 'in_channels' in self.mutable_attrs: + mutable_in_channels = self.mutable_attrs['in_channels'] + in_mask = mutable_in_channels.current_mask.to(weight.device) + else: + in_mask = torch.ones(weight.size(1)).bool().to(weight.device) + + if 'out_channels' in self.mutable_attrs: + mutable_out_channels = self.mutable_attrs['out_channels'] + out_mask = mutable_out_channels.current_mask.to(weight.device) + else: + out_mask = torch.ones(weight.size(0)).bool().to(weight.device) + + if self.groups == 1: + weight = weight[out_mask][:, in_mask] + elif self.groups == self.in_channels == self.out_channels: + # depth-wise conv + weight = weight[out_mask] + else: + raise NotImplementedError( + 'Current `ChannelMutator` only support pruning the depth-wise ' + '`nn.Conv2d` or `nn.Conv2d` module whose group number equals ' + f'to one, but got {self.groups}.') + bias = self.bias[out_mask] if self.bias is not None else None + return weight, bias + + def forward_mixin(self: _ConvNd, x: Tensor) -> Tensor: + """Forward of dynamic conv2d OP.""" + groups = self.groups + if self.groups == self.in_channels == self.out_channels: + groups = x.size(1) + weight, bias, padding = self.get_dynamic_params() + + return self.conv_func(x, weight, bias, self.stride, padding, + self.dilation, groups) + + def to_static_op(self: _ConvNd) -> nn.Conv2d: + """Convert dynamic conv2d to :obj:`torch.nn.Conv2d`. + + Returns: + torch.nn.Conv2d: :obj:`torch.nn.Conv2d` with sliced parameters. + """ + self.check_if_mutables_fixed() + + weight, bias, padding = self.get_dynamic_params() + groups = self.groups + if groups == self.in_channels == self.out_channels and \ + self.mutable_in_channels is not None: + mutable_in_channels = self.mutable_attrs['in_channels'] + groups = mutable_in_channels.current_mask.sum().item() + out_channels = weight.size(0) + in_channels = weight.size(1) * groups + + kernel_size = tuple(weight.shape[2:]) + + static_conv = self.static_op_factory( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=self.stride, + padding=padding, + padding_mode=self.padding_mode, + dilation=self.dilation, + groups=groups, + bias=True if bias is not None else False) + + static_conv.weight = nn.Parameter(weight) + if bias is not None: + static_conv.bias = nn.Parameter(bias) + + return static_conv + + +class BigNasConvMixin(DynamicConvMixin): + """A mixin class for Pytorch conv, which can mutate ``in_channels``, + ``out_channels`` and ``kernel_size``.""" + + def register_mutable_attr(self, attr, mutable): + + if attr == 'in_channels': + self._register_mutable_in_channels(mutable) + elif attr == 'out_channels': + self._register_mutable_out_channels(mutable) + elif attr == 'kernel_size': + self._register_mutable_kernel_size(mutable) + else: + raise NotImplementedError + + def _register_mutable_kernel_size( + self: _ConvNd, mutable_kernel_size: BaseMutable) -> None: + """Mutate ``kernel_size`` with given mutable. + + Args: + mutable_kernel_size (BaseMutable): Mutable for controlling + ``kernel_size``. + + Note: + ``kernel_size_seq`` must be provided if ``mutable_kernel_size`` + does not have ``choices`` attribute. + + Raises: + ValueError: Error if max choice of ``kernel_size_list`` + not same as ``kernel_size``. + """ + + kernel_size_seq = getattr(mutable_kernel_size, 'choices', None) + if kernel_size_seq is None or len(kernel_size_seq) == 0: + raise ValueError('kernel size sequence must be provided') + kernel_size_list = list(sorted(kernel_size_seq)) + + _pair = _ntuple(len(self.weight.shape) - 2) + max_kernel_size = _pair(kernel_size_list[-1]) + if max_kernel_size != self.kernel_size: + raise ValueError( + f'Expect max kernel size to be: {self.kernel_size}, ' + f'but got: {max_kernel_size}') + + self.kernel_size_list = kernel_size_list + self.mutable_attrs['kernel_size'] = mutable_kernel_size + + def get_dynamic_params( + self: _ConvNd) -> Tuple[Tensor, Optional[Tensor], Tuple[int]]: + """Get dynamic parameters that will be used in forward process. + + Returns: + Tuple[Tensor, Optional[Tensor], Tuple[int]]: Sliced weight, bias + and padding. + """ + # 1. slice kernel size of weight according to kernel size mutable + weight, padding = self._get_dynamic_params_by_mutable_kernel_size( + self.weight) + + # 2. slice in/out channel of weight according to mutable in_channels + # and mutable out channels. + weight, bias = self._get_dynamic_params_by_mutable_channels( + weight, self.bias) + return weight, bias, padding + + def _get_dynamic_params_by_mutable_kernel_size( + self: _ConvNd, weight: Tensor) -> Tuple[Tensor, Tuple[int]]: + """Get sliced weight and bias according to ``mutable_in_channels`` and + ``mutable_out_channels``.""" + if 'kernel_size' not in self.mutable_attrs \ + or self.kernel_size_list is None: + return weight, self.padding + + mutable_kernel_size = self.mutable_attrs['kernel_size'] + current_kernel_size = self.get_current_choice(mutable_kernel_size) + + n_dims = len(self.weight.shape) - 2 + current_padding = _get_same_padding(current_kernel_size, n_dims) + + _pair = _ntuple(len(self.weight.shape) - 2) + if _pair(current_kernel_size) == self.kernel_size: + return weight, current_padding + + start_offset, end_offset = _get_current_kernel_pos( + source_kernel_size=self.kernel_size[0], + target_kernel_size=current_kernel_size) + current_weight = \ + weight[:, :, start_offset:end_offset, start_offset:end_offset] + + return current_weight, current_padding + + def forward_mixin(self: _ConvNd, x: Tensor) -> Tensor: + """Forward of dynamic conv2d OP.""" + groups = self.groups + if self.groups == self.in_channels == self.out_channels: + groups = x.size(1) + weight, bias, padding = self.get_dynamic_params() + + return self.conv_func(x, weight, bias, self.stride, padding, + self.dilation, groups) + + +class OFAConvMixin(BigNasConvMixin): + """A mixin class for Pytorch conv, which can mutate ``in_channels``, + ``out_channels`` and ``kernel_size``.""" + + def _register_mutable_kernel_size( + self: _ConvNd, mutable_kernel_size: BaseMutable) -> None: + """Mutate ``kernel_size`` with given mutable and register + transformation matrix.""" + super()._register_mutable_kernel_size(mutable_kernel_size) + self._register_trans_matrix() + + def _register_trans_matrix(self: _ConvNd) -> None: + """Register transformation matrix that used in progressive + shrinking.""" + assert self.kernel_size_list is not None + + trans_matrix_names = [] + for i in range(len(self.kernel_size_list) - 1, 0, -1): + source_kernel_size = self.kernel_size_list[i] + target_kernel_size = self.kernel_size_list[i - 1] + trans_matrix_name = self._get_trans_matrix_name( + src=source_kernel_size, tar=target_kernel_size) + trans_matrix_names.append(trans_matrix_name) + # TODO support conv1d & conv3d + trans_matrix = nn.Parameter(torch.eye(target_kernel_size**2)) + self.register_parameter(name=trans_matrix_name, param=trans_matrix) + self._trans_matrix_names = trans_matrix_names + + @staticmethod + def _get_trans_matrix_name(src: int, tar: int) -> str: + """Get name of trans matrix.""" + return f'trans_matrix_{src}to{tar}' + + def _get_dynamic_params_by_mutable_kernel_size( + self: _ConvNd, weight: Tensor) -> Tuple[Tensor, Tuple[int]]: + """Get sliced weight and bias according to ``mutable_in_channels`` and + ``mutable_out_channels``.""" + + if 'kernel_size' not in self.mutable_attrs: + return weight, self.padding + + mutable_kernel_size = self.mutable_attrs['kernel_size'] + current_kernel_size = self.get_current_choice(mutable_kernel_size) + + n_dims = len(self.weight.shape) - 2 + current_padding = _get_same_padding(current_kernel_size, n_dims) + + _pair = _ntuple(len(self.weight.shape) - 2) + if _pair(current_kernel_size) == self.kernel_size: + return weight, current_padding + + current_weight = weight[:, :, :, :] + for i in range(len(self.kernel_size_list) - 1, 0, -1): + source_kernel_size = self.kernel_size_list[i] + if source_kernel_size <= current_kernel_size: + break + target_kernel_size = self.kernel_size_list[i - 1] + trans_matrix = getattr( + self, + self._get_trans_matrix_name( + src=source_kernel_size, tar=target_kernel_size)) + + start_offset, end_offset = _get_current_kernel_pos( + source_kernel_size=source_kernel_size, + target_kernel_size=target_kernel_size) + target_weight = current_weight[:, :, start_offset:end_offset, + start_offset:end_offset] + target_weight = target_weight.reshape(-1, target_kernel_size**2) + target_weight = F.linear(target_weight, trans_matrix) + target_weight = target_weight.reshape( + weight.size(0), weight.size(1), target_kernel_size, + target_kernel_size) + + current_weight = target_weight + + return current_weight, current_padding diff --git a/mmrazor/models/architectures/dynamic_op/bricks/dynamic_linear.py b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_linear.py new file mode 100644 index 000000000..aa7bcbccc --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_linear.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmrazor.models.mutables.base_mutable import BaseMutable +from .dynamic_mixins import DynamicLinearMixin + + +class DynamicLinear(nn.Linear, DynamicLinearMixin): + """Dynamic Linear OP. + + Note: + Arguments for ``__init__`` of ``DynamicLinear`` is totally same as + :obj:`torch.nn.Linear`. + + Attributes: + mutable_in_features (BaseMutable, optional): Mutable for controlling + ``in_features``. + mutable_out_features (BaseMutable, optional): Mutable for controlling + ``out_features``. + """ + accepted_mutable_attrs = {'in_features', 'out_features'} + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.mutable_attrs: Dict[str, BaseMutable] = nn.ModuleDict() + + @property + def static_op_factory(self): + return nn.Linear + + @classmethod + def convert_from(cls, module): + """Convert a nn.Linear module to a DynamicLinear. + + Args: + module (:obj:`torch.nn.Linear`): The original Linear module. + """ + dynamic_linear = cls( + in_features=module.in_features, + out_features=module.out_features, + bias=True if module.bias is not None else False) + return dynamic_linear + + def forward(self, input: Tensor) -> Tensor: + """Forward of dynamic linear OP.""" + weight, bias = self.get_dynamic_params() + + return F.linear(input, weight, bias) diff --git a/mmrazor/models/architectures/dynamic_op/bricks/dynamic_mixins.py b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_mixins.py new file mode 100644 index 000000000..4837bb12b --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_mixins.py @@ -0,0 +1,399 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Set, Tuple + +import torch +from mmengine import print_log +from torch import Tensor, nn +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.mutables.base_mutable import BaseMutable + + +class DynamicMixin(ABC): + """Base class for dynamic OP. A dynamic OP usually consists of a normal + static OP and mutables, where mutables are used to control the searchable + (mutable) part of the dynamic OP. + + Note: + When the dynamic OP has just been initialized, its forward propagation + logic should be the same as the corresponding static OP. Only after + the searchable part accepts the specific mutable through the + corresponding interface does the part really become dynamic. + + Note: + All subclass should implement ``to_static_op`` and + ``static_op_factory`` APIs. + + Args: + accepted_mutables (set): The string set of all accepted mutables. + """ + accepted_mutable_attrs: Set[str] = set() + attr_mappings: Dict[str, str] = dict() + + @abstractmethod + def register_mutable_attr(self, attr: str, mutable: BaseMutable): + pass + + def get_mutable_attr(self, attr: str) -> BaseMutable: + + self.check_mutable_attr_valid(attr) + if attr in self.attr_mappings: + attr_map = self.attr_mappings[attr] + return getattr(self.mutable_attrs, attr_map, None) # type:ignore + else: + return getattr(self.mutable_attrs, attr, None) # type:ignore + + @classmethod + @abstractmethod + def convert_from(cls, module): + """Convert an instance of Pytorch module to a new instance of Dynamic + module.""" + + @property + @abstractmethod + def static_op_factory(self): + """Corresponding Pytorch OP.""" + + @abstractmethod + def to_static_op(self) -> nn.Module: + """Convert dynamic OP to static OP. + + Note: + The forward result for the same input between dynamic OP and its + corresponding static OP must be same. + + Returns: + nn.Module: Corresponding static OP. + """ + + def check_if_mutables_fixed(self) -> None: + """Check if all mutables are fixed. + + Raises: + RuntimeError: Error if a existing mutable is not fixed. + """ + + def check_fixed(mutable: Optional[BaseMutable]) -> None: + if mutable is not None and not mutable.is_fixed: + raise RuntimeError(f'Mutable {type(mutable)} is not fixed.') + + for mutable in self.mutable_attrs.values(): # type: ignore + check_fixed(mutable) + + def check_mutable_attr_valid(self, attr): + assert attr in self.attr_mappings or \ + attr in self.accepted_mutable_attrs + + @staticmethod + def get_current_choice(mutable: BaseMutable) -> Any: + """Get current choice of given mutable. + + Args: + mutable (BaseMutable): Given mutable. + + Raises: + RuntimeError: Error if `current_choice` is None. + + Returns: + Any: Current choice of given mutable. + """ + current_choice = mutable.current_choice + if current_choice is None: + raise RuntimeError(f'current choice of mutable {type(mutable)} ' + 'can not be None at runtime') + + return current_choice + + +class DynamicChannelMixin(DynamicMixin): + """Base class for dynamic OP with mutable channels. + + Note: + All subclass should implement ``mutable_in_channels`` and + ``mutable_out_channels`` APIs. + """ + + @staticmethod + def check_mutable_channels(mutable_channels: BaseMutable) -> None: + """Check if mutable has `currnet_mask` attribute. + + Args: + mutable_channels (BaseMutable): Mutable to be checked. + + Raises: + ValueError: Error if mutable does not have `current_mask` + attribute. + """ + if not hasattr(mutable_channels, 'current_mask'): + raise ValueError( + 'channel mutable must have attribute `current_mask`') + + +class DynamicBatchNormMixin(DynamicChannelMixin): + """A mixin class for Pytorch BatchNorm, which can mutate + ``num_features``.""" + accepted_mutable_attrs: Set[str] = {'num_features'} + attr_mappings: Dict[str, str] = { + 'in_channels': 'num_features', + 'out_channels': 'num_features', + } + + def register_mutable_attr(self, attr, mutable): + self.check_mutable_attr_valid(attr) + if attr in self.attr_mappings: + attr_map = self.attr_mappings[attr] + assert attr_map in self.accepted_mutable_attrs + if attr_map in self.mutable_attrs: + print_log( + f'{attr_map}({attr}) is already in `mutable_attrs`', + level=logging.WARNING) + else: + self._register_mutable_attr(attr_map, mutable) + elif attr in self.accepted_mutable_attrs: + self._register_mutable_attr(attr, mutable) + else: + raise NotImplementedError + + def _register_mutable_attr(self, attr, mutable): + + if attr == 'num_features': + self._register_mutable_num_features(mutable) + else: + raise NotImplementedError + + def _register_mutable_num_features( + self: _BatchNorm, mutable_num_features: BaseMutable) -> None: + """Mutate ``num_features`` with given mutable. + + Args: + mutable_num_features (BaseMutable): Mutable for controlling + ``num_features``. + + Raises: + RuntimeError: Error if both ``affine`` and + ``tracking_running_stats`` are False. + ValueError: Error if size of mask if not same as ``num_features``. + """ + if not self.affine and not self.track_running_stats: + raise RuntimeError( + 'num_features can not be mutated if both `affine` and ' + '`tracking_running_stats` are False') + + self.check_mutable_channels(mutable_num_features) + mask_size = mutable_num_features.current_mask.size(0) + if mask_size != self.num_features: + raise ValueError( + f'Expect mask size of mutable to be {self.num_features} as ' + f'`num_features`, but got: {mask_size}.') + + self.mutable_attrs['num_features'] = mutable_num_features + + def _get_num_features_mask(self: _BatchNorm) -> Optional[torch.Tensor]: + """Get mask of ``num_features``""" + if self.affine: + refer_tensor = self.weight + elif self.track_running_stats: + refer_tensor = self.running_mean + else: + return None + + if 'num_features' in self.mutable_attrs: + out_mask = self.mutable_attrs['num_features'].current_mask.to( + refer_tensor.device) + else: + out_mask = torch.ones_like(refer_tensor).bool() + + return out_mask + + def get_dynamic_params( + self: _BatchNorm + ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], + Optional[Tensor]]: + """Get dynamic parameters that will be used in forward process. + + Returns: + Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], + Optional[Tensor]]: Sliced running_mean, running_var, weight and + bias. + """ + out_mask = self._get_num_features_mask() + + if self.affine: + weight = self.weight[out_mask] + bias = self.bias[out_mask] + else: + weight, bias = self.weight, self.bias + + if self.track_running_stats: + running_mean = self.running_mean[out_mask] \ + if not self.training or self.track_running_stats else None + running_var = self.running_var[out_mask] \ + if not self.training or self.track_running_stats else None + else: + running_mean, running_var = self.running_mean, self.running_var + + return running_mean, running_var, weight, bias + + def to_static_op(self: _BatchNorm) -> nn.Module: + """Convert dynamic BatchNormxd to :obj:`torch.nn.BatchNormxd`. + + Returns: + torch.nn.BatchNormxd: :obj:`torch.nn.BatchNormxd` with sliced + parameters. + """ + self.check_if_mutables_fixed() + + running_mean, running_var, weight, bias = self.get_dynamic_params() + if 'num_features' in self.mutable_attrs: + num_features = self.mutable_attrs['num_features'].current_mask.sum( + ).item() + else: + num_features = self.num_features + + static_bn = self.static_op_factory( + num_features=num_features, + eps=self.eps, + momentum=self.momentum, + affine=self.affine, + track_running_stats=self.track_running_stats) + + if running_mean is not None: + static_bn.running_mean.copy_(running_mean) + if running_var is not None: + static_bn.running_var.copy_(running_var) + if weight is not None: + static_bn.weight = nn.Parameter(weight) + if bias is not None: + static_bn.bias = nn.Parameter(bias) + + return static_bn + + +class DynamicLinearMixin(DynamicChannelMixin): + """A mixin class for Pytorch Linear, which can mutate ``in_features`` and + ``out_features``.""" + + accepted_mutable_attrs: Set[str] = {'in_features', 'out_features'} + attr_mappings: Dict[str, str] = { + 'in_channels': 'in_features', + 'out_channels': 'out_features', + } + + def register_mutable_attr(self, attr, mutable): + self.check_mutable_attr_valid(attr) + if attr in self.attr_mappings: + attr_map = self.attr_mappings[attr] + assert attr_map in self.accepted_mutable_attrs + if attr_map in self.mutable_attrs: + print_log( + f'{attr_map}({attr}) is already in `mutable_attrs`', + level=logging.WARNING) + else: + self._register_mutable_attr(attr_map, mutable) + elif attr in self.accepted_mutable_attrs: + self._register_mutable_attr(attr, mutable) + else: + raise NotImplementedError + + def _register_mutable_attr(self, attr, mutable): + + if attr == 'in_features': + self._register_mutable_in_features(mutable) + elif attr == 'out_features': + self._register_mutable_out_features(mutable) + else: + raise NotImplementedError + + def _register_mutable_in_features( + self: nn.Linear, mutable_in_features: BaseMutable) -> None: + """Mutate ``in_features`` with given mutable. + + Args: + mutable_in_features (BaseMutable): Mutable for controlling + ``in_features``. + + Raises: + ValueError: Error if size of mask if not same as ``in_features``. + """ + self.check_mutable_channels(mutable_in_features) + mask_size = mutable_in_features.current_mask.size(0) + if mask_size != self.in_features: + raise ValueError( + f'Expect mask size of mutable to be {self.in_features} as ' + f'`in_features`, but got: {mask_size}.') + + self.mutable_attrs['in_features'] = mutable_in_features + + def _register_mutable_out_features( + self: nn.Linear, mutable_out_features: BaseMutable) -> None: + """Mutate ``out_features`` with given mutable. + + Args: + mutable_out_features (BaseMutable): Mutable for controlling + ``out_features``. + + Raises: + ValueError: Error if size of mask if not same as ``out_features``. + """ + self.check_mutable_channels(mutable_out_features) + mask_size = mutable_out_features.current_mask.size(0) + if mask_size != self.out_features: + raise ValueError( + f'Expect mask size of mutable to be {self.out_features} as ' + f'`in_features`, but got: {mask_size}.') + + self.mutable_attrs['out_features'] = mutable_out_features + + def get_dynamic_params(self: nn.Linear) -> Tuple[Tensor, Optional[Tensor]]: + """Get dynamic parameters that will be used in forward process. + + Returns: + Tuple[Tensor, Optional[Tensor]]: Sliced weight and bias. + """ + if 'in_features' not in self.mutable_attrs and \ + 'out_features' not in self.mutable_attrs: + return self.weight, self.bias + + if 'in_features' in self.mutable_attrs: + in_mask = self.mutable_attrs['in_features'].current_mask.to( + self.weight.device) + else: + in_mask = torch.ones(self.weight.size(1)).bool().to( + self.weight.device) + if 'out_features' in self.mutable_attrs: + + out_mask = self.mutable_attrs['out_features'].current_mask.to( + self.weight.device) + else: + out_mask = torch.ones(self.weight.size(0)).bool().to( + self.weight.device) + + weight = self.weight[out_mask][:, in_mask] + bias = self.bias[out_mask] if self.bias is not None else None + + return weight, bias + + def to_static_op(self: nn.Linear) -> nn.Module: + """Convert to :obj:`torch.nn.Linear`. + + Returns: + nn.Linear: :obj:`torch.nn.Linear` with sliced parameters. + """ + self.check_if_mutables_fixed() + + weight, bias = self.get_dynamic_params() + out_features = weight.size(0) + in_features = weight.size(1) + + static_linear = self.static_op_factory( + in_features=in_features, + out_features=out_features, + bias=True if bias is not None else False) + + static_linear.weight = nn.Parameter(weight) + if bias is not None: + static_linear.bias = nn.Parameter(bias) + + return static_linear diff --git a/mmrazor/models/architectures/dynamic_op/bricks/dynamic_norm.py b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_norm.py new file mode 100644 index 000000000..f88cb563a --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/bricks/dynamic_norm.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks.registry import NORM_LAYERS +from torch import Tensor +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.mutables.base_mutable import BaseMutable +from .dynamic_mixins import DynamicBatchNormMixin + + +class _DynamicBatchNorm(_BatchNorm, DynamicBatchNormMixin): + """Dynamic BatchNormxd OP. + + Note: + Arguments for ``__init__`` of ``DynamicBatchNormxd`` is totally same as + :obj:`torch.nn.BatchNormxd`. + + Attributes: + mutable_attrs (ModuleDict[str, BaseMutable]): Mutable attributes, + such as `num_features`. The key of the dict must in + ``accepted_mutable_attrs``. + """ + accepted_mutable_attrs = {'num_features'} + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.mutable_attrs: Dict[str, Optional[BaseMutable]] = nn.ModuleDict() + + @classmethod + def convert_from(cls, module: _BatchNorm): + """Convert a _BatchNorm module to a DynamicBatchNorm. + + Args: + module (:obj:`torch.nn._BatchNorm`): The original BatchNorm module. + """ + dynamic_bn = cls( + num_features=module.num_features, + eps=module.eps, + momentum=module.momentum, + affine=module.affine, + track_running_stats=module.track_running_stats) + + return dynamic_bn + + def forward(self, input: Tensor) -> Tensor: + """Forward of dynamic BatchNormxd OP.""" + self._check_input_dim(input) + + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: # type: ignore + self.num_batches_tracked = \ + self.num_batches_tracked + 1 # type: ignore + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float( + self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is + None) + + running_mean, running_var, weight, bias = self.get_dynamic_params() + + out = F.batch_norm(input, running_mean, running_var, weight, bias, + bn_training, exponential_average_factor, self.eps) + + # copy changed running statistics + if self.training and self.track_running_stats: + out_mask = self._get_num_features_mask() + self.running_mean.masked_scatter_(out_mask, running_mean) + self.running_var.masked_scatter_(out_mask, running_var) + + return out + + +@NORM_LAYERS.register_module() +class DynamicBatchNorm1d(_DynamicBatchNorm): + """Dynamic BatchNorm1d OP.""" + + @property + def static_op_factory(self): + return nn.BatchNorm1d + + def _check_input_dim(self, input: Tensor) -> None: + """Check if input dimension is valid.""" + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)'.format( + input.dim())) + + +@NORM_LAYERS.register_module() +class DynamicBatchNorm2d(_DynamicBatchNorm): + """Dynamic BatchNorm2d OP.""" + + @property + def static_op_factory(self): + return nn.BatchNorm2d + + def _check_input_dim(self, input: Tensor) -> None: + """Check if input dimension is valid.""" + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)'.format( + input.dim())) + + +@NORM_LAYERS.register_module() +class DynamicBatchNorm3d(_DynamicBatchNorm): + """Dynamic BatchNorm3d OP.""" + + @property + def static_op_factory(self): + return nn.BatchNorm3d + + def _check_input_dim(self, input: Tensor) -> None: + """Check if input dimension is valid.""" + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)'.format( + input.dim())) diff --git a/mmrazor/models/architectures/dynamic_op/default_dynamic_ops.py b/mmrazor/models/architectures/dynamic_op/default_dynamic_ops.py index 1e1bf2cb5..2488a49eb 100644 --- a/mmrazor/models/architectures/dynamic_op/default_dynamic_ops.py +++ b/mmrazor/models/architectures/dynamic_op/default_dynamic_ops.py @@ -11,7 +11,6 @@ from mmrazor.models.mutables.mutable_channel import MutableChannel from mmrazor.registry import MODELS -from ...mutables import MutableManageMixIn from .base import ChannelDynamicOP @@ -24,6 +23,7 @@ class DynamicConv2d(nn.Conv2d, ChannelDynamicOP): in_channels_cfg (Dict): Config related to `in_channels`. out_channels_cfg (Dict): Config related to `out_channels`. """ + accepted_mutables = {'mutable_in_channels', 'mutable_out_channels'} def __init__(self, in_channels_cfg, out_channels_cfg, *args, **kwargs): super(DynamicConv2d, self).__init__(*args, **kwargs) @@ -111,7 +111,7 @@ def to_static_op(self) -> nn.Conv2d: return static_conv2d -class DynamicLinear(nn.Linear, MutableManageMixIn): +class DynamicLinear(nn.Linear, ChannelDynamicOP): """Applies a linear transformation to the incoming data according to the `mutable_in_features` and `mutable_out_features` dynamically. @@ -119,6 +119,7 @@ class DynamicLinear(nn.Linear, MutableManageMixIn): in_features_cfg (Dict): Config related to `in_features`. out_features_cfg (Dict): Config related to `out_features`. """ + accepted_mutables = {'mutable_in_features', 'mutable_out_features'} def __init__(self, in_features_cfg, out_features_cfg, *args, **kwargs): super(DynamicLinear, self).__init__(*args, **kwargs) @@ -153,14 +154,19 @@ def forward(self, input: Tensor) -> Tensor: return F.linear(input, weight, bias) + # TODO + def to_static_op(self) -> nn.Module: + return self -class DynamicBatchNorm(_BatchNorm, MutableManageMixIn): + +class DynamicBatchNorm(_BatchNorm, ChannelDynamicOP): """Applies Batch Normalization over an input according to the `mutable_num_features` dynamically. Args: num_features_cfg (Dict): Config related to `num_features`. """ + accepted_mutables = {'mutable_num_features'} def __init__(self, num_features_cfg, *args, **kwargs): super(DynamicBatchNorm, self).__init__(*args, **kwargs) @@ -224,14 +230,19 @@ def forward(self, input: Tensor) -> Tensor: return F.batch_norm(input, running_mean, running_var, weight, bias, bn_training, exponential_average_factor, self.eps) + # TODO + def to_static_op(self) -> nn.Module: + return self + -class DynamicInstanceNorm(_InstanceNorm, MutableManageMixIn): +class DynamicInstanceNorm(_InstanceNorm, ChannelDynamicOP): """Applies Instance Normalization over an input according to the `mutable_num_features` dynamically. Args: num_features_cfg (Dict): Config related to `num_features`. """ + accepted_mutables = {'mutable_num_features'} def __init__(self, num_features_cfg, *args, **kwargs): super(DynamicInstanceNorm, self).__init__(*args, **kwargs) @@ -273,14 +284,19 @@ def forward(self, input: Tensor) -> Tensor: self.training or not self.track_running_stats, self.momentum, self.eps) + # TODO + def to_static_op(self) -> nn.Module: + return self -class DynamicGroupNorm(GroupNorm, MutableManageMixIn): + +class DynamicGroupNorm(GroupNorm, ChannelDynamicOP): """Applies Group Normalization over a mini-batch of inputs according to the `mutable_num_channels` dynamically. Args: num_channels_cfg (Dict): Config related to `num_channels`. """ + accepted_mutables = {'mutable_num_features'} def __init__(self, num_channels_cfg, *args, **kwargs): super(DynamicGroupNorm, self).__init__(*args, **kwargs) @@ -311,3 +327,7 @@ def forward(self, input: Tensor) -> Tensor: weight, bias = self.weight, self.bias return F.group_norm(input, self.num_groups, weight, bias, self.eps) + + # TODO + def to_static_op(self) -> nn.Module: + return self diff --git a/mmrazor/models/architectures/dynamic_op/head/__init__.py b/mmrazor/models/architectures/dynamic_op/head/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/mmrazor/models/architectures/dynamic_op/head/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/mmrazor/models/mutables/__init__.py b/mmrazor/models/mutables/__init__.py index 123e597ae..917364607 100644 --- a/mmrazor/models/mutables/__init__.py +++ b/mmrazor/models/mutables/__init__.py @@ -2,14 +2,13 @@ from .derived_mutable import DerivedMutable from .mutable_channel import (MutableChannel, OneShotMutableChannel, SlimmableMutableChannel) -from .mutable_manage_mixin import MutableManageMixIn from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP, OneShotMutableModule, OneShotMutableOP) from .mutable_value import MutableValue, OneShotMutableValue __all__ = [ 'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP', - 'DiffChoiceRoute', 'DiffMutableModule', 'MutableManageMixIn', - 'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel', - 'DerivedMutable', 'MutableValue', 'OneShotMutableValue' + 'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel', + 'SlimmableMutableChannel', 'MutableChannel', 'DerivedMutable', + 'MutableValue', 'OneShotMutableValue' ] diff --git a/mmrazor/models/mutables/derived_mutable.py b/mmrazor/models/mutables/derived_mutable.py index 9a4454c0e..5e991e9fe 100644 --- a/mmrazor/models/mutables/derived_mutable.py +++ b/mmrazor/models/mutables/derived_mutable.py @@ -8,6 +8,7 @@ import inspect import logging +from itertools import product from typing import Any, Callable, Dict, Iterable, Optional, Set, Union import torch @@ -279,6 +280,26 @@ def is_fixed(self, is_fixed: bool) -> bool: raise RuntimeError( '`is_fixed` of derived mutable should not be modified directly') + @property + def choices(self): + origin_choices = [m.current_choice for m in self.source_mutables] + + all_choices = [m.choices for m in self.source_mutables] + + product_choices = product(*all_choices) + + derived_choices = list() + for item_choices in product_choices: + for m, choice in zip(self.source_mutables, item_choices): + m.current_choice = choice + + derived_choices.append(self.choice_fn()) + + for m, choice in zip(self.source_mutables, origin_choices): + m.current_choice = choice + + return derived_choices + @property def num_choices(self) -> int: """Number of all choices. diff --git a/mmrazor/models/mutables/mutable_channel/mutable_channel.py b/mmrazor/models/mutables/mutable_channel/mutable_channel.py index f3ba2063e..af2bf2188 100644 --- a/mmrazor/models/mutables/mutable_channel/mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/mutable_channel.py @@ -78,11 +78,7 @@ def current_mask(self): mutable.current_mask for mutable in self.concat_parent_mutables ]) else: - # TODO - if self.is_fixed: - return self.convert_choice_to_mask(0) - else: - return self.convert_choice_to_mask(self.current_choice) + return self.convert_choice_to_mask(self.current_choice) def bind_mutable_name(self, name: str) -> None: """Bind a MutableChannel to its name. @@ -105,9 +101,6 @@ def fix_chosen(self, chosen: CHOSEN_TYPE) -> None: 'The mode of current MUTABLE is `fixed`. ' 'Please do not call `fix_chosen` function again.') - # TODO - # should fixed op still have candidate_choices? - self._chosen = chosen self.is_fixed = True def __repr__(self): diff --git a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py b/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py index ebf8b41ef..5d4dec0e7 100644 --- a/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py +++ b/mmrazor/models/mutables/mutable_channel/slimmable_mutable_channel.py @@ -72,8 +72,8 @@ def fix_chosen(self, dumped_chosen: Dict) -> None: # TODO # remove after remove `current_choice` - self.current_choice = self._candidate_choices.index(chosen) - self._candidate_choices = [chosen] + self.current_choice = self.candidate_choices.index(chosen) + self._chosen = chosen super().fix_chosen(chosen) @@ -83,7 +83,9 @@ def num_choices(self) -> int: def convert_choice_to_mask(self, choice: int) -> torch.Tensor: """Get the mask according to the input choice.""" - if not hasattr(self, '_candidate_choices'): + if self.is_fixed: + num_channels = self._chosen + elif not hasattr(self, '_candidate_choices'): # todo: we trace the supernet before set_candidate_choices. # It's hacky num_channels = self.num_channels diff --git a/mmrazor/models/mutables/mutable_manage_mixin.py b/mmrazor/models/mutables/mutable_manage_mixin.py deleted file mode 100644 index f826dd2a5..000000000 --- a/mmrazor/models/mutables/mutable_manage_mixin.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - - -class MutableManageMixIn: - """Mixin class for determining whether an object is a dynamic layer. - - Note that a dynamic layer manage one or several mutables. - """ - pass diff --git a/mmrazor/structures/subnet/fix_subnet.py b/mmrazor/structures/subnet/fix_subnet.py index ead21413e..b80470d60 100644 --- a/mmrazor/structures/subnet/fix_subnet.py +++ b/mmrazor/structures/subnet/fix_subnet.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging -import mmcv +from mmengine import fileio from mmengine.logging import print_log from torch import nn @@ -10,17 +10,20 @@ def _dynamic_to_static(model: nn.Module) -> None: # Avoid circular import - from mmrazor.models.architectures.dynamic_op.base import DynamicOP + from mmrazor.models.architectures.dynamic_op.bricks import DynamicMixin def traverse_children(module: nn.Module) -> None: # TODO # dynamicop must have no dynamic child for name, child in module.named_children(): - if isinstance(child, DynamicOP): + if isinstance(child, DynamicMixin): setattr(module, name, child.to_static_op()) else: traverse_children(child) + if isinstance(model, DynamicMixin): + raise RuntimeError('Root model can not be dynamic op.') + traverse_children(model) @@ -29,10 +32,15 @@ def load_fix_subnet(model: nn.Module, prefix: str = '') -> None: """Load fix subnet.""" if isinstance(fix_mutable, str): - fix_mutable = mmcv.fileio.load(fix_mutable) + fix_mutable = fileio.load(fix_mutable) if not isinstance(fix_mutable, dict): raise TypeError('fix_mutable should be a `str` or `dict`' f'but got {type(fix_mutable)}') + + from mmrazor.models.architectures.dynamic_op.bricks import DynamicMixin + if isinstance(model, DynamicMixin): + raise RuntimeError('Root model can not be dynamic op.') + # Avoid circular import from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables.base_mutable import BaseMutable diff --git a/tests/test_models/test_algorithms/test_slimmable_network.py b/tests/test_models/test_algorithms/test_slimmable_network.py index dd7c481ba..792e8e305 100644 --- a/tests/test_models/test_algorithms/test_slimmable_network.py +++ b/tests/test_models/test_algorithms/test_slimmable_network.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist from mmcls.structures import ClsDataSample -from mmcv import fileio +from mmengine import fileio from mmengine.optim import build_optim_wrapper from mmrazor.models.algorithms import SlimmableNetwork, SlimmableNetworkDDP diff --git a/tests/test_models/test_architectures/test_dynamic_op/__init__.py b/tests/test_models/test_architectures/test_dynamic_op/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/__init__.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py new file mode 100644 index 000000000..545686651 --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_conv.py @@ -0,0 +1,266 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Type +from unittest import TestCase +from unittest.mock import MagicMock + +import pytest +import torch +from torch import nn + +from mmrazor.models.architectures.dynamic_op.bricks import (BigNasConv2d, + DynamicConv2d, + OFAConv2d) +from mmrazor.models.mutables import OneShotMutableChannel, OneShotMutableValue +from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet +from ..utils import fix_dynamic_op + + +class TestDynamicConv2d(TestCase): + + def test_dynamic_conv2d_depthwise(self) -> None: + d_conv2d = DynamicConv2d( + in_channels=10, + out_channels=10, + groups=10, + kernel_size=3, + stride=1, + bias=True) + + mock_mutable = MagicMock() + with pytest.raises(ValueError): + d_conv2d.register_mutable_attr('in_channels', mock_mutable) + with pytest.raises(ValueError): + d_conv2d.register_mutable_attr('out_channels', mock_mutable) + + mock_mutable.current_mask = torch.rand(4) + with pytest.raises(ValueError): + d_conv2d.register_mutable_attr('in_channels', mock_mutable) + with pytest.raises(ValueError): + d_conv2d.register_mutable_attr('out_channels', mock_mutable) + + mutable_in_channels = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_out_channels = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + + d_conv2d.register_mutable_attr('in_channels', mutable_in_channels) + d_conv2d.register_mutable_attr('out_channels', mutable_out_channels) + + with pytest.raises(RuntimeError): + d_conv2d.to_static_op() + + d_conv2d.get_mutable_attr('in_channels').current_choice = 8 + d_conv2d.get_mutable_attr('out_channels').current_choice = 8 + + x = torch.rand(10, 8, 224, 224) + out1 = d_conv2d(x) + assert out1.size(1) == 8 + + with pytest.raises(RuntimeError): + _ = d_conv2d.to_static_op() + + fix_mutables = export_fix_subnet(d_conv2d) + with pytest.raises(RuntimeError): + load_fix_subnet(d_conv2d, fix_mutables) + fix_dynamic_op(d_conv2d, fix_mutables) + + s_conv2d = d_conv2d.to_static_op() + assert s_conv2d.weight.size(0) == 8 + assert s_conv2d.weight.size(1) == 1 + assert s_conv2d.bias.size(0) == 8 + out2 = s_conv2d(x) + + assert torch.equal(out1, out2) + + +@pytest.mark.parametrize('bias', [True, False]) +def test_dynamic_conv2d(bias: bool) -> None: + d_conv2d = DynamicConv2d( + in_channels=4, out_channels=10, kernel_size=3, stride=1, bias=bias) + + x_max = torch.rand(10, 4, 224, 224) + out_before_mutate = d_conv2d(x_max) + + mutable_in_channels = OneShotMutableChannel( + 4, candidate_choices=[2, 3, 4], candidate_mode='number') + mutable_out_channels = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + d_conv2d.register_mutable_attr('in_channels', mutable_in_channels) + d_conv2d.register_mutable_attr('out_channels', mutable_out_channels) + + with pytest.raises(RuntimeError): + d_conv2d.to_static_op() + + d_conv2d.get_mutable_attr('in_channels').current_choice = 4 + d_conv2d.mutate_out_channels = 10 + + out_max = d_conv2d(x_max) + assert torch.equal(out_before_mutate, out_max) + + d_conv2d.get_mutable_attr('in_channels').current_choice = 3 + d_conv2d.mutable_out_channels.current_choice = 4 + + x = torch.rand(10, 3, 224, 224) + out1 = d_conv2d(x) + assert out1.size(1) == 4 + + fix_mutables = export_fix_subnet(d_conv2d) + with pytest.raises(RuntimeError): + load_fix_subnet(d_conv2d, fix_mutables) + fix_dynamic_op(d_conv2d, fix_mutables) + + s_conv2d = d_conv2d.to_static_op() + assert s_conv2d.weight.size(0) == 4 + assert s_conv2d.weight.size(1) == 3 + if bias: + assert s_conv2d.bias.size(0) == 4 + out2 = s_conv2d(x) + + assert torch.equal(out1, out2) + + +@pytest.mark.parametrize( + ['is_mutate_in_channels', 'in_channels', 'out_channels'], [(True, 6, 10), + (False, 10, 4)]) +def test_dynamic_conv2d_mutable_single_channels(is_mutate_in_channels: bool, + in_channels: int, + out_channels: int) -> None: + d_conv2d = DynamicConv2d( + in_channels=10, out_channels=10, kernel_size=3, stride=1, bias=True) + mutable_channels = OneShotMutableChannel( + 10, candidate_choices=[4, 6, 10], candidate_mode='number') + + if is_mutate_in_channels: + d_conv2d.register_mutable_attr('in_channels', mutable_channels) + else: + d_conv2d.register_mutable_attr('out_channels', mutable_channels) + + with pytest.raises(RuntimeError): + d_conv2d.to_static_op() + + if is_mutate_in_channels: + d_conv2d.get_mutable_attr('in_channels').current_choice = in_channels + assert d_conv2d.get_mutable_attr('out_channels') is None + else: + d_conv2d.get_mutable_attr('out_channels').current_choice = out_channels + assert d_conv2d.get_mutable_attr('in_channels') is None + + x = torch.rand(3, in_channels, 224, 224) + out1 = d_conv2d(x) + + assert out1.size(1) == out_channels + + with pytest.raises(RuntimeError): + _ = d_conv2d.to_static_op() + + fix_mutables = export_fix_subnet(d_conv2d) + with pytest.raises(RuntimeError): + load_fix_subnet(d_conv2d, fix_mutables) + fix_dynamic_op(d_conv2d, fix_mutables) + + s_conv2d = d_conv2d.to_static_op() + assert s_conv2d.weight.size(0) == out_channels + assert s_conv2d.weight.size(1) == in_channels + out2 = s_conv2d(x) + + assert torch.equal(out1, out2) + + +@pytest.mark.parametrize('dynamic_class', [OFAConv2d, BigNasConv2d]) +@pytest.mark.parametrize('kernel_size_list', [[5], [3, 5, 7]]) +def test_kernel_dynamic_conv2d(dynamic_class: Type[nn.Conv2d], + kernel_size_list: bool) -> None: + + mutable_in_channels = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_out_channels = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + + mutable_kernel_size = OneShotMutableValue(value_list=kernel_size_list) + + d_conv2d = dynamic_class( + in_channels=10, + out_channels=10, + groups=1, + kernel_size=3 if kernel_size_list is None else max(kernel_size_list), + stride=1, + bias=True) + d_conv2d.register_mutable_attr('in_channels', mutable_in_channels) + d_conv2d.register_mutable_attr('out_channels', mutable_out_channels) + if kernel_size_list is not None: + copied_mutable_kernel_size = copy.deepcopy(mutable_kernel_size) + copied_d_conv2d = copy.deepcopy(d_conv2d) + + copied_mutable_kernel_size._value_list = [] + with pytest.raises(ValueError): + _ = copied_d_conv2d.register_mutable_attr( + 'kernel_size', copied_mutable_kernel_size) + + d_conv2d.register_mutable_attr('kernel_size', mutable_kernel_size) + assert d_conv2d.kernel_size_list == kernel_size_list + + with pytest.raises(RuntimeError): + d_conv2d.to_static_op() + + d_conv2d.get_mutable_attr('in_channels').current_choice = 8 + d_conv2d.get_mutable_attr('out_channels').current_choice = 8 + if kernel_size_list is not None: + kernel_size = mutable_kernel_size.sample_choice() + d_conv2d.mutable_attrs['kernel_size'].current_choice = kernel_size + + x = torch.rand(3, 8, 224, 224) + out1 = d_conv2d(x) + assert out1.size(1) == 8 + + fix_mutables = export_fix_subnet(d_conv2d) + with pytest.raises(RuntimeError): + load_fix_subnet(d_conv2d, fix_mutables) + fix_dynamic_op(d_conv2d, fix_mutables) + + s_conv2d = d_conv2d.to_static_op() + assert s_conv2d.weight.size(0) == 8 + assert s_conv2d.weight.size(1) == 8 + assert s_conv2d.bias.size(0) == 8 + if kernel_size_list is not None: + assert s_conv2d.kernel_size == (kernel_size, kernel_size) + assert tuple(s_conv2d.weight.shape[2:]) == (kernel_size, kernel_size) + out2 = s_conv2d(x) + + assert torch.equal(out1, out2) + + +@pytest.mark.parametrize('dynamic_class', [OFAConv2d, BigNasConv2d]) +def test_mutable_kernel_dynamic_conv2d_grad( + dynamic_class: Type[nn.Conv2d]) -> None: + from mmrazor.models.architectures.dynamic_op.bricks import \ + dynamic_conv_mixins + + kernel_size_list = [3, 5, 7] + d_conv2d = dynamic_class( + in_channels=3, + out_channels=10, + groups=1, + kernel_size=max(kernel_size_list), + stride=1, + bias=False) + + mutable_kernel_size = OneShotMutableValue(value_list=kernel_size_list) + d_conv2d.register_mutable_attr('kernel_size', mutable_kernel_size) + + x = torch.rand(3, 3, 224, 224, requires_grad=True) + + for kernel_size in kernel_size_list: + mutable_kernel_size.current_choice = kernel_size + out = d_conv2d(x).sum() + out.backward() + + start_offset, end_offset = dynamic_conv_mixins._get_current_kernel_pos( + max(kernel_size_list), kernel_size) + + mask = torch.ones_like( + d_conv2d.weight, requires_grad=False, dtype=torch.bool) + mask[:, :, start_offset:end_offset, start_offset:end_offset] = 0 + assert d_conv2d.weight.grad[mask].norm().item() == 0 + + d_conv2d.weight.grad.zero_() diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py new file mode 100644 index 000000000..29665ac8a --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_linear.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional +from unittest.mock import MagicMock + +import pytest +import torch +from torch import nn + +from mmrazor.models.architectures.dynamic_op.bricks import (DynamicLinear, + DynamicLinearMixin) +from mmrazor.models.mutables import OneShotMutableChannel +from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet +from ..utils import fix_dynamic_op + + +@pytest.mark.parametrize('bias', [True, False]) +def test_dynamic_linear(bias) -> None: + mutable_in_features = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_out_features = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + + d_linear = DynamicLinear(in_features=10, out_features=10, bias=bias) + + mock_mutable = MagicMock() + with pytest.raises(ValueError): + d_linear.register_mutable_attr('in_features', mock_mutable) + with pytest.raises(ValueError): + d_linear.register_mutable_attr('out_features', mock_mutable) + + mock_mutable.current_mask = torch.rand(8) + with pytest.raises(ValueError): + d_linear.register_mutable_attr('in_features', mock_mutable) + with pytest.raises(ValueError): + d_linear.register_mutable_attr('out_features', mock_mutable) + + d_linear.register_mutable_attr('in_features', mutable_in_features) + d_linear.register_mutable_attr('out_features', mutable_out_features) + + with pytest.raises(RuntimeError): + d_linear.to_static_op() + + d_linear.get_mutable_attr('in_channels').current_choice = 8 + d_linear.get_mutable_attr('out_channels').current_choice = 4 + + x = torch.rand(10, 8) + out1 = d_linear(x) + assert out1.size(1) == 4 + + with pytest.raises(RuntimeError): + _ = d_linear.to_static_op() + + fix_mutables = export_fix_subnet(d_linear) + with pytest.raises(RuntimeError): + load_fix_subnet(d_linear, fix_mutables) + fix_dynamic_op(d_linear, fix_mutables) + assert isinstance(d_linear, nn.Linear) + assert isinstance(d_linear, DynamicLinearMixin) + + s_linear = d_linear.to_static_op() + assert s_linear.weight.size(0) == 4 + assert s_linear.weight.size(1) == 8 + if bias: + assert s_linear.bias.size(0) == 4 + assert not isinstance(s_linear, DynamicLinearMixin) + assert isinstance(s_linear, nn.Linear) + out2 = s_linear(x) + + assert torch.equal(out1, out2) + + +@pytest.mark.parametrize( + ['is_mutate_in_features', 'in_features', 'out_features'], [(True, 6, 10), + (False, 10, 4), + (None, 10, 10)]) +def test_dynamic_linear_mutable_single_features( + is_mutate_in_features: Optional[bool], in_features: int, + out_features: int) -> None: + d_linear = DynamicLinear(in_features=10, out_features=10, bias=True) + mutable_channels = OneShotMutableChannel( + 10, candidate_choices=[4, 6, 10], candidate_mode='number') + + if is_mutate_in_features is not None: + if is_mutate_in_features: + d_linear.register_mutable_attr('in_channels', mutable_channels) + else: + d_linear.register_mutable_attr('out_channels', mutable_channels) + + if is_mutate_in_features: + d_linear.get_mutable_attr('in_channels').current_choice = in_features + assert d_linear.get_mutable_attr('out_channels') is None + elif is_mutate_in_features is False: + d_linear.get_mutable_attr('out_channels').current_choice = out_features + assert d_linear.get_mutable_attr('in_channels') is None + + x = torch.rand(3, in_features) + out1 = d_linear(x) + + assert out1.size(1) == out_features + + if is_mutate_in_features is not None: + with pytest.raises(RuntimeError): + _ = d_linear.to_static_op() + + fix_mutables = export_fix_subnet(d_linear) + with pytest.raises(RuntimeError): + load_fix_subnet(d_linear, fix_mutables) + fix_dynamic_op(d_linear, fix_mutables) + + s_linear = d_linear.to_static_op() + assert s_linear.weight.size(0) == out_features + assert s_linear.weight.size(1) == in_features + out2 = s_linear(x) + + assert torch.equal(out1, out2) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py new file mode 100644 index 000000000..3c8d045e5 --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/test_bricks/test_dynamic_norm.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Type +from unittest.mock import MagicMock + +import pytest +import torch +from torch import nn + +from mmrazor.models.architectures.dynamic_op.bricks import (DynamicBatchNorm1d, + DynamicBatchNorm2d, + DynamicBatchNorm3d, + DynamicMixin) +from mmrazor.models.mutables import OneShotMutableChannel +from mmrazor.structures.subnet import export_fix_subnet, load_fix_subnet +from ..utils import fix_dynamic_op + + +@pytest.mark.parametrize('dynamic_class,input_shape', + [(DynamicBatchNorm1d, (10, 8, 224)), + (DynamicBatchNorm2d, (10, 8, 224, 224)), + (DynamicBatchNorm3d, (10, 8, 3, 224, 224))]) +@pytest.mark.parametrize('affine', [True, False]) +@pytest.mark.parametrize('track_running_stats', [True, False]) +def test_dynamic_bn(dynamic_class: Type[nn.modules.batchnorm._BatchNorm], + input_shape: Tuple[int], affine: bool, + track_running_stats: bool) -> None: + mutable_num_features = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + + d_bn = dynamic_class( + num_features=10, + affine=affine, + track_running_stats=track_running_stats) + if not affine and not track_running_stats: + with pytest.raises(RuntimeError): + d_bn.register_mutable_attr('num_features', mutable_num_features) + else: + mock_mutable = MagicMock() + with pytest.raises(ValueError): + d_bn.register_mutable_attr('num_features', mock_mutable) + mock_mutable.current_mask = torch.rand(5) + with pytest.raises(ValueError): + d_bn.register_mutable_attr('num_features', mock_mutable) + + d_bn.register_mutable_attr('num_features', mutable_num_features) + assert d_bn.get_mutable_attr('in_channels') is d_bn.get_mutable_attr( + 'out_channels') + + if affine or track_running_stats: + d_bn.get_mutable_attr('in_channels').current_choice = 8 + + with pytest.raises(ValueError): + wrong_shape_x = torch.rand(8) + _ = d_bn(wrong_shape_x) + + x = torch.rand(*input_shape) + out1 = d_bn(x) + assert out1.size(1) == 8 + + fix_mutables = export_fix_subnet(d_bn) + with pytest.raises(RuntimeError): + load_fix_subnet(d_bn, fix_mutables) + fix_dynamic_op(d_bn, fix_mutables) + assert isinstance(d_bn, dynamic_class) + assert isinstance(d_bn, DynamicMixin) + + s_bn = d_bn.to_static_op() + if affine: + assert s_bn.weight.size(0) == 8 + assert s_bn.bias.size(0) == 8 + if track_running_stats: + assert s_bn.running_mean.size(0) == 8 + assert s_bn.running_var.size(0) == 8 + assert not isinstance(s_bn, DynamicMixin) + assert isinstance(s_bn, d_bn.static_op_factory) + out2 = s_bn(x) + + assert torch.equal(out1, out2) + + +@pytest.mark.parametrize(['static_class', 'dynamic_class', 'input_shape'], + [(nn.BatchNorm1d, DynamicBatchNorm1d, (10, 8, 224)), + (nn.BatchNorm2d, DynamicBatchNorm2d, + (10, 8, 224, 224)), + (nn.BatchNorm3d, DynamicBatchNorm3d, + (10, 8, 3, 224, 224))]) +def test_bn_track_running_stats( + static_class: Type[nn.modules.batchnorm._BatchNorm], + dynamic_class: Type[nn.modules.batchnorm._BatchNorm], + input_shape: Tuple[int], +) -> None: + mutable_num_features = OneShotMutableChannel( + 10, candidate_choices=[4, 8, 10], candidate_mode='number') + mutable_num_features.current_choice = 8 + d_bn = dynamic_class( + num_features=10, track_running_stats=True, affine=False) + d_bn.register_mutable_attr('num_features', mutable_num_features) + + s_bn = static_class(num_features=8, track_running_stats=True, affine=False) + + d_bn.train() + s_bn.train() + mask = d_bn._get_num_features_mask() + for _ in range(10): + x = torch.rand(*input_shape) + _ = d_bn(x) + _ = s_bn(x) + + d_running_mean = d_bn.running_mean[mask] + d_running_var = d_bn.running_var[mask] + + assert torch.equal(s_bn.running_mean, d_running_mean) + assert torch.equal(s_bn.running_var, d_running_var) + + d_bn.eval() + s_bn.eval() + x = torch.rand(*input_shape) + + assert torch.equal(d_bn(x), s_bn(x)) diff --git a/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py b/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py index a720aa697..97277d03a 100644 --- a/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py +++ b/tests/test_models/test_architectures/test_dynamic_op/test_default_dynamic_op.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase +"""from unittest import TestCase. import pytest import torch from mmrazor.models.architectures import DynamicConv2d from mmrazor.structures import export_fix_subnet, load_fix_subnet - +from .utils import fix_dynamic_op class TestDefaultDynamicOP(TestCase): @@ -38,7 +38,9 @@ def test_dynamic_conv2d(self) -> None: self.assertEqual(out1.size(1), 4) fix_mutables = export_fix_subnet(d_conv2d) - load_fix_subnet(d_conv2d, fix_mutables) + with pytest.raises(RuntimeError): + load_fix_subnet(d_conv2d, fix_mutables) + fix_dynamic_op(d_conv2d, fix_mutables) out2 = d_conv2d(x) self.assertTrue(torch.equal(out1, out2)) @@ -77,7 +79,9 @@ def test_dynamic_conv2d_depthwise(self) -> None: self.assertEqual(out1.size(1), 8) fix_mutables = export_fix_subnet(d_conv2d) - load_fix_subnet(d_conv2d, fix_mutables) + with pytest.raises(RuntimeError): + load_fix_subnet(d_conv2d, fix_mutables) + fix_dynamic_op(d_conv2d, fix_mutables) out2 = d_conv2d(x) self.assertTrue(torch.equal(out1, out2)) @@ -86,3 +90,4 @@ def test_dynamic_conv2d_depthwise(self) -> None: out3 = s_conv2d(x) self.assertTrue(torch.equal(out1, out3)) +""" diff --git a/tests/test_models/test_architectures/test_dynamic_op/utils.py b/tests/test_models/test_architectures/test_dynamic_op/utils.py new file mode 100644 index 000000000..aaa6bf3c2 --- /dev/null +++ b/tests/test_models/test_architectures/test_dynamic_op/utils.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +from mmrazor.models.architectures.dynamic_op import DynamicOP + + +def fix_dynamic_op(op: DynamicOP, fix_mutables: Optional[Dict] = None) -> None: + for name, mutable in op.mutable_attrs.items(): + + if fix_mutables is not None: + chosen = fix_mutables[f'mutable_attrs.{name}'] + else: + chosen = mutable.dump_chosen() + + mutable.fix_chosen(chosen) diff --git a/tests/test_models/test_mutators/utils.py b/tests/test_models/test_mutators/utils.py index ee8762267..7ddede648 100644 --- a/tests/test_models/test_mutators/utils.py +++ b/tests/test_models/test_mutators/utils.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List -from mmcv import fileio +from mmengine import fileio from mmrazor.models.algorithms import SlimmableNetwork diff --git a/tests/test_runners/test_evolution_search_loop.py b/tests/test_runners/test_evolution_search_loop.py index 1769401c9..f90e5c026 100644 --- a/tests/test_runners/test_evolution_search_loop.py +++ b/tests/test_runners/test_evolution_search_loop.py @@ -6,8 +6,8 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -import mmcv import torch +from mmengine import fileio from mmengine.config import Config from torch.utils.data import DataLoader, Dataset @@ -104,7 +104,7 @@ def test_init(self): fake_subnet = {'1': 'choice1', '2': 'choice2'} fake_candidates = Candidates((fake_subnet, 0.)) init_candidates_path = os.path.join(self.temp_dir, 'candidates.yaml') - mmcv.fileio.dump(fake_candidates, init_candidates_path) + fileio.dump(fake_candidates, init_candidates_path) loop_cfg.init_candidates = init_candidates_path loop = LOOPS.build(loop_cfg) self.assertIsInstance(loop, EvolutionSearchLoop)