Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Dsnas Algorithm #226

Merged
merged 39 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6fad48c
[tmp] Update Dsnas
Aug 1, 2022
6ded2d4
[tmp] refactor arch_loss & flops_loss
Aug 5, 2022
1932d82
Update Dsnas & MMRAZOR_EVALUATOR:
Aug 10, 2022
be98e8b
Merge branch 'dev-1.x' into gy/dsnas
Aug 11, 2022
b1f5c7d
Update lr scheduler & fix a bug:
Aug 12, 2022
8f8beef
remove old evaluators
Aug 16, 2022
6f4e6e8
Merge branch 'dev-1.x' of github.com:open-mmlab/mmrazor into gy/dsnas
Aug 16, 2022
3deb833
remove old evaluators
Aug 16, 2022
2247bdd
update param_scheduler config
Aug 16, 2022
8d8d1b8
merge dev-1.x into gy/estimator
Aug 23, 2022
c2dbcaf
add flops_loss in Dsnas using ResourcesEstimator
Aug 23, 2022
bc813f3
get resources before mutator.prepare_from_supernet
Aug 25, 2022
ce87f89
delete unness broadcast api from gml
Aug 25, 2022
73d9c3b
broadcast spec_modules_resources when estimating
Aug 25, 2022
5de5bc9
update early fix mechanism for Dsnas
Aug 25, 2022
1e4014d
merge dev-1.x into gy/dsnas
Aug 25, 2022
e98043a
fix merge
Aug 25, 2022
2fbdd01
update units in estimator
Aug 26, 2022
d676d93
minor change
Aug 26, 2022
420dcac
merge dev-1.x into gy/dsnas
Aug 29, 2022
d6a401b
fix data_preprocessor api
Aug 31, 2022
aa2e0f2
add flops_loss_coef
Aug 31, 2022
08964ce
remove DsnasOptimWrapper
Aug 31, 2022
32baf69
fix bn eps and data_preprocessor
Aug 31, 2022
22cf3ed
fix bn weight decay bug
Sep 6, 2022
e0bafe2
add betas for mutator optimizer
Sep 6, 2022
84b5367
set diff_rank_seed=True for dsnas
Sep 6, 2022
eef0514
fix start_factor of lr when warm up
Sep 6, 2022
7258926
remove .module in non-ddp mode
Sep 6, 2022
2dfdc79
add GlobalAveragePoolingWithDropout
Sep 6, 2022
79d69ef
add UT for dsnas
Sep 6, 2022
b511403
remove unness channel adjustment for shufflenetv2
Sep 7, 2022
91ef7d8
update supernet configs
Sep 26, 2022
cfccb5e
delete unness dropout
Sep 26, 2022
6794188
delete unness part with minor change on dsnas
Sep 26, 2022
bbc93ff
merge dev-1.x into gy/dsnas
Sep 29, 2022
3a6b696
minor change on the flag of search stage
Sep 29, 2022
6df18d1
update README and subnet configs
Sep 29, 2022
a77d894
add UT for OneHotMutableOP
Sep 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
norm_cfg = dict(type='BN', eps=0.01)

_STAGE_MUTABLE = dict(
type='mmrazor.OneHotMutableOP',
fix_threshold=0.3,
candidates=dict(
shuffle_3x3=dict(
type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg),
shuffle_5x5=dict(
type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg),
shuffle_7x7=dict(
type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg),
shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg)))

arch_setting = [
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, mutable_cfg.
[64, 4, _STAGE_MUTABLE],
[160, 4, _STAGE_MUTABLE],
[320, 8, _STAGE_MUTABLE],
[640, 4, _STAGE_MUTABLE]
]

nas_backbone = dict(
type='mmrazor.SearchableShuffleNetV2',
widen_factor=1.0,
arch_setting=arch_setting,
norm_cfg=norm_cfg)
102 changes: 102 additions & 0 deletions configs/_base_/settings/imagenet_bs1024_dsnas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
data_preprocessor = dict(
type='mmcls.ClsDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)

train_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(type='mmcls.RandomResizedCrop', scale=224),
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
dict(type='mmcls.PackClsInputs'),
]

test_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
dict(type='mmcls.CenterCrop', crop_size=224),
dict(type='mmcls.PackClsInputs'),
]

train_dataloader = dict(
batch_size=128,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
persistent_workers=True,
)

val_dataloader = dict(
batch_size=128,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

# optimizer
paramwise_cfg = dict(bias_decay_mult=0.0, norm_decay_mult=0.0)

optim_wrapper = dict(
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
optimizer=dict(
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5),
paramwise_cfg=paramwise_cfg),
mutator=dict(
optimizer=dict(
type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5,
0.999))))

search_epochs = 85
# leanring policy
param_scheduler = dict(
architecture=[
dict(
type='mmcls.LinearLR',
end=5,
start_factor=0.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='mmcls.CosineAnnealingLR',
T_max=240,
begin=5,
end=search_epochs,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='mmcls.CosineAnnealingLR',
T_max=160,
begin=search_epochs,
end=240,
eta_min=0.0,
by_epoch=True,
convert_to_iter_based=True)
],
mutator=[])

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=240)
val_cfg = dict()
test_cfg = dict()
20 changes: 20 additions & 0 deletions configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
backbone.layers.0.0: shuffle_3x3
backbone.layers.0.1: shuffle_3x3
backbone.layers.0.2: shuffle_xception
backbone.layers.0.3: shuffle_3x3
backbone.layers.1.0: shuffle_xception
backbone.layers.1.1: shuffle_7x7
backbone.layers.1.2: shuffle_3x3
backbone.layers.1.3: shuffle_3x3
backbone.layers.2.0: shuffle_xception
backbone.layers.2.1: shuffle_xception
backbone.layers.2.2: shuffle_7x7
backbone.layers.2.3: shuffle_xception
backbone.layers.2.4: shuffle_xception
backbone.layers.2.5: shuffle_xception
backbone.layers.2.6: shuffle_7x7
backbone.layers.2.7: shuffle_3x3
backbone.layers.3.0: shuffle_3x3
backbone.layers.3.1: shuffle_xception
backbone.layers.3.2: shuffle_xception
backbone.layers.3.3: shuffle_3x3
43 changes: 43 additions & 0 deletions configs/nas/mmcls/dsnas/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# DSNAS

> [DSNAS: Direct Neural Architecture Search without Parameter Retraining](https://arxiv.org/abs/2002.09128.pdf)

<!-- [ALGORITHM] -->

## Abstract

Most existing NAS methods require two-stage parameter optimization.
However, performance of the same architecture in the two stages correlates poorly.
Based on this observation, DSNAS proposes a task-specific end-to-end differentiable NAS framework that simultaneously optimizes architecture and parameters with a low-biased Monte Carlo estimate. Child networks derived from DSNAS can be deployed directly without parameter retraining.

![pipeline](/docs/en/imgs/model_zoo/dsnas/pipeline.jpg)

## Results and models

### Supernet

| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks |
| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :----------------------: | :--------------: |
| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](<>) \| [log](<>) | MMRazor searched |

**Note**:

1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example,
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.**
2. We convert the official searched checkpoint DSNASsearch240.pth into mmrazor-style and evaluate with pytorch1.8_cuda11.0, Top-1 is 74.1 and Top-5 is 91.51.
3. The implementation of ShuffleNetV2 in official DSNAS is different from OpenMMLab's and we follow the structure design in OpenMMLab. Note that with the
origin ShuffleNetV2 design in official DSNAS, the Top-1 is 73.92 and Top-5 is 91.59.
4. The finetune stage in our implementation refers to the 'search-from-search' stage mentioned in official DSNAS.
5. We obtain params and FLOPs using `mmrazor.ResourceEstimator`, which may be different from the origin repo.

## Citation

```latex
@inproceedings{hu2020dsnas,
title={Dsnas: Direct neural architecture search without parameter retraining},
author={Hu, Shoukang and Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Shi, Jianping and Liu, Xunying and Lin, Dahua},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={12084--12092},
year={2020}
}
```
29 changes: 29 additions & 0 deletions configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
_base_ = ['./dsnas_supernet_8xb128_in1k.py']

# NOTE: Replace this with the mutable_cfg searched by yourself.
fix_subnet = {
'backbone.layers.0.0': 'shuffle_3x3',
'backbone.layers.0.1': 'shuffle_7x7',
'backbone.layers.0.2': 'shuffle_3x3',
'backbone.layers.0.3': 'shuffle_5x5',
'backbone.layers.1.0': 'shuffle_3x3',
'backbone.layers.1.1': 'shuffle_3x3',
'backbone.layers.1.2': 'shuffle_3x3',
'backbone.layers.1.3': 'shuffle_7x7',
'backbone.layers.2.0': 'shuffle_xception',
'backbone.layers.2.1': 'shuffle_3x3',
'backbone.layers.2.2': 'shuffle_3x3',
'backbone.layers.2.3': 'shuffle_5x5',
'backbone.layers.2.4': 'shuffle_3x3',
'backbone.layers.2.5': 'shuffle_5x5',
'backbone.layers.2.6': 'shuffle_7x7',
'backbone.layers.2.7': 'shuffle_7x7',
'backbone.layers.3.0': 'shuffle_xception',
'backbone.layers.3.1': 'shuffle_3x3',
'backbone.layers.3.2': 'shuffle_7x7',
'backbone.layers.3.3': 'shuffle_3x3',
}

model = dict(fix_subnet=fix_subnet)

find_unused_parameters = False
36 changes: 36 additions & 0 deletions configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
_base_ = [
'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py',
'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py',
'mmcls::_base_/default_runtime.py',
]

# model
model = dict(
type='mmrazor.Dsnas',
architecture=dict(
type='ImageClassifier',
data_preprocessor=_base_.data_preprocessor,
backbone=_base_.nas_backbone,
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(
type='LabelSmoothLoss',
num_classes=1000,
label_smooth_val=0.1,
mode='original',
loss_weight=1.0),
topk=(1, 5))),
mutator=dict(type='mmrazor.DiffModuleMutator'),
pretrain_epochs=15,
finetune_epochs=_base_.search_epochs,
)

model_wrapper_cfg = dict(
type='mmrazor.DsnasDDP',
broadcast_buffers=False,
find_unused_parameters=True)

randomness = dict(seed=48, diff_rank_seed=True)
4 changes: 2 additions & 2 deletions mmrazor/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
'EstimateResourcesHook'
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook',
'SelfDistillValLoop'
]
5 changes: 3 additions & 2 deletions mmrazor/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, OverhaulFeatureDistillation,
SelfDistill, SingleTeacherDistill)
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP

__all__ = [
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation'
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas',
'DsnasDDP'
]
5 changes: 4 additions & 1 deletion mmrazor/models/algorithms/nas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim import AutoSlim, AutoSlimDDP
from .darts import Darts, DartsDDP
from .dsnas import Dsnas, DsnasDDP
from .spos import SPOS

__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP']
__all__ = [
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP'
]
Loading