Skip to content

Commit

Permalink
[Feature] Add Dsnas Algorithm (#226)
Browse files Browse the repository at this point in the history
* [tmp] Update Dsnas

* [tmp] refactor arch_loss & flops_loss

* Update Dsnas & MMRAZOR_EVALUATOR:
1. finalized compute_loss & handle_grads in algorithm;
2. add MMRAZOR_EVALUATOR;
3. fix bugs.

* Update lr scheduler & fix a bug:
1. update param_scheduler & lr_scheduler for dsnas;
2. fix a bug of switching to finetune stage.

* remove old evaluators

* remove old evaluators

* update param_scheduler config

* merge dev-1.x into gy/estimator

* add flops_loss in Dsnas using ResourcesEstimator

* get resources before mutator.prepare_from_supernet

* delete unness broadcast api from gml

* broadcast spec_modules_resources when estimating

* update early fix mechanism for Dsnas

* fix merge

* update units in estimator

* minor change

* fix data_preprocessor api

* add flops_loss_coef

* remove DsnasOptimWrapper

* fix bn eps and data_preprocessor

* fix bn weight decay bug

* add betas for mutator optimizer

* set diff_rank_seed=True for dsnas

* fix start_factor of lr when warm up

* remove .module in non-ddp mode

* add GlobalAveragePoolingWithDropout

* add UT for dsnas

* remove unness channel adjustment for shufflenetv2

* update supernet configs

* delete unness dropout

* delete unness part with minor change on dsnas

* minor change on the flag of search stage

* update README and subnet configs

* add UT for OneHotMutableOP
  • Loading branch information
gaoyang07 committed Sep 29, 2022
1 parent d07dee9 commit 8d603d9
Show file tree
Hide file tree
Showing 18 changed files with 1,187 additions and 30 deletions.
28 changes: 28 additions & 0 deletions configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
norm_cfg = dict(type='BN', eps=0.01)

_STAGE_MUTABLE = dict(
type='mmrazor.OneHotMutableOP',
fix_threshold=0.3,
candidates=dict(
shuffle_3x3=dict(
type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg),
shuffle_5x5=dict(
type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg),
shuffle_7x7=dict(
type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg),
shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg)))

arch_setting = [
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, mutable_cfg.
[64, 4, _STAGE_MUTABLE],
[160, 4, _STAGE_MUTABLE],
[320, 8, _STAGE_MUTABLE],
[640, 4, _STAGE_MUTABLE]
]

nas_backbone = dict(
type='mmrazor.SearchableShuffleNetV2',
widen_factor=1.0,
arch_setting=arch_setting,
norm_cfg=norm_cfg)
102 changes: 102 additions & 0 deletions configs/_base_/settings/imagenet_bs1024_dsnas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
data_preprocessor = dict(
type='mmcls.ClsDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)

train_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(type='mmcls.RandomResizedCrop', scale=224),
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
dict(type='mmcls.PackClsInputs'),
]

test_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
dict(type='mmcls.CenterCrop', crop_size=224),
dict(type='mmcls.PackClsInputs'),
]

train_dataloader = dict(
batch_size=128,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
persistent_workers=True,
)

val_dataloader = dict(
batch_size=128,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

# optimizer
paramwise_cfg = dict(bias_decay_mult=0.0, norm_decay_mult=0.0)

optim_wrapper = dict(
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
optimizer=dict(
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5),
paramwise_cfg=paramwise_cfg),
mutator=dict(
optimizer=dict(
type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5,
0.999))))

search_epochs = 85
# leanring policy
param_scheduler = dict(
architecture=[
dict(
type='mmcls.LinearLR',
end=5,
start_factor=0.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='mmcls.CosineAnnealingLR',
T_max=240,
begin=5,
end=search_epochs,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='mmcls.CosineAnnealingLR',
T_max=160,
begin=search_epochs,
end=240,
eta_min=0.0,
by_epoch=True,
convert_to_iter_based=True)
],
mutator=[])

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=240)
val_cfg = dict()
test_cfg = dict()
20 changes: 20 additions & 0 deletions configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
backbone.layers.0.0: shuffle_3x3
backbone.layers.0.1: shuffle_3x3
backbone.layers.0.2: shuffle_xception
backbone.layers.0.3: shuffle_3x3
backbone.layers.1.0: shuffle_xception
backbone.layers.1.1: shuffle_7x7
backbone.layers.1.2: shuffle_3x3
backbone.layers.1.3: shuffle_3x3
backbone.layers.2.0: shuffle_xception
backbone.layers.2.1: shuffle_xception
backbone.layers.2.2: shuffle_7x7
backbone.layers.2.3: shuffle_xception
backbone.layers.2.4: shuffle_xception
backbone.layers.2.5: shuffle_xception
backbone.layers.2.6: shuffle_7x7
backbone.layers.2.7: shuffle_3x3
backbone.layers.3.0: shuffle_3x3
backbone.layers.3.1: shuffle_xception
backbone.layers.3.2: shuffle_xception
backbone.layers.3.3: shuffle_3x3
43 changes: 43 additions & 0 deletions configs/nas/mmcls/dsnas/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# DSNAS

> [DSNAS: Direct Neural Architecture Search without Parameter Retraining](https://arxiv.org/abs/2002.09128.pdf)
<!-- [ALGORITHM] -->

## Abstract

Most existing NAS methods require two-stage parameter optimization.
However, performance of the same architecture in the two stages correlates poorly.
Based on this observation, DSNAS proposes a task-specific end-to-end differentiable NAS framework that simultaneously optimizes architecture and parameters with a low-biased Monte Carlo estimate. Child networks derived from DSNAS can be deployed directly without parameter retraining.

![pipeline](/docs/en/imgs/model_zoo/dsnas/pipeline.jpg)

## Results and models

### Supernet

| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks |
| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :----------------------: | :--------------: |
| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](<>) \| [log](<>) | MMRazor searched |

**Note**:

1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example,
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.**
2. We convert the official searched checkpoint DSNASsearch240.pth into mmrazor-style and evaluate with pytorch1.8_cuda11.0, Top-1 is 74.1 and Top-5 is 91.51.
3. The implementation of ShuffleNetV2 in official DSNAS is different from OpenMMLab's and we follow the structure design in OpenMMLab. Note that with the
origin ShuffleNetV2 design in official DSNAS, the Top-1 is 73.92 and Top-5 is 91.59.
4. The finetune stage in our implementation refers to the 'search-from-search' stage mentioned in official DSNAS.
5. We obtain params and FLOPs using `mmrazor.ResourceEstimator`, which may be different from the origin repo.

## Citation

```latex
@inproceedings{hu2020dsnas,
title={Dsnas: Direct neural architecture search without parameter retraining},
author={Hu, Shoukang and Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Shi, Jianping and Liu, Xunying and Lin, Dahua},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={12084--12092},
year={2020}
}
```
29 changes: 29 additions & 0 deletions configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
_base_ = ['./dsnas_supernet_8xb128_in1k.py']

# NOTE: Replace this with the mutable_cfg searched by yourself.
fix_subnet = {
'backbone.layers.0.0': 'shuffle_3x3',
'backbone.layers.0.1': 'shuffle_7x7',
'backbone.layers.0.2': 'shuffle_3x3',
'backbone.layers.0.3': 'shuffle_5x5',
'backbone.layers.1.0': 'shuffle_3x3',
'backbone.layers.1.1': 'shuffle_3x3',
'backbone.layers.1.2': 'shuffle_3x3',
'backbone.layers.1.3': 'shuffle_7x7',
'backbone.layers.2.0': 'shuffle_xception',
'backbone.layers.2.1': 'shuffle_3x3',
'backbone.layers.2.2': 'shuffle_3x3',
'backbone.layers.2.3': 'shuffle_5x5',
'backbone.layers.2.4': 'shuffle_3x3',
'backbone.layers.2.5': 'shuffle_5x5',
'backbone.layers.2.6': 'shuffle_7x7',
'backbone.layers.2.7': 'shuffle_7x7',
'backbone.layers.3.0': 'shuffle_xception',
'backbone.layers.3.1': 'shuffle_3x3',
'backbone.layers.3.2': 'shuffle_7x7',
'backbone.layers.3.3': 'shuffle_3x3',
}

model = dict(fix_subnet=fix_subnet)

find_unused_parameters = False
36 changes: 36 additions & 0 deletions configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
_base_ = [
'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py',
'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py',
'mmcls::_base_/default_runtime.py',
]

# model
model = dict(
type='mmrazor.Dsnas',
architecture=dict(
type='ImageClassifier',
data_preprocessor=_base_.data_preprocessor,
backbone=_base_.nas_backbone,
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(
type='LabelSmoothLoss',
num_classes=1000,
label_smooth_val=0.1,
mode='original',
loss_weight=1.0),
topk=(1, 5))),
mutator=dict(type='mmrazor.DiffModuleMutator'),
pretrain_epochs=15,
finetune_epochs=_base_.search_epochs,
)

model_wrapper_cfg = dict(
type='mmrazor.DsnasDDP',
broadcast_buffers=False,
find_unused_parameters=True)

randomness = dict(seed=48, diff_rank_seed=True)
4 changes: 2 additions & 2 deletions mmrazor/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
'EstimateResourcesHook'
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook',
'SelfDistillValLoop'
]
5 changes: 3 additions & 2 deletions mmrazor/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, OverhaulFeatureDistillation,
SelfDistill, SingleTeacherDistill)
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP

__all__ = [
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation'
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas',
'DsnasDDP'
]
5 changes: 4 additions & 1 deletion mmrazor/models/algorithms/nas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim import AutoSlim, AutoSlimDDP
from .darts import Darts, DartsDDP
from .dsnas import Dsnas, DsnasDDP
from .spos import SPOS

__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP']
__all__ = [
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP'
]
Loading

0 comments on commit 8d603d9

Please sign in to comment.