Skip to content

Commit

Permalink
[Feature] Add dynamic bricks (#228)
Browse files Browse the repository at this point in the history
* add dynamic bricks

* add dynamic conv2d test

* add tests for dynamic linear and dynamic norm

* add docstring for dynamic conv2d

* add docstring for dynamic linear

* add docstring for dynamic batchnorm

* Refactor the dynamic op ( put more logic into the mixin )

* fix UT

* Fix UT ( fileio was moved to mmengine)

* derived mutable adds choices property

* Unify the register interface of mutable in dynamic op

* Unified getter interface of mutable in dynamic op

Co-authored-by: gaojianfei <gaojianfei@sensetime.com>
Co-authored-by: pppppM <gjf_mail@126.com>
  • Loading branch information
3 people committed Aug 19, 2022
1 parent 8775b03 commit d190037
Show file tree
Hide file tree
Showing 29 changed files with 1,876 additions and 59 deletions.
12 changes: 6 additions & 6 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import warnings
from typing import Dict, List, Optional, Tuple, Union

import mmcv
import torch
from mmengine import fileio
from mmengine.dist import broadcast_object_list
from mmengine.evaluator import Evaluator
from mmengine.runner import EpochBasedTrainLoop
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(self,
if init_candidates is None:
self.candidates = Candidates()
else:
self.candidates = mmcv.fileio.load(init_candidates)
self.candidates = fileio.load(init_candidates)
assert isinstance(self.candidates, Candidates), 'please use the \
correct init candidates file'

Expand Down Expand Up @@ -228,7 +228,7 @@ def _crossover(self) -> SupportRandomSubnet:
def _resume(self):
"""Resume searching."""
if self.runner.rank == 0:
searcher_resume = mmcv.fileio.load(self.resume_from)
searcher_resume = fileio.load(self.resume_from)
for k in searcher_resume.keys():
setattr(self, k, searcher_resume[k])
epoch_start = int(searcher_resume['_epoch'])
Expand All @@ -244,8 +244,8 @@ def _save_best_fix_subnet(self):
self.model.set_subnet(best_random_subnet)
best_fix_subnet = export_fix_subnet(self.model)
save_name = 'best_fix_subnet.yaml'
mmcv.fileio.dump(best_fix_subnet,
osp.join(self.runner.work_dir, save_name))
fileio.dump(best_fix_subnet,
osp.join(self.runner.work_dir, save_name))
self.runner.logger.info(
'Search finished and '
f'{save_name} saved in {self.runner.work_dir}.')
Expand All @@ -271,7 +271,7 @@ def _save_searcher_ckpt(self) -> None:
save_for_resume['_epoch'] = self.runner.epoch
for k in ['candidates', 'top_k_candidates']:
save_for_resume[k] = getattr(self, k)
mmcv.fileio.dump(
fileio.dump(
save_for_resume,
osp.join(self.runner.work_dir,
f'search_epoch_{self.runner.epoch}.pkl'))
Expand Down
4 changes: 2 additions & 2 deletions mmrazor/engine/runner/subnet_sampler_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from abc import abstractmethod
from typing import Dict, List, Optional, Sequence, Tuple, Union

import mmcv
import torch
from mmengine import fileio
from mmengine.evaluator import Evaluator
from mmengine.runner import IterBasedTrainLoop
from mmengine.utils import is_list_of
Expand Down Expand Up @@ -326,6 +326,6 @@ def _check_constraints(self, random_subnet: SupportRandomSubnet) -> bool:
def _save_candidates(self) -> None:
"""Save the candidates to init the next searching."""
save_path = os.path.join(self.runner.work_dir, 'candidates.pkl')
mmcv.fileio.dump(self.candidates, save_path)
fileio.dump(self.candidates, save_path)
self.runner.logger.info(f'candidates.pkl saved in '
f'{self.runner.work_dir}')
5 changes: 2 additions & 3 deletions mmrazor/models/algorithms/pruning/slimmable_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from pathlib import Path
from typing import Dict, List, Optional, Union

import mmcv
import torch
from mmengine import BaseDataElement
from mmengine import BaseDataElement, fileio
from mmengine.model import BaseModel, MMDistributedDataParallel
from mmengine.optim import OptimWrapper
from torch import nn
Expand Down Expand Up @@ -86,7 +85,7 @@ def _load_and_merge_channel_cfgs(
"""Load and merge channel config."""
channel_cfgs = list()
for channel_cfg_path in channel_cfg_paths:
channel_cfg = mmcv.fileio.load(channel_cfg_path)
channel_cfg = fileio.load(channel_cfg_path)
channel_cfgs.append(channel_cfg)

return self.merge_channel_cfgs(channel_cfgs)
Expand Down
4 changes: 3 additions & 1 deletion mmrazor/models/architectures/dynamic_op/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import DynamicOP
from .default_dynamic_ops import (DynamicBatchNorm, DynamicConv2d,
DynamicGroupNorm, DynamicInstanceNorm,
DynamicLinear)
from .slimmable_dynamic_ops import SwitchableBatchNorm2d

__all__ = [
'DynamicConv2d', 'DynamicLinear', 'DynamicBatchNorm',
'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d'
'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d',
'DynamicOP'
]
93 changes: 87 additions & 6 deletions mmrazor/models/architectures/dynamic_op/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Any, Optional, Set

from torch import nn

from mmrazor.models.mutables.mutable_channel import MutableChannel
from mmrazor.models.mutables.base_mutable import BaseMutable


class DynamicOP(ABC):
"""Base class for dynamic OP. A dynamic OP usually consists of a normal
static OP and mutables, where mutables are used to control the searchable
(mutable) part of the dynamic OP.
Note:
When the dynamic OP has just been initialized, its forward propagation
logic should be the same as the corresponding static OP. Only after
the searchable part accepts the specific mutable through the
corresponding interface does the part really become dynamic.
Note:
All subclass should implement ``to_static_op`` API.
Args:
accepted_mutables (set): The string set of all accepted mutables.
"""
accepted_mutables: Set[str] = set()

@abstractmethod
def to_static_op(self) -> nn.Module:
...
"""Convert dynamic OP to static OP.
Note:
The forward result for the same input between dynamic OP and its
corresponding static OP must be same.
Returns:
nn.Module: Corresponding static OP.
"""

def check_if_mutables_fixed(self) -> None:
"""Check if all mutables are fixed.
Raises:
RuntimeError: Error if a existing mutable is not fixed.
"""

def check_fixed(mutable: Optional[BaseMutable]) -> None:
if mutable is not None and not mutable.is_fixed:
raise RuntimeError(f'Mutable {type(mutable)} is not fixed.')

for mutable in self.accepted_mutables:
check_fixed(getattr(self, f'{mutable}'))

@staticmethod
def get_current_choice(mutable: BaseMutable) -> Any:
"""Get current choice of given mutable.
Args:
mutable (BaseMutable): Given mutable.
Raises:
RuntimeError: Error if `current_choice` is None.
Returns:
Any: Current choice of given mutable.
"""
current_choice = mutable.current_choice
if current_choice is None:
raise RuntimeError(f'current choice of mutable {type(mutable)} '
'can not be None at runtime')

return current_choice


class ChannelDynamicOP(DynamicOP):
"""Base class for dynamic OP with mutable channels.
Note:
All subclass should implement ``mutable_in`` and ``mutable_out`` APIs.
"""

@property
@abstractmethod
def mutable_in(self) -> MutableChannel:
...
def mutable_in(self) -> Optional[BaseMutable]:
"""Mutable related to input."""

@property
def mutable_out(self) -> MutableChannel:
...
@abstractmethod
def mutable_out(self) -> Optional[BaseMutable]:
"""Mutable related to output."""

@staticmethod
def check_mutable_channels(mutable_channels: BaseMutable) -> None:
"""Check if mutable has `currnet_mask` attribute.
Args:
mutable_channels (BaseMutable): Mutable to be checked.
Raises:
ValueError: Error if mutable does not have `current_mask`
attribute.
"""
if not hasattr(mutable_channels, 'current_mask'):
raise ValueError(
'channel mutable must have attribute `current_mask`')
14 changes: 14 additions & 0 deletions mmrazor/models/architectures/dynamic_op/bricks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d
from .dynamic_linear import DynamicLinear
from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin,
DynamicLinearMixin, DynamicMixin)
from .dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d,
DynamicBatchNorm3d)

__all__ = [
'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear',
'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d',
'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin',
'DynamicLinearMixin'
]
Loading

0 comments on commit d190037

Please sign in to comment.