Skip to content

Commit

Permalink
[Feature] Add MobileOne Backbone For MMCls 1.x. (#1030)
Browse files Browse the repository at this point in the history
* add mobileOne

* add train cfg

* update cfgs

* update URL

* update configs

* update inceptionv3 metafile

* add configs

* fix lint

* update checkpoint urls

* Update configs

* Update README

Co-authored-by: mzr1996 <mzr1996@163.com>
  • Loading branch information
Ezra-Yu and mzr1996 committed Sep 16, 2022
1 parent 9999da6 commit f1d2f50
Show file tree
Hide file tree
Showing 26 changed files with 1,317 additions and 63 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/cspnet)
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/poolformer)
- [x] [Inception V3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/inception_v3)
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)

</details>

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ mim install -e .
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/cspnet)
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/poolformer)
- [x] [Inception V3](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/inception_v3)
- [x] [MobileOne](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobileone)

</details>

Expand Down
19 changes: 19 additions & 0 deletions configs/_base_/models/mobileone/mobileone_s0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='MobileOne',
arch='s0',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
))
19 changes: 19 additions & 0 deletions configs/_base_/models/mobileone/mobileone_s1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='MobileOne',
arch='s1',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
))
19 changes: 19 additions & 0 deletions configs/_base_/models/mobileone/mobileone_s2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='MobileOne',
arch='s2',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
))
19 changes: 19 additions & 0 deletions configs/_base_/models/mobileone/mobileone_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='MobileOne',
arch='s3',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
))
19 changes: 19 additions & 0 deletions configs/_base_/models/mobileone/mobileone_s4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='MobileOne',
arch='s4',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
))
4 changes: 2 additions & 2 deletions configs/inception_v3/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ Collections:
Title: "Rethinking the Inception Architecture for Computer Vision"
README: configs/inception_v3/README.md
Code:
URL: TODO
Version: TODO
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc1/configs/inception_v3/metafile.yml
Version: v1.0.0rc1

Models:
- Name: inception-v3_3rdparty_8xb32_in1k
Expand Down
132 changes: 132 additions & 0 deletions configs/mobileone/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# MobileOne

> [An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040)
<!-- [ALGORITHM] -->

## Abstract

Efficient neural network backbones for mobile devices are often optimized for metrics such as FLOPs or parameter count. However, these metrics may not correlate well with latency of the network when deployed on a mobile device. Therefore, we perform extensive analysis of different metrics by deploying several mobile-friendly networks on a mobile device. We identify and analyze architectural and optimization bottlenecks in recent efficient neural networks and provide ways to mitigate these bottlenecks. To this end, we design an efficient backbone MobileOne, with variants achieving an inference time under 1 ms on an iPhone12 with 75.9% top-1 accuracy on ImageNet. We show that MobileOne achieves state-of-the-art performance within the efficient architectures while being many times faster on mobile. Our best model obtains similar performance on ImageNet as MobileFormer while being 38x faster. Our model obtains 2.3% better top-1 accuracy on ImageNet than EfficientNet at similar latency. Furthermore, we show that our model generalizes to multiple tasks - image classification, object detection, and semantic segmentation with significant improvements in latency and accuracy as compared to existing efficient architectures when deployed on a mobile device.

<div align=center>
<img src="https://user-images.githubusercontent.com/18586273/183552452-74657532-f461-48f7-9aa7-c23f006cdb07.png" width="40%"/>
</div>

## Results and models

### ImageNet-1k

| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :------------: | :-----------------------------: | :----------------------------: | :-------: | :-------: | :--------------------------------------------------: | :-----------------------------------------------------: |
| MobileOne-s0\* | 5.29(train) \| 2.08 (deploy) | 1.09 (train) \| 0.28 (deploy) | 71.36 | 89.87 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/mobileone-s0_8xb128_in1k.py) \| [config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220915-007ae971.pth) |
| MobileOne-s1\* | 4.83 (train) \| 4.76 (deploy) | 0.86 (train) \| 0.84 (deploy) | 75.76 | 92.77 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/mobileone-s1_8xb128_in1k.py) \| [config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/deploy/mobileone-s1_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_3rdparty_in1k_20220915-473c8469.pth) |
| MobileOne-s2\* | 7.88 (train) \| 7.88 (deploy) | 1.34 (train) \| 1.31 (deploy) | 77.39 | 93.63 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/mobileone-s2_8xb128_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/deploy/mobileone-s2_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_3rdparty_in1k_20220915-ed2e4c30.pth) |
| MobileOne-s3\* | 10.17 (train) \| 10.08 (deploy) | 1.95 (train) \| 1.91 (deploy) | 77.93 | 93.89 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/mobileone-s3_8xb128_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/deploy/mobileone-s3_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_3rdparty_in1k_20220915-84d6a02c.pth) |
| MobileOne-s4\* | 14.95 (train) \| 14.84 (deploy) | 3.05 (train) \| 3.00 (deploy) | 79.30 | 94.37 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/mobileone-s4_8xb128_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobileone/deploy/mobileone-s4_deploy_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_3rdparty_in1k_20220915-ce9509ee.pth) |

*Models with * are converted from the [official repo](https://github.com/apple/ml-mobileone). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*

*Because the [official repo.](https://github.com/apple/ml-mobileone) does not give a strategy for training and testing, the test data pipline of [RepVGG](https://github.com/open-mmlab/mmclassification/tree/master/configs/repvgg) is used here, and the result is about 0.1 lower than that in the paper. Refer to [this issue](https://github.com/apple/ml-mobileone/issues/2).*

## How to use

The checkpoints provided are all `training-time` models. Use the reparameterize tool to switch them to more efficient `inference-time` architecture, which not only has fewer parameters but also less calculations.

### Use tool

Use provided tool to reparameterize the given model and save the checkpoint:

```bash
python tools/convert_models/reparameterize_model.py ${CFG_PATH} ${SRC_CKPT_PATH} ${TARGET_CKPT_PATH}
```

`${CFG_PATH}` is the config file path, `${SRC_CKPT_PATH}` is the source chenpoint file path, `${TARGET_CKPT_PATH}` is the target deploy weight file path.

For example:

```shell
python ./tools/convert_models/reparameterize_model.py ./configs/mobileone/mobileone-s0_8xb128_in1k.py https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220811-db5ce29b.pth ./mobileone_s0_deploy.pth
```

To use reparameterized weights, the config file must switch to **the deploy config files**.

```bash
python tools/test.py ${Deploy_CFG} ${Deploy_Checkpoint} --metrics accuracy
```

For example of using the reparameterized weights above:

```shell
python ./tools/test.py ./configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py mobileone_s0_deploy.pth --metrics accuracy
```

### In the code

Use the API `switch_to_deploy` of `MobileOne` backbone to to switch to the deploy mode. Usually called like `backbone.switch_to_deploy()` or `classificer.backbone.switch_to_deploy()`.

For Backbones:

```python
from mmcls.models import build_backbone
import torch

x = torch.randn( (1, 3, 224, 224) )
backbone_cfg=dict(type='MobileOne', arch='s0')
backbone = build_backbone(backbone_cfg)
backbone.init_weights()
backbone.eval()
outs_ori = backbone(x)

backbone.switch_to_deploy()
outs_dep = backbone(x)

for out1, out2 in zip(outs_ori, outs_dep):
assert torch.allclose(out1, out2)
```

For ImageClassifiers:

```python
from mmcls.models import build_classifier
import torch
import numpy as np

cfg = dict(
type='ImageClassifier',
backbone=dict(
type='MobileOne',
arch='s0',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

x = torch.randn( (1, 3, 224, 224) )
classifier = build_classifier(cfg)
classifier.init_weights()
classifier.eval()
y_ori = classifier(x, return_loss=False)

classifier.backbone.switch_to_deploy()
y_dep = classifier(x, return_loss=False)

for y1, y2 in zip(y_ori, y_dep):
assert np.allclose(y1, y2)
```

## Citation

```bibtex
@article{mobileone2022,
title={An Improved One millisecond Mobile Backbone},
author={Vasu, Pavan Kumar Anasosalu and Gabriel, James and Zhu, Jeff and Tuzel, Oncel and Ranjan, Anurag},
journal={arXiv preprint arXiv:2206.04040},
year={2022}
}
```
3 changes: 3 additions & 0 deletions configs/mobileone/deploy/mobileone-s0_deploy_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = ['../mobileone-s0_8xb128_in1k.py']

model = dict(backbone=dict(deploy=True))
3 changes: 3 additions & 0 deletions configs/mobileone/deploy/mobileone-s1_deploy_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = ['../mobileone-s1_8xb128_in1k.py']

model = dict(backbone=dict(deploy=True))
3 changes: 3 additions & 0 deletions configs/mobileone/deploy/mobileone-s2_deploy_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = ['../mobileone-s2_8xb128_in1k.py']

model = dict(backbone=dict(deploy=True))
3 changes: 3 additions & 0 deletions configs/mobileone/deploy/mobileone-s3_deploy_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = ['../mobileone-s3_8xb128_in1k.py']

model = dict(backbone=dict(deploy=True))
3 changes: 3 additions & 0 deletions configs/mobileone/deploy/mobileone-s4_deploy_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = ['../mobileone-s4_8xb128_in1k.py']

model = dict(backbone=dict(deploy=True))
98 changes: 98 additions & 0 deletions configs/mobileone/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
Collections:
- Name: MobileOne
Metadata:
Training Data: ImageNet-1k
Architecture:
- re-parameterization Convolution
- VGG-style Neural Network
- Depthwise Convolution
- Pointwise Convolution
Paper:
URL: https://arxiv.org/abs/2206.04040
Title: 'An Improved One millisecond Mobile Backbone'
README: configs/mobileone/README.md
Code:
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc1/configs/mobileone/metafile.yml
Version: v1.0.0rc1

Models:
- Name: mobileone-s0_3rdparty_8xb128_in1k
In Collection: MobileOne
Config: configs/mobileone/mobileone-s0_8xb128_in1k.py
Metadata:
FLOPs: 1091227648 # 1.09G
Parameters: 5293272 # 5.29M
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 71.36
Top 5 Accuracy: 89.87
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s0_3rdparty_in1k_20220915-007ae971.pth
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
- Name: mobileone-s1_3rdparty_8xb128_in1k
In Collection: MobileOne
Config: configs/mobileone/mobileone-s1_8xb128_in1k.py
Metadata:
FLOPs: 863491328 # 8.6G
Parameters: 4825192 # 4.82M
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 75.76
Top 5 Accuracy: 92.77
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s1_3rdparty_in1k_20220915-473c8469.pth
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
- Name: mobileone-s2_3rdparty_8xb128_in1k
In Collection: MobileOne
Config: configs/mobileone/mobileone-s2_8xb128_in1k.py
Metadata:
FLOPs: 1344083328
Parameters: 7884648
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 77.39
Top 5 Accuracy: 93.63
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s2_3rdparty_in1k_20220915-ed2e4c30.pth
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
- Name: mobileone-s3_3rdparty_8xb128_in1k
In Collection: MobileOne
Config: configs/mobileone/mobileone-s3_8xb128_in1k.py
Metadata:
FLOPs: 1951043584
Parameters: 10170600
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 77.93
Top 5 Accuracy: 93.89
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s3_3rdparty_in1k_20220915-84d6a02c.pth
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
- Name: mobileone-s4_3rdparty_8xb128_in1k
In Collection: MobileOne
Config: configs/mobileone/mobileone-s4_8xb128_in1k.py
Metadata:
FLOPs: 3052580688
Parameters: 14951248
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 79.30
Top 5 Accuracy: 94.37
Weights: https://download.openmmlab.com/mmclassification/v0/mobileone/mobileone-s4_3rdparty_in1k_20220915-ce9509ee.pth
Converted From:
Weights: https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar
Code: https://github.com/apple/ml-mobileone
Loading

0 comments on commit f1d2f50

Please sign in to comment.