-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
923 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# DEST | ||
|
||
[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791) | ||
|
||
## Description | ||
|
||
Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles. | ||
|
||
## Usage | ||
|
||
### Prerequisites | ||
|
||
- Python 3.7 | ||
- PyTorch 1.6 or higher | ||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc2 or higher | ||
- mmcv v2.0.0rc4 | ||
- MIM v0.33 or higher | ||
|
||
### Dataset preparing | ||
|
||
Preparing `cityscapes` dataset following this [Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets) | ||
|
||
### Training commands | ||
|
||
```shell | ||
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest | ||
``` | ||
|
||
To train on multiple GPUs, e.g. 8 GPUs, run the following command: | ||
|
||
```shell | ||
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --launcher pytorch --gpus 8 | ||
``` | ||
|
||
### Testing commands | ||
|
||
```shell | ||
mim test mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --checkpoint ${CHECKPOINT_PATH} --eval mIoU | ||
``` | ||
|
||
## Results and models | ||
|
||
### Cityscapes | ||
|
||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | | ||
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------- | | ||
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | - | | ||
|
||
Note: | ||
|
||
- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining. | ||
- Training of DEST is not very stable, which is sensitive to random seeds. | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@article{YangDEST, | ||
title={Depth Estimation with Simplified Transformer}, | ||
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn}, | ||
journal={arXiv preprint arXiv:2204.13791}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
## Checklist | ||
|
||
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. | ||
|
||
- [x] Finish the code | ||
|
||
- [x] Basic docstrings & proper citation | ||
|
||
- [x] Test-time correctness | ||
|
||
- [x] A full README | ||
|
||
- [x] Milestone 2: Indicates a successful model implementation. | ||
|
||
- [x] Training-time correctness | ||
|
||
- [ ] Milestone 3: Good to be a part of our core package! | ||
|
||
- [ ] Type hints and docstrings | ||
|
||
- [ ] Unit tests | ||
|
||
- [ ] Code polishing | ||
|
||
- [ ] Metafile.yml | ||
|
||
- [ ] Move your modules into the core package following the codebase's file hierarchy structure. | ||
|
||
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# DEST | ||
|
||
[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791) | ||
|
||
## Introduction | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
<a href="https://github.com/NVIDIA/DL4AGX/tree/master/DEST">Official Repo</a> | ||
|
||
## Abstract | ||
|
||
<!-- [ABSTRACT] --> | ||
|
||
Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles. | ||
|
||
<!-- [IMAGE] --> | ||
|
||
<div align=center> | ||
<img src="./simplifiedAtt.png" width="70%"/> | ||
</div> | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@article{YangDEST, | ||
title={Depth Estimation with Simplified Transformer}, | ||
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn}, | ||
journal={arXiv preprint arXiv:2204.13791}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
## Results and models | ||
|
||
### Cityscapes | ||
|
||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | | ||
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------- | | ||
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | - | | ||
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | - | | ||
|
||
Note: | ||
|
||
- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining. | ||
- Training of DEST is not very stable, which is sensitive to random seeds. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# model settings | ||
embed_dims = [32, 64, 160, 256] | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained=None, | ||
backbone=dict( | ||
type='SimplifiedMixTransformer', | ||
in_channels=3, | ||
embed_dims=embed_dims, | ||
num_stages=4, | ||
num_layers=[2, 2, 2, 2], | ||
num_heads=[1, 2, 5, 8], | ||
patch_sizes=[7, 3, 3, 3], | ||
strides=[4, 2, 2, 2], | ||
sr_ratios=[8, 4, 2, 1], | ||
out_indices=(0, 1, 2, 3), | ||
mlp_ratios=[8, 8, 4, 4], | ||
qkv_bias=True, | ||
drop_rate=0.0, | ||
attn_drop_rate=0.0, | ||
drop_path_rate=0.1, | ||
norm_cfg=norm_cfg), | ||
decode_head=dict( | ||
type='DESTHead', | ||
in_channels=[32, 64, 160, 256], | ||
in_index=[0, 1, 2, 3], | ||
channels=32, | ||
dropout_ratio=0.1, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), | ||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768))) |
33 changes: 33 additions & 0 deletions
33
projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
_base_ = [ | ||
'./dest_simpatt-b0.py', | ||
'../../../configs/_base_/datasets/cityscapes_1024x1024.py', | ||
'../../../configs/_base_/default_runtime.py', | ||
'../../../configs/_base_/schedules/schedule_160k.py' | ||
] | ||
|
||
custom_imports = dict(imports=['projects.dest.models']) | ||
|
||
optimizer = dict( | ||
_delete_=True, | ||
type='AdamW', | ||
lr=0.00006, | ||
betas=(0.9, 0.999), | ||
weight_decay=0.01, | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'pos_block': dict(decay_mult=0.), | ||
'norm': dict(decay_mult=0.), | ||
'head': dict(lr_mult=10.) | ||
})) | ||
|
||
lr_config = dict( | ||
_delete_=True, | ||
policy='poly', | ||
warmup='linear', | ||
warmup_iters=1500, | ||
warmup_ratio=1e-6, | ||
power=1.0, | ||
min_lr=0.0, | ||
by_epoch=False) | ||
|
||
data = dict(samples_per_gpu=1, workers_per_gpu=1) |
9 changes: 9 additions & 0 deletions
9
projects/dest/configs/dest_simpatt-b1_1024x1024_160k_cityscapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py'] | ||
|
||
embed_dims = [64, 128, 250, 320] | ||
|
||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained=None, | ||
backbone=dict(embed_dims=embed_dims), | ||
decode_head=dict(in_channels=embed_dims, channels=64)) |
9 changes: 9 additions & 0 deletions
9
projects/dest/configs/dest_simpatt-b2_1024x1024_160k_cityscapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py'] | ||
|
||
embed_dims = [64, 128, 250, 320] | ||
|
||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained=None, | ||
backbone=dict(embed_dims=embed_dims, num_layers=[3, 3, 6, 3]), | ||
decode_head=dict(in_channels=embed_dims, channels=64)) |
22 changes: 22 additions & 0 deletions
22
projects/dest/configs/dest_simpatt-b3_1024x1024_160k_cityscapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py'] | ||
|
||
embed_dims = [64, 128, 250, 320] | ||
|
||
optimizer = dict( | ||
_delete_=True, | ||
type='AdamW', | ||
lr=0.00006, | ||
betas=(0.9, 0.999), | ||
weight_decay=0.01, | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'pos_block': dict(decay_mult=0.), | ||
'norm': dict(decay_mult=0.), | ||
'head': dict(lr_mult=1.) | ||
})) | ||
|
||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained=None, | ||
backbone=dict(embed_dims=embed_dims, num_layers=[3, 6, 8, 3]), | ||
decode_head=dict(in_channels=embed_dims, channels=64)) |
22 changes: 22 additions & 0 deletions
22
projects/dest/configs/dest_simpatt-b4_1024x1024_160k_cityscapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py'] | ||
|
||
embed_dims = [64, 128, 250, 320] | ||
|
||
optimizer = dict( | ||
_delete_=True, | ||
type='AdamW', | ||
lr=0.00006, | ||
betas=(0.9, 0.999), | ||
weight_decay=0.01, | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'pos_block': dict(decay_mult=0.), | ||
'norm': dict(decay_mult=0.), | ||
'head': dict(lr_mult=1.) | ||
})) | ||
|
||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained=None, | ||
backbone=dict(embed_dims=embed_dims, num_layers=[3, 8, 12, 5]), | ||
decode_head=dict(in_channels=embed_dims, channels=64)) |
22 changes: 22 additions & 0 deletions
22
projects/dest/configs/dest_simpatt-b5_1024x1024_160k_cityscapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py'] | ||
|
||
embed_dims = [64, 128, 250, 320] | ||
|
||
optimizer = dict( | ||
_delete_=True, | ||
type='AdamW', | ||
lr=0.00006, | ||
betas=(0.9, 0.999), | ||
weight_decay=0.01, | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'pos_block': dict(decay_mult=0.), | ||
'norm': dict(decay_mult=0.), | ||
'head': dict(lr_mult=1.) | ||
})) | ||
|
||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained=None, | ||
backbone=dict(embed_dims=embed_dims, num_layers=[3, 10, 16, 5]), | ||
decode_head=dict(in_channels=embed_dims, channels=64)) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .dest_head import DESTHead | ||
from .smit import SimplifiedMixTransformer | ||
|
||
__all__ = ['SimplifiedMixTransformer', 'DESTHead'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.cnn import ConvModule | ||
|
||
from mmseg.models import HEADS | ||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | ||
|
||
|
||
@HEADS.register_module() | ||
class DESTHead(BaseDecodeHead): | ||
|
||
def __init__(self, interpolate_mode='bilinear', **kwargs): | ||
super().__init__(input_transform='multiple_select', **kwargs) | ||
self.interpolate_mode = interpolate_mode | ||
num_inputs = len(self.in_channels) | ||
assert num_inputs == len(self.in_index) | ||
self.fuse_in_channels = self.in_channels.copy() | ||
for i in range(num_inputs - 1): | ||
self.fuse_in_channels[i] += self.fuse_in_channels[i + 1] | ||
self.convs = nn.ModuleList() | ||
for i in range(num_inputs): | ||
self.convs.append( | ||
ConvModule( | ||
in_channels=self.in_channels[i], | ||
out_channels=self.in_channels[i], | ||
kernel_size=1, | ||
stride=1, | ||
act_cfg=self.act_cfg)) | ||
|
||
self.fuse_convs = nn.ModuleList() | ||
for i in range(num_inputs): | ||
self.fuse_convs.append( | ||
ConvModule( | ||
in_channels=self.fuse_in_channels[i], | ||
out_channels=self.in_channels[i], | ||
kernel_size=3, | ||
stride=1, | ||
padding=1, | ||
act_cfg=self.act_cfg)) | ||
|
||
self.upsample = nn.ModuleList([ | ||
nn.Sequential(nn.Upsample(scale_factor=2, mode=interpolate_mode)) | ||
] * len(self.in_channels)) | ||
|
||
def forward(self, inputs): | ||
feat = None | ||
for idx in reversed(range(len(inputs))): | ||
x = self.convs[idx](inputs[idx]) | ||
if idx != len(inputs) - 1: | ||
x = torch.concat([feat, x], dim=1) | ||
x = self.upsample[idx](x) | ||
feat = self.fuse_convs[idx](x) | ||
return self.cls_seg(feat) |
Oops, something went wrong.