Skip to content

Commit

Permalink
Refine pruning branch (#307)
Browse files Browse the repository at this point in the history
* [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281)

* init

* TD: CRDLoss

* complete UT

* fix docstrings

* fix ci

* update

* fix CI

* DONE

* maintain CRD dataset unique funcs as a mixin

* maintain CRD dataset unique funcs as a mixin

* maintain CRD dataset unique funcs as a mixin

* add UT: CRD_ClsDataset

* init

* TODO: UT test formatting.

* init

* crd dataset wrapper

* update docstring

Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>

* [Improvement] Update estimator with api revision (#277)

* update estimator usage and fix bugs

* refactor api of estimator & add inner check methods

* fix docstrings

* update search loop and config

* fix lint

* update unittest

* decouple mmdet dependency and fix lint

Co-authored-by: humu789 <humu@pjlab.org.cn>

* [Fix] Fix tracer (#273)

* test image_classifier_loss_calculator

* fix backward tracer

* update SingleStageDetectorPseudoLoss

* merge

* [Feature] Add Dsnas Algorithm (#226)

* [tmp] Update Dsnas

* [tmp] refactor arch_loss & flops_loss

* Update Dsnas & MMRAZOR_EVALUATOR:
1. finalized compute_loss & handle_grads in algorithm;
2. add MMRAZOR_EVALUATOR;
3. fix bugs.

* Update lr scheduler & fix a bug:
1. update param_scheduler & lr_scheduler for dsnas;
2. fix a bug of switching to finetune stage.

* remove old evaluators

* remove old evaluators

* update param_scheduler config

* merge dev-1.x into gy/estimator

* add flops_loss in Dsnas using ResourcesEstimator

* get resources before mutator.prepare_from_supernet

* delete unness broadcast api from gml

* broadcast spec_modules_resources when estimating

* update early fix mechanism for Dsnas

* fix merge

* update units in estimator

* minor change

* fix data_preprocessor api

* add flops_loss_coef

* remove DsnasOptimWrapper

* fix bn eps and data_preprocessor

* fix bn weight decay bug

* add betas for mutator optimizer

* set diff_rank_seed=True for dsnas

* fix start_factor of lr when warm up

* remove .module in non-ddp mode

* add GlobalAveragePoolingWithDropout

* add UT for dsnas

* remove unness channel adjustment for shufflenetv2

* update supernet configs

* delete unness dropout

* delete unness part with minor change on dsnas

* minor change on the flag of search stage

* update README and subnet configs

* add UT for OneHotMutableOP

* [Feature] Update train (#279)

* support auto resume

* add enable auto_scale_lr in train.py

* support '--amp' option

* [Fix] Fix darts metafile (#278)

fix darts metafile

* fix ci (#284)

* fix ci for circle ci

* fix bug in test_metafiles

* add  pr_stage_test for github ci

* add multiple version

* fix ut

* fix lint

* Temporarily skip dataset UT

* update github ci

* add github lint ci

* install wheel

* remove timm from requirements

* install wheel when test on windows

* fix error

* fix bug

* remove github windows ci

* fix device error of arch_params when DsnasDDP

* fix CRD dataset ut

* fix scope error

* rm test_cuda in workflows of github

* [Doc] fix typos in en/usr_guides

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>

* fix bug when python=3.6

* fix lint

* fix bug when test using cpu only

* refine ci

* fix error in ci

* try ci

* update repr of Channel

* fix error

* mv init_from_predefined_model to MutableChannelUnit

* move tests

* update SquentialMutableChannel

* update l1 mutable channel unit

* add OneShotMutableChannel

* candidate_mode -> choice_mode

* update docstring

* change ci

Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>
  • Loading branch information
10 people committed Oct 10, 2022
1 parent 8330b62 commit 3715bbc
Show file tree
Hide file tree
Showing 35 changed files with 354 additions and 260 deletions.
2 changes: 1 addition & 1 deletion .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
docker exec mmrazor pip install -e /mmdetection
docker exec mmrazor pip install -e /mmclassification
docker exec mmrazor pip install -e /mmsegmentation
docker exec mmrazor pip install -r requirements/tests.txt
docker exec mmrazor pip install -r requirements.txt
- run:
name: Build and install
command: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
type='OneShotMutableChannelUnit',
default_args=dict(
candidate_choices=list(i / 12 for i in range(2, 13)),
candidate_mode='ratio',
choice_mode='ratio',
divisor=8)),
parse_cfg=dict(
type='BackwardTracer',
Expand Down
4 changes: 3 additions & 1 deletion mmrazor/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation',
'ItePruneAlgorithm', 'DAFLDataFreeDistillation',
'OverhaulFeatureDistillation', 'Dsnas', 'DsnasDDP'
'OverhaulFeatureDistillation', 'Dsnas', 'DsnasDDP',
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas',
'DsnasDDP'
]
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ def init_candidates(self, candidates: List):
for num in candidates:
self.candidate_bn[str(num)] = nn.BatchNorm2d(
num, self.eps, self.momentum, self.affine,
self.track_running_stats, self.weight.device,
self.weight.dtype)
self.track_running_stats)

def forward(self, input: Tensor) -> Tensor:
"""Forward."""
Expand Down
7 changes: 5 additions & 2 deletions mmrazor/models/mutables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .base_mutable import BaseMutable
from .derived_mutable import DerivedMutable
from .mutable_channel import (BaseMutableChannel, MutableChannelContainer,
SimpleMutableChannel, SquentialMutableChannel)
OneShotMutableChannel, SimpleMutableChannel,
SquentialMutableChannel)
from .mutable_channel.units import (ChannelUnitType, L1MutableChannelUnit,
MutableChannelUnit,
OneShotMutableChannelUnit,
Expand All @@ -22,5 +23,7 @@
'BaseMutableChannel', 'MutableChannelContainer', 'ChannelUnitType',
'SquentialMutableChannel', 'BaseMutable', 'DiffChoiceRoute',
'DiffMutableModule', 'DerivedMutable', 'MutableValue',
'OneShotMutableValue', 'OneHotMutableOP'
'OneShotMutableValue', 'OneHotMutableOP', 'OneShotMutableChannel',
'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel',
'DerivedMutable', 'MutableValue', 'OneShotMutableValue', 'OneHotMutableOP'
]
13 changes: 7 additions & 6 deletions mmrazor/models/mutables/mutable_channel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_mutable_channel import BaseMutableChannel
from .units import (ChannelUnitType, L1MutableChannelUnit,
MutableChannelUnit, OneShotMutableChannelUnit,
SequentialMutableChannelUnit, SlimmableChannelUnit)
from .mutable_channel_container import MutableChannelContainer
from .oneshot_mutalbe_channel import OneShotMutableChannel
from .sequential_mutable_channel import SquentialMutableChannel
from .simple_mutable_channel import SimpleMutableChannel
from .units import (ChannelUnitType, L1MutableChannelUnit, MutableChannelUnit,
OneShotMutableChannelUnit, SequentialMutableChannelUnit,
SlimmableChannelUnit)

__all__ = [
'SimpleMutableChannel', 'L1MutableChannelUnit',
'SequentialMutableChannelUnit', 'MutableChannelUnit',
'OneShotMutableChannelUnit', 'SlimmableChannelUnit',
'BaseMutableChannel', 'MutableChannelContainer', 'SquentialMutableChannel',
'ChannelUnitType'
'OneShotMutableChannelUnit', 'SlimmableChannelUnit', 'BaseMutableChannel',
'MutableChannelContainer', 'SquentialMutableChannel', 'ChannelUnitType',
'OneShotMutableChannel'
]
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MutableChannelContainer(BaseMutableChannel):

def __init__(self, num_channels: int, **kwargs):
super().__init__(num_channels, **kwargs)
self.mutable_channels: IndexDict[BaseMutableChannel] = IndexDict()
self.mutable_channels = IndexDict()

# choice

Expand Down
41 changes: 41 additions & 0 deletions mmrazor/models/mutables/mutable_channel/oneshot_mutalbe_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union

from .sequential_mutable_channel import SquentialMutableChannel


class OneShotMutableChannel(SquentialMutableChannel):
"""OneShotMutableChannel is a subclass of SquentialMutableChannel. The
difference is that a OneShotMutableChannel limits the candidates of the
choice.
Args:
num_channels (int): number of channels.
candidate_choices (List[Union[float, int]], optional): A list of
candidate width ratios. Each candidate indicates how many
channels to be reserved. Defaults to [].
choice_mode (str, optional): Mode of choices. Defaults to 'number'.
"""

def __init__(self,
num_channels: int,
candidate_choices: List[Union[float, int]] = [],
choice_mode='number',
**kwargs):
super().__init__(num_channels, choice_mode, **kwargs)
self.candidate_choices = candidate_choices
if candidate_choices == []:
candidate_choices.append(num_channels if self.is_num_mode else 1.0)

@property
def current_choice(self) -> Union[int, float]:
"""Get current choice."""
return super().current_choice

@current_choice.setter
def current_choice(self, choice: Union[int, float]):
"""Set current choice."""
assert choice in self.candidate_choices
SquentialMutableChannel.current_choice.fset( # type: ignore
self, # type: ignore
choice) # type: ignore
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable
from typing import Callable, Union

import torch

from mmrazor.registry import MODELS
from ..derived_mutable import DerivedMutable
from .base_mutable_channel import BaseMutableChannel
from .simple_mutable_channel import SimpleMutableChannel

# TODO discuss later


@MODELS.register_module()
class SquentialMutableChannel(BaseMutableChannel):
class SquentialMutableChannel(SimpleMutableChannel):
"""SquentialMutableChannel defines a BaseMutableChannel which switch off
channel mask from right to left sequentially, like '11111000'.
Expand All @@ -22,21 +22,36 @@ class SquentialMutableChannel(BaseMutableChannel):
num_channels (int): number of channels.
"""

def __init__(self, num_channels: int, **kwargs):
def __init__(self, num_channels: int, choice_mode='number', **kwargs):

super().__init__(num_channels, **kwargs)
assert choice_mode in ['ratio', 'number']
self.choice_mode = choice_mode
self.mask = torch.ones([self.num_channels]).bool()

@property
def current_choice(self) -> int:
def is_num_mode(self):
"""Get if the choice is number mode."""
return self.choice_mode == 'number'

@property
def current_choice(self) -> Union[int, float]:
"""Get current choice."""
return (self.mask == 1).sum().item()
int_choice = (self.mask == 1).sum().item()
if self.is_num_mode:
return int_choice
else:
return self._num2ratio(int_choice)

@current_choice.setter
def current_choice(self, choice: int):
def current_choice(self, choice: Union[int, float]):
"""Set choice."""
if isinstance(choice, float):
int_choice = self._ratio2num(choice)
else:
int_choice = choice
mask = torch.zeros([self.num_channels], device=self.mask.device)
mask[0:choice] = 1
mask[0:int_choice] = 1
self.mask = mask.bool()

@property
Expand All @@ -58,20 +73,6 @@ def dump_chosen(self):
"""Dump chosen."""
return self.current_choice

# def __mul__(self, other):
# """multiplication."""
# if isinstance(other, int):
# return self.derive_expand_mutable(other)
# else:
# return None

# def __floordiv__(self, other):
# """division."""
# if isinstance(other, int):
# return self.derive_divide_mutable(other)
# else:
# return None

def __rmul__(self, other) -> DerivedMutable:
return self * other

Expand Down Expand Up @@ -121,3 +122,17 @@ def __floordiv__(self, other) -> DerivedMutable:
return self.derive_divide_mutable(*other)

raise TypeError(f'Unsupported type {type(other)} for div!')

def _num2ratio(self, choice: Union[int, float]) -> float:
"""Convert the a number choice to a ratio choice."""
if isinstance(choice, float):
return choice
else:
return choice / self.num_channels

def _ratio2num(self, choice: Union[int, float]) -> int:
"""Convert the a ratio choice to a number choice."""
if isinstance(choice, int):
return choice
else:
return max(1, int(self.num_channels * choice))
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def is_mutable(self) -> bool:

def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'{self.name}, index=({self.index}), '
f'{self.name}, index={self.index}, '
f'is_output_channel='
f'{"true" if self.is_output_channel else "false"}, '
f'expand_ratio={self.expand_ratio}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn

from mmrazor.registry import MODELS
from ..simple_mutable_channel import SimpleMutableChannel
from .sequential_mutable_channel_unit import SequentialMutableChannelUnit


Expand All @@ -25,6 +26,23 @@ def __init__(self,
min_ratio=0.9) -> None:
super().__init__(num_channels, choice_mode, divisor, min_value,
min_ratio)
self.mutable_channel = SimpleMutableChannel(num_channels)

# choices

@property
def current_choice(self) -> Union[int, float]:
num = self.mutable_channel.activated_channels
if self.is_num_mode:
return num
else:
return self._num2ratio(num)

@current_choice.setter
def current_choice(self, choice: Union[int, float]):
int_choice = self._get_valid_int_choice(choice)
mask = self._generate_mask(int_choice).bool()
self.mutable_channel.current_choice = mask

# private methods

Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This module defines MutableChannelUnit."""
import abc
from collections import Set
from typing import Dict, List, Type, TypeVar

import torch.nn as nn

import mmrazor.models.architectures.dynamic_ops as dynamic_ops
from mmrazor.models.architectures import dynamic_ops
from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin
from mmrazor.models.mutables import DerivedMutable
from mmrazor.models.mutables.mutable_channel.base_mutable_channel import \
BaseMutableChannel
from ..mutable_channel_container import MutableChannelContainer
from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel,
MutableChannelContainer)
from .channel_unit import Channel, ChannelUnit


class MutableChannelUnit(ChannelUnit):

# init methods
def __init__(self, num_channels: int, **kwargs) -> None:
"""MutableChannelUnit inherits from ChannelUnit, which manages
channels with channel-dependency.
"""MutableChannelUnit inherits from ChannelUnit, which manages channels
with channel-dependency.
Compared with ChannelUnit, MutableChannelUnit defines the core
interfaces for pruning. By inheriting MutableChannelUnit,
Expand All @@ -44,6 +44,70 @@ def __init__(self, num_channels: int, **kwargs) -> None:

super().__init__(num_channels)

@classmethod
def init_from_mutable_channel(cls, mutable_channel: BaseMutableChannel):
unit = cls(mutable_channel.num_channels)
return unit

@classmethod
def init_from_predefined_model(cls, model: nn.Module):
"""Initialize units using the model with pre-defined dynamicops and
mutable-channels."""

def process_container(contanier: MutableChannelContainer,
module,
module_name,
mutable2units,
is_output=True):
for index, mutable in contanier.mutable_channels.items():
if isinstance(mutable, DerivedMutable):
source_mutables: Set = \
mutable._trace_source_mutables()
source_channel_mutables = [
mutable for mutable in source_mutables
if isinstance(mutable, BaseMutableChannel)
]
assert len(source_channel_mutables) == 1, (
'only support one mutable channel '
'used in DerivedMutable')
mutable = list(source_channel_mutables)[0]

if mutable not in mutable2units:
mutable2units[mutable] = cls.init_from_mutable_channel(
mutable)

unit: MutableChannelUnit = mutable2units[mutable]
if is_output:
unit.add_ouptut_related(
Channel(
module_name,
module,
index,
is_output_channel=is_output))
else:
unit.add_input_related(
Channel(
module_name,
module,
index,
is_output_channel=is_output))

mutable2units: Dict = {}
for name, module in model.named_modules():
if isinstance(module, DynamicChannelMixin):
in_container: MutableChannelContainer = \
module.get_mutable_attr(
'in_channels')
out_container: MutableChannelContainer = \
module.get_mutable_attr(
'out_channels')
process_container(in_container, module, name, mutable2units,
False)
process_container(out_container, module, name, mutable2units,
True)
units = list(mutable2units.values())
return units

# properties

@property
Expand Down Expand Up @@ -97,7 +161,7 @@ def prepare_for_pruning(self, model):
For example, we need to register mutables to dynamic-ops.
"""
raise not NotImplementedError
raise NotImplementedError

# pruning: choice-related

Expand Down
Loading

0 comments on commit 3715bbc

Please sign in to comment.