Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add DerivedMutable & MutableValue #215

Merged
merged 9 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mmrazor/models/mutables/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .derived_mutable import DerivedMutable
from .mutable_channel import (MutableChannel, OneShotMutableChannel,
SlimmableMutableChannel)
from .mutable_manage_mixin import MutableManageMixIn
from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP,
OneShotMutableModule, OneShotMutableOP)
from .mutable_value import MutableValue, OneShotMutableValue

__all__ = [
'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP',
'DiffChoiceRoute', 'DiffMutableModule', 'MutableManageMixIn',
'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel'
'OneShotMutableChannel', 'SlimmableMutableChannel', 'MutableChannel',
'DerivedMutable', 'MutableValue', 'OneShotMutableValue'
]
382 changes: 382 additions & 0 deletions mmrazor/models/mutables/derived_mutable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,382 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import logging
from collections.abc import Iterable
from typing import Any, Callable, Dict, Optional, Protocol, Set, Union

import torch
from mmengine.logging import print_log
from torch import Tensor

from ..utils import make_divisible
from .base_mutable import CHOICE_TYPE, BaseMutable


class MutableProtocol(Protocol): # pragma: no cover
"""Protocol for Mutable."""

@property
def current_choice(self) -> Any:
"""Current choice."""

def derive_expand_mutable(self, expand_ratio: int) -> Any:
"""Derive expand mutable."""

def derive_divide_mutable(self, ratio: int, divisor: int) -> Any:
"""Derive divide mutable."""


class MutableChannelProtocol(MutableProtocol): # pragma: no cover
"""Protocol for MutableChannel."""

@property
def current_mask(self) -> Tensor:
"""Current mask."""


def _expand_choice_fn(mutable: MutableProtocol, expand_ratio: int) -> Callable:
"""Helper function to build `choice_fn` for expand derived mutable."""

def fn():
return mutable.current_choice * expand_ratio

return fn


def _expand_mask_fn(mutable: MutableProtocol,
expand_ratio: int) -> Callable: # pragma: no cover
"""Helper function to build `mask_fn` for expand derived mutable."""
if not hasattr(mutable, 'current_mask'):
raise ValueError('mutable must have attribute `currnet_mask`')

def fn():
mask = mutable.current_mask
expand_num_channels = mask.size(0) * expand_ratio
expand_choice = mutable.current_choice * expand_ratio
expand_mask = torch.zeros(expand_num_channels).bool()
expand_mask[:expand_choice] = True

return expand_mask

return fn


def _divide_and_divise(x: int, ratio: int, divisor: int = 8) -> int:
"""Helper function for divide and divise."""
new_x = x // ratio

return make_divisible(new_x, divisor)


def _divide_choice_fn(mutable: MutableProtocol,
ratio: int,
divisor: int = 8) -> Callable:
"""Helper function to build `choice_fn` for divide derived mutable."""

def fn():
return _divide_and_divise(mutable.current_choice, ratio, divisor)

return fn


def _divide_mask_fn(mutable: MutableProtocol,
ratio: int,
divisor: int = 8) -> Callable: # pragma: no cover
"""Helper function to build `mask_fn` for divide derived mutable."""
if not hasattr(mutable, 'current_mask'):
raise ValueError('mutable must have attribute `currnet_mask`')

def fn():
mask = mutable.current_mask
divide_num_channels = _divide_and_divise(mask.size(0), ratio, divisor)
divide_choice = _divide_and_divise(mutable.current_choice, ratio,
divisor)
divide_mask = torch.zeros(divide_num_channels).bool()
divide_mask[:divide_choice] = True

return divide_mask

return fn


def _concat_choice_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable:
"""Helper function to build `choice_fn` for concat derived mutable."""

def fn():
return sum((m.current_choice for m in mutables))

return fn


def _concat_mask_fn(mutables: Iterable[MutableChannelProtocol]) -> Callable:
"""Helper function to build `mask_fn` for concat derived mutable."""

def fn():
return torch.cat([m.current_mask for m in mutables])

return fn


class DerivedMethodMixin:
"""A mixin that provides some useful method to derive mutable."""

def derive_same_mutable(self: MutableProtocol) -> 'DerivedMutable':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo? ‘DerivedMutable’ -> DerivedMutable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for using str here is that class DeriveMutable is defined after DerivedMethodMixin. To directly use DerivedMutable for type hint, the version of python should >= 3.7. More details can be found at this link

"""Derive same mutable as the source."""
return self.derive_expand_mutable(expand_ratio=1)

def derive_expand_mutable(self: MutableProtocol,
expand_ratio: int) -> 'DerivedMutable':
wutongshenqiu marked this conversation as resolved.
Show resolved Hide resolved
"""Derive expand mutable, usually used with `expand_ratio`."""
choice_fn = _expand_choice_fn(self, expand_ratio=expand_ratio)

mask_fn: Optional[Callable] = None
if hasattr(self, 'current_mask'):
mask_fn = _expand_mask_fn(self, expand_ratio=expand_ratio)

return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)

def derive_divide_mutable(self: MutableProtocol,
ratio: int,
divisor: int = 8) -> 'DerivedMutable':
wutongshenqiu marked this conversation as resolved.
Show resolved Hide resolved
"""Derive divide mutable, usually used with `make_divisable`."""
choice_fn = _divide_choice_fn(self, ratio=ratio, divisor=divisor)

mask_fn: Optional[Callable] = None
if hasattr(self, 'current_mask'):
mask_fn = _divide_mask_fn(self, ratio=ratio, divisor=divisor)

return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)

@staticmethod
def derive_concat_mutable(
mutables: Iterable[MutableChannelProtocol]) -> 'DerivedMutable':
wutongshenqiu marked this conversation as resolved.
Show resolved Hide resolved
"""Derive concat mutable, usually used with `torch.cat`."""
for mutable in mutables:
if not hasattr(mutable, 'current_mask'):
raise RuntimeError('Source mutable of concat derived mutable '
'must have attribute `currnet_mask`')

choice_fn = _concat_choice_fn(mutables)
mask_fn = _concat_mask_fn(mutables)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the protocal (requires to contain current_mask attr) of all the mutables before assign mask_fn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


return DerivedMutable(choice_fn=choice_fn, mask_fn=mask_fn)


class DerivedMutable(BaseMutable[CHOICE_TYPE, CHOICE_TYPE],
DerivedMethodMixin):
"""Class for derived mutable.

A derived mutable is a mutable derived from other mutables that has
`current_choice` and `current_mask` attributes (if any).

Note:
A derived mutable does not have its own search space, so it is
not legal to modify its `current_choice` or `current_mask` directly.
And the only way to modify them is by modifying `current_choice` or
`current_mask` in corresponding source mutables.

Args:
choice_fn (callable): A closure that controls how to generate
`current_choice`.
mask_fn (callable, optional): A closure that controls how to generate
`current_mask`. Defaults to None.
source_mutables (iterable, optional): Specify source mutables for this
derived mutable. If the argument is None, source mutables will be
traced automatically by parsing mutables in closure variables.
Defaults to None.
alias (str, optional): alias of the `MUTABLE`. Defaults to None.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
and `Pretrained`. Defaults to None.

Examples:
>>> from mmrazor.models.mutables import OneShotMutableChannel
>>> mutable_channel = OneShotMutableChannel(
... num_channels=3,
... candidate_choices=[1, 2, 3],
... candidate_mode='number')
>>> # derive expand mutable
>>> derived_mutable_channel = mutable_channel * 2
>>> # source mutables will be traced automatically
>>> derived_mutable_channel.source_mutables
{OneShotMutableChannel(name=unbind, num_channels=3, current_choice=3, choices=[1, 2, 3], activated_channels=3, concat_mutable_name=[])} # noqa: E501
>>> # modify `current_choice` of `mutable_channel`
>>> mutable_channel.current_choice = 2
>>> # `current_choice` and `current_mask` of derived mutable will be modified automatically # noqa: E501
>>> derived_mutable_channel
DerivedMutable(current_choice=4, activated_channels=4, source_mutables={OneShotMutableChannel(name=unbind, num_channels=3, current_choice=2, choices=[1, 2, 3], activated_channels=2, concat_mutable_name=[])}, is_fixed=False) # noqa: E501
"""

def __init__(self,
choice_fn: Callable,
mask_fn: Optional[Callable] = None,
source_mutables: Optional[Iterable[BaseMutable]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None) -> None:
super().__init__(alias, init_cfg)

self.choice_fn = choice_fn
self.mask_fn = mask_fn

if source_mutables is None:
source_mutables = self._trace_source_mutables()
if len(source_mutables) == 0:
raise RuntimeError(
'Can not find source mutables automatically, '
'please provide manually.')
else:
source_mutables = set(source_mutables)
for mutable in source_mutables:
if not self.is_source_mutable(mutable):
raise ValueError('Expect all mutable to be source mutable, '
f'but {mutable} is not')
self.source_mutables = source_mutables

# TODO
# has no effect
def fix_chosen(self, chosen: CHOICE_TYPE) -> None:
"""Fix mutable with subnet config.

Warning:
Fix derived mutable will have no actually effect.
"""
print_log(
'Trying to fix chosen for derived mutable, '
'which will have no effect.',
level=logging.WARNING)

def dump_chosen(self) -> CHOICE_TYPE:
"""Dump information of chosen.

Returns:
Dict: Dumped information.
"""
print_log(
'Trying to dump chosen for derived mutable, '
'but its value depend on the source mutables.',
level=logging.WARNING)
return self.current_choice

@property
def is_fixed(self) -> bool:
"""Whether the derived mutable is fixed.

Note:
Depends on whether all source mutables are already fixed.
"""
return all(m.is_fixed for m in self.source_mutables)

@is_fixed.setter
def is_fixed(self, is_fixed: bool) -> bool:
"""Setter of is fixed."""
raise RuntimeError(
'`is_fixed` of derived mutable should not be modified directly')

@property
def num_choices(self) -> int:
"""Number of all choices.

Note:
Since derive mutable does not have its own search space, the number
of choices will always be `1`.

Returns:
int: Number of choices.
"""
return 1

@property
def current_choice(self) -> CHOICE_TYPE:
"""Current choice of derived mutable."""
return self.choice_fn()

@current_choice.setter
def current_choice(self, choice: CHOICE_TYPE) -> None:
"""Setter of current choice.

Raises:
RuntimeError: Error when `current_choice` of derived mutable
is modified directly.
"""
raise RuntimeError('Choice of drived mutable can not be set.')

@property
def current_mask(self) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can current_mask can be set directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related methods have been overloaded to prohibit direct modification of current_mask

"""Current mask of derived mutable."""
if self.mask_fn is None:
raise RuntimeError(
'`mask_fn` must be set before access `current_mask`.')
return self.mask_fn()

@current_mask.setter
def current_mask(self, mask: Tensor) -> None:
"""Setter of current mask.

Raises:
RuntimeError: Error when `current_mask` of derived mutable
is modified directly.
"""
raise RuntimeError('Mask of drived mutable can not be set.')

@staticmethod
def _trace_source_mutables_from_closure(
closure: Callable) -> Set[BaseMutable]:
"""Trace source mutables from closure."""
source_mutables: Set[BaseMutable] = set()

def add_mutables_dfs(
mutable: Union[Iterable, BaseMutable, Dict]) -> None:
nonlocal source_mutables
if isinstance(mutable, BaseMutable):
if isinstance(mutable, DerivedMutable):
source_mutables |= mutable.source_mutables
else:
source_mutables.add(mutable)
# dict is also iterable, should parse first
elif isinstance(mutable, dict):
add_mutables_dfs(mutable.values())
add_mutables_dfs(mutable.keys())
elif isinstance(mutable, Iterable):
for m in mutable:
add_mutables_dfs(m)

noncolcal_pars = inspect.getclosurevars(closure).nonlocals
add_mutables_dfs(noncolcal_pars.values())

return source_mutables

def _trace_source_mutables(self) -> Set[BaseMutable]:
"""Trace source mutables."""
source_mutables = self._trace_source_mutables_from_closure(
self.choice_fn)
if self.mask_fn is not None:
source_mutables |= self._trace_source_mutables_from_closure(
self.mask_fn)

return source_mutables

@staticmethod
def is_source_mutable(mutable: object) -> bool:
"""Judge whether an object is source mutable(not derived mutable).

Args:
mutable (object): An object.

Returns:
bool: Indicate whether the object is source mutable or not.
"""
return isinstance(mutable, BaseMutable) and \
not isinstance(mutable, DerivedMutable)

# TODO
# should be __str__? but can not provide info when debug
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could provide the link_path from DerivedMutable to Source mutables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provide attribute source_mutables to trace source mutables of derived mutable

def __repr__(self) -> str: # pragma: no cover
s = f'{self.__class__.__name__}('
s += f'current_choice={self.current_choice}, '
if self.mask_fn is not None:
s += f'activated_channels={self.current_mask.sum().item()}, '
s += f'source_mutables={self.source_mutables}, '
s += f'is_fixed={self.is_fixed})'

return s
Loading