Skip to content

Commit

Permalink
merge pruning into dev-1.x (#312)
Browse files Browse the repository at this point in the history
* add ChannelGroup (#250)

* rebase new dev-1.x

* modification for adding config_template

* add docstring to channel_group.py

* add docstring to mutable_channel_group.py

* rm channel_group_cfg from Graph2ChannelGroups

* change choice type of SequentialChannelGroup from float to int

* add a warning about group-wise conv

* restore __init__ of dynamic op

* in_channel_mutable  ->  mutable_in_channel

* rm abstractproperty

* add a comment about VT

* rm registry for ChannelGroup

* MUTABLECHANNELGROUP -> ChannelGroupType

* refine docstring of IndexDict

* update docstring

* update docstring

* is_prunable -> is_mutable

* update docstring

* fix error in pre-commit

* update unittest

* add return type

* unify init_xxx apit

* add unitest about init of MutableChannelGroup

* update according to reviews

* sequential_channel_group -> sequential_mutable_channel_group

Co-authored-by: liukai <liukai@pjlab.org.cn>

* Add BaseChannelMutator and refactor Autoslim (#289)

* add BaseChannelMutator

* add autoslim

* tmp

* make SequentialMutableChannelGroup accpeted both of num and ratio  as choice. and supports divisior

* update OneShotMutableChannelGroup

* pass supernet training of autoslim

* refine autoslim

* fix bug in OneShotMutableChannelGroup

* refactor make_divisible

* fix spell error:  channl -> channel

* init_using_backward_tracer -> init_from_backward_tracer
init_from_fx_tracer -> init_from_fx_tracer

* refine SequentialMutableChannelGroup

* let mutator support models with dynamicop

* support define search space in model

* tracer_cfg -> parse_cfg

* refine

* using -> from

* update docstring

* update docstring

Co-authored-by: liukai <liukai@pjlab.org.cn>

* refactor slimmable and add l1-norm (#291)

* refactor slimmable and add l1-norm

* make l1-norm support convnd

* update get_channel_groups

* add  l1-norm_resnet34_8xb32_in1k.py

* add pretrained to resnet34-l1

* remove old channel mutator

* BaseChannelMutator -> ChannelMutator

* update according to reviews

* add readme to l1-norm

* MBV2_slimmable -> MBV2_slimmable_config

Co-authored-by: liukai <liukai@pjlab.org.cn>

* Clean old codes. (#296)

* remove old dynamic ops

* move dynamic ops

* clean old mutable_channels

* rm OneShotMutableChannel

* rm MutableChannel

* refine

* refine

* use SquentialMutableChannel to replace OneshotMutableChannel

* refactor dynamicops folder

* let SquentialMutableChannel support float

Co-authored-by: liukai <liukai@pjlab.org.cn>

* Add channel-flow (#301)

* base_channel_mutator -> channel_mutator

* init

* update docstring

* allow omitting redundant configs for channel

* add register_mutable_channel_to_a_module to MutableChannelContainer

* update according to reviews 1

* update according to reviews 2

* update according to reviews 3

* remove old docstring

* fix error

* using->from

* update according to reviews

* support self-define input channel number

* update docstring

* chanenl -> channel_elem

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>

* Rename: ChannelGroup -> ChannelUnit (#302)

* refine repr of MutableChannelGroup

* rename folder name

* ChannelGroup -> ChannelUnit

* filename in units folder

* channel_group -> channel_unit

* groups -> units

* group -> unit

* update

* get_mutable_channel_groups -> get_mutable_channel_units

* fix bug

* refine docstring

* fix ci

* fix bug in tracer

Co-authored-by: liukai <liukai@pjlab.org.cn>

* Merge dev-1.x to pruning (#311)

* [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281)

* init

* TD: CRDLoss

* complete UT

* fix docstrings

* fix ci

* update

* fix CI

* DONE

* maintain CRD dataset unique funcs as a mixin

* maintain CRD dataset unique funcs as a mixin

* maintain CRD dataset unique funcs as a mixin

* add UT: CRD_ClsDataset

* init

* TODO: UT test formatting.

* init

* crd dataset wrapper

* update docstring

Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>

* [Improvement] Update estimator with api revision (#277)

* update estimator usage and fix bugs

* refactor api of estimator & add inner check methods

* fix docstrings

* update search loop and config

* fix lint

* update unittest

* decouple mmdet dependency and fix lint

Co-authored-by: humu789 <humu@pjlab.org.cn>

* [Fix] Fix tracer (#273)

* test image_classifier_loss_calculator

* fix backward tracer

* update SingleStageDetectorPseudoLoss

* merge

* [Feature] Add Dsnas Algorithm (#226)

* [tmp] Update Dsnas

* [tmp] refactor arch_loss & flops_loss

* Update Dsnas & MMRAZOR_EVALUATOR:
1. finalized compute_loss & handle_grads in algorithm;
2. add MMRAZOR_EVALUATOR;
3. fix bugs.

* Update lr scheduler & fix a bug:
1. update param_scheduler & lr_scheduler for dsnas;
2. fix a bug of switching to finetune stage.

* remove old evaluators

* remove old evaluators

* update param_scheduler config

* merge dev-1.x into gy/estimator

* add flops_loss in Dsnas using ResourcesEstimator

* get resources before mutator.prepare_from_supernet

* delete unness broadcast api from gml

* broadcast spec_modules_resources when estimating

* update early fix mechanism for Dsnas

* fix merge

* update units in estimator

* minor change

* fix data_preprocessor api

* add flops_loss_coef

* remove DsnasOptimWrapper

* fix bn eps and data_preprocessor

* fix bn weight decay bug

* add betas for mutator optimizer

* set diff_rank_seed=True for dsnas

* fix start_factor of lr when warm up

* remove .module in non-ddp mode

* add GlobalAveragePoolingWithDropout

* add UT for dsnas

* remove unness channel adjustment for shufflenetv2

* update supernet configs

* delete unness dropout

* delete unness part with minor change on dsnas

* minor change on the flag of search stage

* update README and subnet configs

* add UT for OneHotMutableOP

* [Feature] Update train (#279)

* support auto resume

* add enable auto_scale_lr in train.py

* support '--amp' option

* [Fix] Fix darts metafile (#278)

fix darts metafile

* fix ci (#284)

* fix ci for circle ci

* fix bug in test_metafiles

* add  pr_stage_test for github ci

* add multiple version

* fix ut

* fix lint

* Temporarily skip dataset UT

* update github ci

* add github lint ci

* install wheel

* remove timm from requirements

* install wheel when test on windows

* fix error

* fix bug

* remove github windows ci

* fix device error of arch_params when DsnasDDP

* fix CRD dataset ut

* fix scope error

* rm test_cuda in workflows of github

* [Doc] fix typos in en/usr_guides

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>

Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>

* Refine pruning branch (#307)

* [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281)

* init

* TD: CRDLoss

* complete UT

* fix docstrings

* fix ci

* update

* fix CI

* DONE

* maintain CRD dataset unique funcs as a mixin

* maintain CRD dataset unique funcs as a mixin

* maintain CRD dataset unique funcs as a mixin

* add UT: CRD_ClsDataset

* init

* TODO: UT test formatting.

* init

* crd dataset wrapper

* update docstring

Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>

* [Improvement] Update estimator with api revision (#277)

* update estimator usage and fix bugs

* refactor api of estimator & add inner check methods

* fix docstrings

* update search loop and config

* fix lint

* update unittest

* decouple mmdet dependency and fix lint

Co-authored-by: humu789 <humu@pjlab.org.cn>

* [Fix] Fix tracer (#273)

* test image_classifier_loss_calculator

* fix backward tracer

* update SingleStageDetectorPseudoLoss

* merge

* [Feature] Add Dsnas Algorithm (#226)

* [tmp] Update Dsnas

* [tmp] refactor arch_loss & flops_loss

* Update Dsnas & MMRAZOR_EVALUATOR:
1. finalized compute_loss & handle_grads in algorithm;
2. add MMRAZOR_EVALUATOR;
3. fix bugs.

* Update lr scheduler & fix a bug:
1. update param_scheduler & lr_scheduler for dsnas;
2. fix a bug of switching to finetune stage.

* remove old evaluators

* remove old evaluators

* update param_scheduler config

* merge dev-1.x into gy/estimator

* add flops_loss in Dsnas using ResourcesEstimator

* get resources before mutator.prepare_from_supernet

* delete unness broadcast api from gml

* broadcast spec_modules_resources when estimating

* update early fix mechanism for Dsnas

* fix merge

* update units in estimator

* minor change

* fix data_preprocessor api

* add flops_loss_coef

* remove DsnasOptimWrapper

* fix bn eps and data_preprocessor

* fix bn weight decay bug

* add betas for mutator optimizer

* set diff_rank_seed=True for dsnas

* fix start_factor of lr when warm up

* remove .module in non-ddp mode

* add GlobalAveragePoolingWithDropout

* add UT for dsnas

* remove unness channel adjustment for shufflenetv2

* update supernet configs

* delete unness dropout

* delete unness part with minor change on dsnas

* minor change on the flag of search stage

* update README and subnet configs

* add UT for OneHotMutableOP

* [Feature] Update train (#279)

* support auto resume

* add enable auto_scale_lr in train.py

* support '--amp' option

* [Fix] Fix darts metafile (#278)

fix darts metafile

* fix ci (#284)

* fix ci for circle ci

* fix bug in test_metafiles

* add  pr_stage_test for github ci

* add multiple version

* fix ut

* fix lint

* Temporarily skip dataset UT

* update github ci

* add github lint ci

* install wheel

* remove timm from requirements

* install wheel when test on windows

* fix error

* fix bug

* remove github windows ci

* fix device error of arch_params when DsnasDDP

* fix CRD dataset ut

* fix scope error

* rm test_cuda in workflows of github

* [Doc] fix typos in en/usr_guides

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>

* fix bug when python=3.6

* fix lint

* fix bug when test using cpu only

* refine ci

* fix error in ci

* try ci

* update repr of Channel

* fix error

* mv init_from_predefined_model to MutableChannelUnit

* move tests

* update SquentialMutableChannel

* update l1 mutable channel unit

* add OneShotMutableChannel

* candidate_mode -> choice_mode

* update docstring

* change ci

Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>
  • Loading branch information
11 people committed Oct 10, 2022
1 parent f98ac34 commit b4b7e24
Show file tree
Hide file tree
Showing 95 changed files with 4,992 additions and 2,898 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,19 @@

# !autoslim algorithm config
# ==========================================================================
channel_cfg_paths = [
'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-220M_acc-71.4_20220715-9c288f3b_subnet_cfg.yaml', # noqa: E501
'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-320M_acc-72.73_20220715-9aa8f8ae_subnet_cfg.yaml', # noqa: E501
'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-530M_acc-74.23_20220715-aa8754fe_subnet_cfg.yaml' # noqa: E501
]

model = dict(
_delete_=True,
_scope_='mmrazor',
type='SlimmableNetwork',
architecture=supernet,
data_preprocessor=data_preprocessor,
channel_cfg_paths=channel_cfg_paths,
mutator=dict(
type='SlimmableChannelMutator',
mutable_cfg=dict(type='SlimmableMutableChannel'),
tracer_cfg=dict(
channel_unit_cfg=dict(
type='SlimmableChannelUnit',
units='tests/data/MBV2_slimmable_config.json'),
parse_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='ImageClassifierPseudoLoss'))))

Expand All @@ -46,6 +42,6 @@
broadcast_buffers=False,
find_unused_parameters=True)

optim_wrapper = dict(accumulative_counts=len(channel_cfg_paths))
optim_wrapper = dict(accumulative_counts=3)

val_cfg = dict(type='mmrazor.SlimmableValLoop')
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
_base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py'

_channel_cfg_paths = 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-220M_acc-71.4_20220715-9c288f3b_subnet_cfg.yaml' # noqa: E501
model = dict(channel_cfg_paths=_channel_cfg_paths)
model = dict(deploy_index=0)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
_base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py'

_channel_cfg_paths = 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-320M_acc-72.73_20220715-9aa8f8ae_subnet_cfg.yaml' # noqa: E501
model = dict(channel_cfg_paths=_channel_cfg_paths)
model = dict(deploy_index=1)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
_base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py'

_channel_cfg_paths = 'https://download.openmmlab.com/mmrazor/v1/autoslim/autoslim_mbv2_subnet_8xb256_in1k_flops-530M_acc-74.23_20220715-aa8754fe_subnet_cfg.yaml' # noqa: E501
model = dict(channel_cfg_paths=_channel_cfg_paths)
model = dict(deploy_index=2)
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@
preds_T=dict(recorder='fc', from_student=False)))),
mutator=dict(
type='OneShotChannelMutator',
mutable_cfg=dict(
type='OneShotMutableChannel',
candidate_choices=list(i / 12 for i in range(2, 13)),
candidate_mode='ratio'),
tracer_cfg=dict(
channel_unit_cfg=dict(
type='OneShotMutableChannelUnit',
default_args=dict(
candidate_choices=list(i / 12 for i in range(2, 13)),
choice_mode='ratio',
divisor=8)),
parse_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='ImageClassifierPseudoLoss'))))

Expand Down
11 changes: 11 additions & 0 deletions configs/pruning/mmcls/l1-norm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# L1-norm pruning

> [Pruning Filters for Efficient ConvNets.](https://arxiv.org/pdf/1608.08710.pdf)
<!-- [ALGORITHM] -->

## Implementation

L1-norm pruning is a classical filter pruning algorithm. It prunes filers(channels) according to the l1-norm of the weight of a conv layer.

We use ItePruneAlgorithm and L1MutableChannelUnit to implement l1-norm pruning. Please refer to xxxx for more configuration detail.
56 changes: 56 additions & 0 deletions configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
_base_ = ['mmcls::resnet/resnet34_8xb32_in1k.py']

stage_ratio_1 = 0.7
stage_ratio_2 = 0.7
stage_ratio_3 = 0.7
stage_ratio_4 = 1.0

# the config template of target_pruning_ratio can be got by
# python ./tools/get_channel_units.py {config_file} --choice
target_pruning_ratio = {
'backbone.layer1.2.conv2_(0, 64)_64': stage_ratio_1,
'backbone.layer1.0.conv1_(0, 64)_64': stage_ratio_1,
'backbone.layer1.1.conv1_(0, 64)_64': stage_ratio_1,
'backbone.layer1.2.conv1_(0, 64)_64': stage_ratio_1,
'backbone.layer2.0.conv1_(0, 128)_128': stage_ratio_2,
'backbone.layer2.3.conv2_(0, 128)_128': stage_ratio_2,
'backbone.layer2.1.conv1_(0, 128)_128': stage_ratio_2,
'backbone.layer2.2.conv1_(0, 128)_128': stage_ratio_2,
'backbone.layer2.3.conv1_(0, 128)_128': stage_ratio_2,
'backbone.layer3.0.conv1_(0, 256)_256': stage_ratio_3,
'backbone.layer3.5.conv2_(0, 256)_256': stage_ratio_3,
'backbone.layer3.1.conv1_(0, 256)_256': stage_ratio_3,
'backbone.layer3.2.conv1_(0, 256)_256': stage_ratio_3,
'backbone.layer3.3.conv1_(0, 256)_256': stage_ratio_3,
'backbone.layer3.4.conv1_(0, 256)_256': stage_ratio_3,
'backbone.layer3.5.conv1_(0, 256)_256': stage_ratio_3,
'backbone.layer4.0.conv1_(0, 512)_512': stage_ratio_4,
'backbone.layer4.2.conv2_(0, 512)_512': stage_ratio_4,
'backbone.layer4.1.conv1_(0, 512)_512': stage_ratio_4,
'backbone.layer4.2.conv1_(0, 512)_512': stage_ratio_4
}
data_preprocessor = {'type': 'mmcls.ClsDataPreprocessor'}
architecture = _base_.model
architecture.update({
'init_cfg': {
'type':
'Pretrained',
'checkpoint':
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa
}
})

model = dict(
_delete_=True,
_scope_='mmrazor',
type='ItePruneAlgorithm',
architecture=architecture,
mutator_cfg=dict(
type='ChannelMutator',
channel_unit_cfg=dict(
type='L1MutableChannelUnit',
default_args=dict(choice_mode='ratio'))),
target_pruning_ratio=target_pruning_ratio,
step_epoch=1,
prune_times=1,
)
4 changes: 2 additions & 2 deletions mmrazor/engine/runner/slimmable_val_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def run(self):
self.runner.call_hook('before_val')

all_metrics = dict()
for subnet_idx in range(self._model.num_subnet):
for subnet_idx, subnet in enumerate(self._model.mutator.subnets):
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
self._model.mutator.switch_choices(subnet_idx)
self._model.mutator.set_choices(subnet)
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute student metrics
Expand Down
23 changes: 18 additions & 5 deletions mmrazor/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,24 @@
SelfDistill, SingleTeacherDistill)
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
from .pruning.ite_prune_algorithm import ItePruneAlgorithm

__all__ = [
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas',
'DsnasDDP'
'SingleTeacherDistill',
'BaseAlgorithm',
'FpnTeacherDistill',
'SPOS',
'SlimmableNetwork',
'SlimmableNetworkDDP',
'AutoSlim',
'AutoSlimDDP',
'Darts',
'DartsDDP',
'SelfDistill',
'DataFreeDistillation',
'DAFLDataFreeDistillation',
'OverhaulFeatureDistillation',
'ItePruneAlgorithm',
'Dsnas',
'DsnasDDP',
]
44 changes: 34 additions & 10 deletions mmrazor/models/algorithms/nas/autoslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from mmrazor.models.utils import (add_prefix,
reinitialize_optim_wrapper_count_status)
from mmrazor.registry import MODEL_WRAPPERS, MODELS
from mmrazor.utils import SingleMutatorRandomSubnet
from ..base import BaseAlgorithm

VALID_MUTATOR_TYPE = Union[OneShotChannelMutator, Dict]
Expand All @@ -33,10 +32,24 @@ def __init__(self,
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
init_cfg: Optional[Dict] = None,
num_samples: int = 2) -> None:
"""Implementation of Autoslim algorithm. Please refer to
https://arxiv.org/abs/1903.11728 for more details.
Args:
mutator (VALID_MUTATOR_TYPE): config of mutator.
distiller (VALID_DISTILLER_TYPE): config of distiller.
architecture (Union[BaseModel, Dict]): the model to be searched.
data_preprocessor (Optional[Union[Dict, nn.Module]], optional):
data prepocessor. Defaults to None.
init_cfg (Optional[Dict], optional): config of initialization.
Defaults to None.
num_samples (int, optional): number of sample subnets.
Defaults to 2.
"""
super().__init__(architecture, data_preprocessor, init_cfg)

self.mutator = self._build_mutator(mutator)
# `prepare_from_supernet` must be called before distiller initialized
self.mutator: OneShotChannelMutator = MODELS.build(mutator)
# prepare_from_supernet` must be called before distiller initialized
self.mutator.prepare_from_supernet(self.architecture)

self.distiller = self._build_distiller(distiller)
Expand All @@ -49,7 +62,7 @@ def __init__(self,

def _build_mutator(self,
mutator: VALID_MUTATOR_TYPE) -> OneShotChannelMutator:
"""build mutator."""
"""Build mutator."""
if isinstance(mutator, dict):
mutator = MODELS.build(mutator)
if not isinstance(mutator, OneShotChannelMutator):
Expand All @@ -61,6 +74,7 @@ def _build_mutator(self,

def _build_distiller(
self, distiller: VALID_DISTILLER_TYPE) -> ConfigurableDistiller:
"""Build distiller."""
if isinstance(distiller, dict):
distiller = MODELS.build(distiller)
if not isinstance(distiller, ConfigurableDistiller):
Expand All @@ -70,20 +84,25 @@ def _build_distiller(

return distiller

def sample_subnet(self) -> SingleMutatorRandomSubnet:
def sample_subnet(self) -> Dict:
"""Sample a subnet."""
return self.mutator.sample_choices()

def set_subnet(self, subnet: SingleMutatorRandomSubnet) -> None:
def set_subnet(self, subnet) -> None:
"""Set a subnet."""
self.mutator.set_choices(subnet)

def set_max_subnet(self) -> None:
self.mutator.set_max_choices()
"""Set max subnet."""
self.mutator.set_choices(self.mutator.max_choices())

def set_min_subnet(self) -> None:
return self.mutator.set_min_choices()
"""Set min subnet."""
self.mutator.set_choices(self.mutator.min_choices())

def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Train step."""

def distill_step(
batch_inputs: torch.Tensor, data_samples: List[BaseDataElement]
Expand All @@ -109,7 +128,9 @@ def distill_step(
accumulative_counts=self.num_samples + 2)
self._optim_wrapper_count_status_reinitialized = True

batch_inputs, data_samples = self.data_preprocessor(data, True)
input_data = self.data_preprocessor(data, True)
batch_inputs = input_data['inputs']
data_samples = input_data['data_samples']

total_losses = dict()
self.set_max_subnet()
Expand All @@ -136,6 +157,7 @@ def distill_step(

@MODEL_WRAPPERS.register_module()
class AutoSlimDDP(MMDistributedDataParallel):
"""DDPwapper for autoslim."""

def __init__(self,
*,
Expand Down Expand Up @@ -175,7 +197,9 @@ def distill_step(
accumulative_counts=self.module.num_samples + 2)
self._optim_wrapper_count_status_reinitialized = True

batch_inputs, data_samples = self.module.data_preprocessor(data, True)
input_data = self.module.data_preprocessor(data, True)
batch_inputs = input_data['inputs']
data_samples = input_data['data_samples']

total_losses = dict()
self.module.set_max_subnet()
Expand Down
Loading

0 comments on commit b4b7e24

Please sign in to comment.