Skip to content

Commit

Permalink
fix according to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wutongshenqiu committed Aug 10, 2022
1 parent eaf57a1 commit 953dfb7
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 22 deletions.
46 changes: 27 additions & 19 deletions mmrazor/models/mutables/derived_mutable.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import logging
from collections.abc import Iterable
from typing import Any, Callable, Dict, Optional, Protocol, Set, Union

import torch
from mmcls.models.utils import make_divisible
from mmengine.logging import print_log
from torch import Tensor

from ..utils import make_divisible
from .base_mutable import CHOICE_TYPE, BaseMutable


Expand Down Expand Up @@ -110,9 +108,6 @@ def fn():

def _concat_mask_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable:
"""Helper function to build `mask_fn` for concat derived mutable."""
for mutable in mutables:
if not hasattr(mutable, 'current_mask'):
raise RuntimeError('mutable must have attribute `currnet_mask`')

def fn():
return torch.cat([m.current_mask for m in mutables])
Expand Down Expand Up @@ -154,13 +149,19 @@ def derive_divide_mutable(self: MutableProtocol,
def derive_concat_mutable(
mutables: Iterable[MutableChannelProtocol]) -> 'DerivedMutable':
"""Derive concat mutable, usually used with `torch.cat`."""
for mutable in mutables:
if not hasattr(mutable, 'current_mask'):
raise RuntimeError('Source mutable of concat derived mutable '
'must have attribute `currnet_mask`')

choice_fn = _concat_choice_fn(mutables)
mask_fn = _concat_mask_fn(mutables)

return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)


class DerivedMutable(BaseMutable[CHOICE_TYPE, Dict], DerivedMethodMixin):
class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE],
DerivedMethodMixin):
"""Class for derived mutable.
A derived mutable is a mutable derived from other mutables that has
Expand Down Expand Up @@ -219,11 +220,9 @@ def __init__(self,
if source_mutables is None:
source_mutables = self._trace_source_mutables()
if len(source_mutables) == 0:
# TODO
# warning or raise error?
print_log(
'Can not find source mutables automatically',
level=logging.WARNING)
raise RuntimeError(
'Can not find source mutables automatically, '
'please provide manually.')
else:
source_mutables = set(source_mutables)
for mutable in source_mutables:
Expand All @@ -234,7 +233,7 @@ def __init__(self,

# TODO
# has no effect
def fix_chosen(self, chosen: Dict) -> None:
def fix_chosen(self, chosen: CHOICE_TYPE) -> None:
"""Fix mutable with subnet config.
Warning:
Expand All @@ -245,13 +244,13 @@ def fix_chosen(self, chosen: Dict) -> None:

self.is_fixed = True

def dump_chosen(self) -> Dict:
def dump_chosen(self) -> CHOICE_TYPE:
"""Dump information of chosen.
Returns:
Dict: Dumped information.
"""
return dict(current_choice=self.current_choice)
return self.current_choice

@property
def num_choices(self) -> int:
Expand Down Expand Up @@ -279,16 +278,26 @@ def current_choice(self, choice: CHOICE_TYPE) -> None:
RuntimeError: Error when `current_choice` of derived mutable
is modified directly.
"""
raise RuntimeError('Choice of drived mutable can not be set!')
raise RuntimeError('Choice of drived mutable can not be set.')

@property
def current_mask(self) -> Tensor:
"""Current mask of derived mutable."""
if self.mask_fn is None:
raise RuntimeError(
'`mask_fn` must be set before access `current_mask`')
'`mask_fn` must be set before access `current_mask`.')
return self.mask_fn()

@current_mask.setter
def current_mask(self, mask: Tensor) -> None:
"""Setter of current mask.
Raises:
RuntimeError: Error when `current_mask` of derived mutable
is modified directly.
"""
raise RuntimeError('Mask of drived mutable can not be set.')

@staticmethod
def _trace_source_mutables_from_closure(
closure: Callable) -> Set[BaseMutable]:
Expand Down Expand Up @@ -343,8 +352,7 @@ def is_source_mutable(mutable: object) -> bool:
# should be __str__? but can not provide info when debug
def __repr__(self) -> str: # pragma: no cover
s = f'{self.__class__.__name__}('
if self.choice_fn is not None:
s += f'current_choice={self.current_choice}, '
s += f'current_choice={self.current_choice}, '
if self.mask_fn is not None:
s += f'activated_channels={self.current_mask.sum().item()}, '
s += f'source_mutables={self.source_mutables}, '
Expand Down
5 changes: 4 additions & 1 deletion mmrazor/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .make_divisible import make_divisible
from .misc import add_prefix
from .optim_wrapper import reinitialize_optim_wrapper_count_status

__all__ = ['add_prefix', 'reinitialize_optim_wrapper_count_status']
__all__ = [
'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible'
]
31 changes: 31 additions & 0 deletions mmrazor/models/utils/make_divisible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional


def make_divisible(value: int,
divisor: int,
min_value: Optional[int] = None,
min_ratio: float = 0.9) -> int:
"""Make divisible function.
This function rounds the channel number down to the nearest value that can
be divisible by the divisor.
Args:
value (int): The original channel number.
divisor (int): The divisor to fully divide the channel number.
min_value (int, optional): The minimum value of the output channel.
Default: None, means that the minimum value equal to the divisor.
min_ratio (float): The minimum ratio of the rounded channel
number to the original channel number. Default: 0.9.
Returns:
int: The modified output channel number
"""

if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
new_value += divisor
return new_value
45 changes: 43 additions & 2 deletions tests/test_models/test_mutables/test_derived_mutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def test_mutable_channel_derived(self) -> None:
mc_derived.current_mask,
torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=torch.bool))

with pytest.raises(RuntimeError):
mc_derived.current_mask = torch.ones(
mc_derived.current_mask.size())

def test_mutable_divide(self) -> None:
mc = OneShotMutableChannel(
num_channels=128,
Expand Down Expand Up @@ -101,13 +105,17 @@ def test_mutable_divide(self) -> None:

def test_double_fixed(self) -> None:
choice_fn = lambda x: x # noqa: E731
derived_mutable = DerivedMutable(choice_fn)
derived_mutable = DerivedMutable(choice_fn, source_mutables=[])
derived_mutable.fix_chosen({})

with pytest.raises(RuntimeError):
derived_mutable.fix_chosen({})

def test_source_mutables(self) -> None:
useless_fn = lambda x: x # noqa: E731
with pytest.raises(RuntimeError):
_ = DerivedMutable(choice_fn=useless_fn)

mc1 = OneShotMutableChannel(
num_channels=3, candidate_choices=[1, 3], candidate_mode='number')
mc2 = OneShotMutableChannel(
Expand Down Expand Up @@ -156,6 +164,39 @@ def fn():
mask_fn=dict_closure_fn({2: [mc1, mc2, mc3]}, {3: dd_mutable}))
assert dddd_mutable.source_mutables == {mc1, mc2, mc3}

def test_nested_mutables(self) -> None:
source_a = OneShotMutableChannel(
num_channels=2, candidate_choices=[1, 2], candidate_mode='number')
source_b = OneShotMutableChannel(
num_channels=3, candidate_choices=[2, 3], candidate_mode='number')

# derive from
derived_c = source_a * 1
concat_mutables = [source_b, derived_c]
derived_d = DerivedMutable.derive_concat_mutable(concat_mutables)
concat_mutables = [derived_c, derived_d]
derived_e = DerivedMutable.derive_concat_mutable(concat_mutables)

assert derived_c.source_mutables == {source_a}
assert derived_d.source_mutables == {source_a, source_b}
assert derived_e.source_mutables == {source_a, source_b}

source_a.current_choice = 1
source_b.current_choice = 3

assert derived_c.current_choice == 1
assert torch.equal(derived_c.current_mask,
torch.tensor([1, 0], dtype=torch.bool))

assert derived_d.current_choice == 4
assert torch.equal(derived_d.current_mask,
torch.tensor([1, 1, 1, 1, 0], dtype=torch.bool))

assert derived_e.current_choice == 5
assert torch.equal(
derived_e.current_mask,
torch.tensor([1, 0, 1, 1, 1, 1, 0], dtype=torch.bool))


@pytest.mark.parametrize('expand_ratio', [1, 2, 3])
def test_derived_expand_mutable(expand_ratio: int) -> None:
Expand All @@ -180,7 +221,7 @@ def test_derived_expand_mutable(expand_ratio: int) -> None:
_ = mv_derived.current_mask

chosen = mv_derived.dump_chosen()
assert chosen == {'current_choice': mv.current_choice * expand_ratio}
assert chosen == mv.current_choice * expand_ratio
mv_derived.fix_chosen(chosen)
assert mv_derived.is_fixed

Expand Down

0 comments on commit 953dfb7

Please sign in to comment.