Skip to content

Commit

Permalink
complement unittest for derived mutable
Browse files Browse the repository at this point in the history
  • Loading branch information
wutongshenqiu committed Aug 8, 2022
1 parent 3b19858 commit 8b55218
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 38 deletions.
76 changes: 68 additions & 8 deletions mmrazor/models/mutables/derived_mutable.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Callable, Dict, Iterable, Optional, Protocol
import inspect
from collections.abc import Iterable
from typing import Any, Callable, Dict, Optional, Protocol, Set, Union

import torch
from mmcls.models.utils import make_divisible
from mmengine.logging import MMLogger
from torch import Tensor

from .base_mutable import CHOICE_TYPE, BaseMutable

logger = MMLogger.get_current_instance()

class MutableProtocol(Protocol):

class MutableProtocol(Protocol): # pragma: no cover

@property
def current_choice(self) -> Any:
Expand All @@ -21,7 +26,7 @@ def derive_divide_mutable(self, ratio: int, divisor: int) -> Any:
...


class ChannelMutableProtocol(MutableProtocol):
class ChannelMutableProtocol(MutableProtocol): # pragma: no cover

@property
def current_mask(self) -> Tensor:
Expand All @@ -36,7 +41,8 @@ def fn():
return fn


def _expand_mask_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable:
def _expand_mask_fn(mutable: MutableProtocol,
expand_ratio: int) -> Callable: # pragma: no cover
if not hasattr(mutable, 'current_mask'):
raise ValueError('mutable must have attribute `currnet_mask`')

Expand Down Expand Up @@ -70,7 +76,7 @@ def fn():

def _divide_mask_fn(mutable: MutableProtocol,
ratio: int,
divisor: int = 8) -> Callable:
divisor: int = 8) -> Callable: # pragma: no cover
if not hasattr(mutable, 'current_mask'):
raise ValueError('mutable must have attribute `currnet_mask`')

Expand Down Expand Up @@ -98,7 +104,7 @@ def fn():
def _concat_mask_fn(mutables: Iterable[ChannelMutableProtocol]) -> Callable:
for mutable in mutables:
if not hasattr(mutable, 'current_mask'):
raise ValueError('mutable must have attribute `currnet_mask`')
raise RuntimeError('mutable must have attribute `currnet_mask`')

def fn():
return torch.cat([m.current_mask for m in mutables])
Expand Down Expand Up @@ -146,13 +152,28 @@ class DerivedMutable(BaseMutable[CHOICE_TYPE, Dict], DerivedMethodMixin):
def __init__(self,
choice_fn: Callable,
mask_fn: Optional[Callable] = None,
source_mutables: Optional[Iterable[BaseMutable]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(alias, init_cfg)

self.choice_fn = choice_fn
self.mask_fn = mask_fn

if source_mutables is None:
source_mutables = self._find_source_mutables()
if len(source_mutables) == 0:
# TODO
# warning or raise error?
logger.warning('Can not find source mutables automatically')
else:
source_mutables = set(source_mutables)
for mutable in source_mutables:
if not self.is_source_mutable(mutable):
raise ValueError('Expect all mutable to be source mutable, '
f'but {mutable} is not')
self.source_mutables = source_mutables

# TODO
# has no effect
def fix_chosen(self, chosen: Dict) -> None:
Expand Down Expand Up @@ -183,14 +204,53 @@ def current_mask(self) -> Tensor:
'`mask_fn` must be set before access `current_mask`')
return self.mask_fn()

@staticmethod
def _extract_source_mutables_from_fn(fn: Callable) -> Set[BaseMutable]:
source_mutables: Set[BaseMutable] = set()

def add_mutables_dfs(
mutable: Union[Iterable, BaseMutable, Dict]) -> None:
nonlocal source_mutables
if isinstance(mutable, BaseMutable):
if isinstance(mutable, DerivedMutable):
source_mutables |= mutable.source_mutables
else:
source_mutables.add(mutable)
# dict is also iterable, should parse first
elif isinstance(mutable, dict):
add_mutables_dfs(mutable.values())
add_mutables_dfs(mutable.keys())
elif isinstance(mutable, Iterable):
for m in mutable:
add_mutables_dfs(m)

noncolcal_pars = inspect.getclosurevars(fn).nonlocals
add_mutables_dfs(noncolcal_pars.values())

return source_mutables

def _find_source_mutables(self) -> Set[BaseMutable]:
source_mutables = self._extract_source_mutables_from_fn(self.choice_fn)
if self.mask_fn is not None:
source_mutables |= self._extract_source_mutables_from_fn(
self.mask_fn)

return source_mutables

@staticmethod
def is_source_mutable(mutable: BaseMutable) -> bool:
return isinstance(mutable, BaseMutable) and \
not isinstance(mutable, DerivedMutable)

# TODO
# should be __str__? but can not provide info when debug
def __repr__(self) -> str:
def __repr__(self) -> str: # pragma: no cover
s = f'{self.__class__.__name__}('
if self.choice_fn is not None:
s += f'current_choice={self.current_choice}, '
if self.mask_fn is not None:
s += f'activated_mask_nums={self.current_mask.sum().item()}, '
s += f'activated_channels={self.current_mask.sum().item()}, '
s += f'source_mutables={self.source_mutables}, '
s += f'is_fixed={self.is_fixed})'

return s
3 changes: 3 additions & 0 deletions mmrazor/models/mutables/mutable_channel/mutable_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, num_channels: int, **kwargs):
# outputs, we add the mutable out of these modules to the
# `concat_parent_mutables` of this module.
self.concat_parent_mutables: List[MutableChannel] = list()
self.name = 'unbind'

@property
def same_mutables(self):
Expand Down Expand Up @@ -104,6 +105,8 @@ def fix_chosen(self, chosen: CHOSEN_TYPE) -> None:
'The mode of current MUTABLE is `fixed`. '
'Please do not call `fix_chosen` function again.')

# TODO
# should fixed op still have candidate_choices?
self._chosen = chosen
self.is_fixed = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __repr__(self):
repr_str += f'num_channels={self.num_channels}, '
repr_str += f'current_choice={self.current_choice}, '
repr_str += f'choices={self.choices}, '
repr_str += f'current_mask_shape={self.current_mask.shape}, '
repr_str += f'activated_channels={self.current_mask.sum().item()}, '
repr_str += f'concat_mutable_name={concat_mutable_name})'
return repr_str

Expand Down Expand Up @@ -203,10 +203,7 @@ def fn():

raise TypeError(f'Unsupported type {type(other)} for mul!')

def __rdiv__(self, other) -> DerivedMutable:
return self / other

def __div__(self, other) -> DerivedMutable:
def __floordiv__(self, other) -> DerivedMutable:
if isinstance(other, int):
return self.derive_divide_mutable(other)
if isinstance(other, tuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def fix_chosen(self, dumped_chosen: Dict) -> None:
# TODO
# remove after remove `current_choice`
self.current_choice = self._candidate_choices.index(chosen)
self._candidate_choices = [chosen]

super().fix_chosen(chosen)

Expand Down
9 changes: 9 additions & 0 deletions mmrazor/models/mutables/mutable_value/mutable_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def __mul__(self, other) -> DerivedMutable:

raise TypeError(f'Unsupported type {type(other)} for mul!')

def __floordiv__(self, other) -> DerivedMutable:
if isinstance(other, int):
return self.derive_divide_mutable(other)
if isinstance(other, tuple):
assert len(other) == 2
return self.derive_divide_mutable(*other)

raise TypeError(f'Unsupported type {type(other)} for div!')


# TODO
# 1. use comparable for type hint
Expand Down
157 changes: 132 additions & 25 deletions tests/test_models/test_mutables/test_derived_mutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,18 @@

class TestDerivedMutable(TestCase):

def test_mutable_drived(self) -> None:
mv = OneShotMutableValue(value_list=[3, 5, 7])

mv_derived = mv * 4
assert isinstance(mv_derived, BaseMutable)
assert isinstance(mv_derived, DerivedMutable)
assert not mv_derived.is_fixed
assert mv_derived.num_choices == 1

mv.current_choice = mv.max_choice
assert mv_derived.current_choice == 28
mv.current_choice = mv.min_choice
assert mv_derived.current_choice == 12

with pytest.raises(RuntimeError):
mv_derived.current_choice = 123
with pytest.raises(RuntimeError):
_ = mv_derived.current_mask

chosen = mv_derived.dump_chosen()
assert chosen == {'current_choice': 12}
mv_derived.fix_chosen(chosen)
assert mv_derived.is_fixed
def test_derived_same_mutable(self) -> None:
mc = OneShotMutableChannel(
num_channels=3,
candidate_choices=[1, 2, 3],
candidate_mode='number')
mc_derived = mc.derive_same_mutable()
assert mc_derived.source_mutables == {mc}

mv.current_choice = 5
assert mv_derived.current_choice == 20
mc.current_choice = 2
assert mc_derived.current_choice == 2
assert torch.equal(mc_derived.current_mask,
torch.tensor([1, 1, 0], dtype=torch.bool))

def test_mutable_concat_derived(self) -> None:
mc1 = OneShotMutableChannel(
Expand All @@ -46,6 +32,7 @@ def test_mutable_concat_derived(self) -> None:
ms = [mc1, mc2]

mc_derived = DerivedMutable.derive_concat_mutable(ms)
assert mc_derived.source_mutables == set(ms)

mc1.current_choice = 1
mc2.current_choice = 4
Expand All @@ -61,12 +48,18 @@ def test_mutable_concat_derived(self) -> None:
mc_derived.current_mask,
torch.tensor([1, 0, 0, 1, 0, 0, 0], dtype=torch.bool))

mv = OneShotMutableValue(value_list=[1, 2, 3])
ms = [mc1, mv]
with pytest.raises(RuntimeError):
_ = DerivedMutable.derive_concat_mutable(ms)

def test_mutable_channel_derived(self) -> None:
mc = OneShotMutableChannel(
num_channels=3,
candidate_choices=[1, 2, 3],
candidate_mode='number')
mc_derived = mc * 3
assert mc_derived.source_mutables == {mc}

mc.current_choice = 1
assert mc_derived.current_choice == 3
Expand All @@ -79,3 +72,117 @@ def test_mutable_channel_derived(self) -> None:
assert torch.equal(
mc_derived.current_mask,
torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=torch.bool))

def test_mutable_divide(self) -> None:
mc = OneShotMutableChannel(
num_channels=128,
candidate_choices=[112, 120, 128],
candidate_mode='number')
mc_derived = mc // 8
assert mc_derived.source_mutables == {mc}

mc.current_choice = 128
assert mc_derived.current_choice == 16
assert torch.equal(mc_derived.current_mask,
torch.ones(16, dtype=torch.bool))
mc.current_choice = 120
assert mc_derived.current_choice == 16
assert torch.equal(mc_derived.current_mask,
torch.ones(16, dtype=torch.bool))

mv = OneShotMutableValue(value_list=[112, 120, 128])
mv_derived = mv // 8
assert mv_derived.source_mutables == {mv}

mv.current_choice == 128
assert mv_derived.current_choice == 16
mv.current_choice == 120
assert mv_derived.current_choice == 16

def test_double_fixed(self) -> None:
choice_fn = lambda x: x # noqa: E731
derived_mutable = DerivedMutable(choice_fn)
derived_mutable.fix_chosen({})

with pytest.raises(RuntimeError):
derived_mutable.fix_chosen({})

def test_source_mutables(self) -> None:
mc1 = OneShotMutableChannel(
num_channels=3, candidate_choices=[1, 3], candidate_mode='number')
mc2 = OneShotMutableChannel(
num_channels=4, candidate_choices=[1, 4], candidate_mode='number')
ms = [mc1, mc2]

mc_derived1 = DerivedMutable.derive_concat_mutable(ms)

from mmrazor.models.mutables.derived_mutable import (_concat_choice_fn,
_concat_mask_fn)
mc_derived2 = DerivedMutable(
choice_fn=_concat_choice_fn(ms),
mask_fn=_concat_mask_fn(ms),
source_mutables=ms)
assert mc_derived1.source_mutables == mc_derived2.source_mutables

dd_mutable = mc_derived1.derive_same_mutable()
assert dd_mutable.source_mutables == mc_derived1.source_mutables

with pytest.raises(ValueError):
_ = DerivedMutable(
choice_fn=lambda x: x, source_mutables=[mc_derived1])

def dict_closure_fn(x, y):

def fn():
nonlocal x, y

return fn

ddd_mutable = DerivedMutable(
choice_fn=dict_closure_fn({
mc1: [2, 3],
mc2: 2
}, None),
mask_fn=dict_closure_fn({2: [mc1, mc2]}, {3: dd_mutable}))
assert ddd_mutable.source_mutables == mc_derived1.source_mutables

mc3 = OneShotMutableChannel(
num_channels=4, candidate_choices=[2, 4], candidate_mode='number')
dddd_mutable = DerivedMutable(
choice_fn=dict_closure_fn({
mc1: [2, 3],
mc2: 2
}, []),
mask_fn=dict_closure_fn({2: [mc1, mc2, mc3]}, {3: dd_mutable}))
assert dddd_mutable.source_mutables == {mc1, mc2, mc3}


@pytest.mark.parametrize('expand_ratio', [1, 2, 3])
def test_derived_expand_mutable(expand_ratio: int) -> None:
mv = OneShotMutableValue(value_list=[3, 5, 7])

mv_derived = mv * expand_ratio
assert mv_derived.source_mutables == {mv}

assert isinstance(mv_derived, BaseMutable)
assert isinstance(mv_derived, DerivedMutable)
assert not mv_derived.is_fixed
assert mv_derived.num_choices == 1

mv.current_choice = mv.max_choice
assert mv_derived.current_choice == mv.current_choice * expand_ratio
mv.current_choice = mv.min_choice
assert mv_derived.current_choice == mv.current_choice * expand_ratio

with pytest.raises(RuntimeError):
mv_derived.current_choice = 123
with pytest.raises(RuntimeError):
_ = mv_derived.current_mask

chosen = mv_derived.dump_chosen()
assert chosen == {'current_choice': mv.current_choice * expand_ratio}
mv_derived.fix_chosen(chosen)
assert mv_derived.is_fixed

mv.current_choice = 5
assert mv_derived.current_choice == 5 * expand_ratio

0 comments on commit 8b55218

Please sign in to comment.