-
Notifications
You must be signed in to change notification settings - Fork 220
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add dynamic bricks * add dynamic conv2d test * add tests for dynamic linear and dynamic norm * add docstring for dynamic conv2d * add docstring for dynamic linear * add docstring for dynamic batchnorm * Refactor the dynamic op ( put more logic into the mixin ) * fix UT * Fix UT ( fileio was moved to mmengine) * derived mutable adds choices property * Unify the register interface of mutable in dynamic op * Unified getter interface of mutable in dynamic op Co-authored-by: gaojianfei <gaojianfei@sensetime.com> Co-authored-by: pppppM <gjf_mail@126.com>
- Loading branch information
1 parent
8775b03
commit d190037
Showing
29 changed files
with
1,876 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .base import DynamicOP | ||
from .default_dynamic_ops import (DynamicBatchNorm, DynamicConv2d, | ||
DynamicGroupNorm, DynamicInstanceNorm, | ||
DynamicLinear) | ||
from .slimmable_dynamic_ops import SwitchableBatchNorm2d | ||
|
||
__all__ = [ | ||
'DynamicConv2d', 'DynamicLinear', 'DynamicBatchNorm', | ||
'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d' | ||
'DynamicInstanceNorm', 'DynamicGroupNorm', 'SwitchableBatchNorm2d', | ||
'DynamicOP' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,106 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Optional, Set | ||
|
||
from torch import nn | ||
|
||
from mmrazor.models.mutables.mutable_channel import MutableChannel | ||
from mmrazor.models.mutables.base_mutable import BaseMutable | ||
|
||
|
||
class DynamicOP(ABC): | ||
"""Base class for dynamic OP. A dynamic OP usually consists of a normal | ||
static OP and mutables, where mutables are used to control the searchable | ||
(mutable) part of the dynamic OP. | ||
Note: | ||
When the dynamic OP has just been initialized, its forward propagation | ||
logic should be the same as the corresponding static OP. Only after | ||
the searchable part accepts the specific mutable through the | ||
corresponding interface does the part really become dynamic. | ||
Note: | ||
All subclass should implement ``to_static_op`` API. | ||
Args: | ||
accepted_mutables (set): The string set of all accepted mutables. | ||
""" | ||
accepted_mutables: Set[str] = set() | ||
|
||
@abstractmethod | ||
def to_static_op(self) -> nn.Module: | ||
... | ||
"""Convert dynamic OP to static OP. | ||
Note: | ||
The forward result for the same input between dynamic OP and its | ||
corresponding static OP must be same. | ||
Returns: | ||
nn.Module: Corresponding static OP. | ||
""" | ||
|
||
def check_if_mutables_fixed(self) -> None: | ||
"""Check if all mutables are fixed. | ||
Raises: | ||
RuntimeError: Error if a existing mutable is not fixed. | ||
""" | ||
|
||
def check_fixed(mutable: Optional[BaseMutable]) -> None: | ||
if mutable is not None and not mutable.is_fixed: | ||
raise RuntimeError(f'Mutable {type(mutable)} is not fixed.') | ||
|
||
for mutable in self.accepted_mutables: | ||
check_fixed(getattr(self, f'{mutable}')) | ||
|
||
@staticmethod | ||
def get_current_choice(mutable: BaseMutable) -> Any: | ||
"""Get current choice of given mutable. | ||
Args: | ||
mutable (BaseMutable): Given mutable. | ||
Raises: | ||
RuntimeError: Error if `current_choice` is None. | ||
Returns: | ||
Any: Current choice of given mutable. | ||
""" | ||
current_choice = mutable.current_choice | ||
if current_choice is None: | ||
raise RuntimeError(f'current choice of mutable {type(mutable)} ' | ||
'can not be None at runtime') | ||
|
||
return current_choice | ||
|
||
|
||
class ChannelDynamicOP(DynamicOP): | ||
"""Base class for dynamic OP with mutable channels. | ||
Note: | ||
All subclass should implement ``mutable_in`` and ``mutable_out`` APIs. | ||
""" | ||
|
||
@property | ||
@abstractmethod | ||
def mutable_in(self) -> MutableChannel: | ||
... | ||
def mutable_in(self) -> Optional[BaseMutable]: | ||
"""Mutable related to input.""" | ||
|
||
@property | ||
def mutable_out(self) -> MutableChannel: | ||
... | ||
@abstractmethod | ||
def mutable_out(self) -> Optional[BaseMutable]: | ||
"""Mutable related to output.""" | ||
|
||
@staticmethod | ||
def check_mutable_channels(mutable_channels: BaseMutable) -> None: | ||
"""Check if mutable has `currnet_mask` attribute. | ||
Args: | ||
mutable_channels (BaseMutable): Mutable to be checked. | ||
Raises: | ||
ValueError: Error if mutable does not have `current_mask` | ||
attribute. | ||
""" | ||
if not hasattr(mutable_channels, 'current_mask'): | ||
raise ValueError( | ||
'channel mutable must have attribute `current_mask`') |
14 changes: 14 additions & 0 deletions
14
mmrazor/models/architectures/dynamic_op/bricks/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .dynamic_conv import BigNasConv2d, DynamicConv2d, OFAConv2d | ||
from .dynamic_linear import DynamicLinear | ||
from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, | ||
DynamicLinearMixin, DynamicMixin) | ||
from .dynamic_norm import (DynamicBatchNorm1d, DynamicBatchNorm2d, | ||
DynamicBatchNorm3d) | ||
|
||
__all__ = [ | ||
'BigNasConv2d', 'DynamicConv2d', 'OFAConv2d', 'DynamicLinear', | ||
'DynamicBatchNorm1d', 'DynamicBatchNorm2d', 'DynamicBatchNorm3d', | ||
'DynamicMixin', 'DynamicChannelMixin', 'DynamicBatchNormMixin', | ||
'DynamicLinearMixin' | ||
] |
Oops, something went wrong.