Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEST] add DEST model #2482

Merged
merged 4 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions LICENSES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

In this file, we list the features with other licenses instead of Apache 2.0. Users should be careful about adopting these features in any commercial matters.

| Feature | Files | License |
| :-------: | :-------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------: |
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
| Feature | Files | License |
| :-------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------: |
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
| DEST | [mmseg/models/backbones/smit.py](https://github.com/open-mmlab/mmsegmentation/blob/master/projects/dest/models/smit.py) [mmseg/models/decode_heads/dest_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/projects/dest/models/dest_head.py) | [NVIDIA License](https://github.com/NVIDIA/DL4AGX/blob/master/DEST/LICENSE) |
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ Supported methods:
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
- [x] [K-Net (NeurIPS'2021)](configs/knet)
- [x] [DEST (CVPRW'2022)](projects/dest)

Supported datasets:

Expand Down
99 changes: 99 additions & 0 deletions projects/dest/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# DEST

[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)

## Description

Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles.

## Usage

### Prerequisites

- Python 3.8.12
- PyTorch 1.11
- mmcv v1.7.0
- Install [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) from source

All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the mmsegmentaions directory so that Python can locate the configuration files in mmsegmentation.

### Dataset preparing

Preparing `cityscapes` dataset following this [Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets)

### Training commands

```shell
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest
```

To train on multiple GPUs, e.g. 8 GPUs, run the following command:

```shell
mim train mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --launcher pytorch --gpus 8
```

### Testing commands

```shell
mim test mmsegmentation projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py --work-dir work_dirs/dest --checkpoint ${CHECKPOINT_PATH} --eval mIoU
```

## Results and models

### Cityscapes

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025-11f73f34.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025.log) |
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358-0dd4e86e.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358.logmmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358.log) |
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943-b06319ae.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943.log) |
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800-ee4cec5c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800.log) |
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155-3ca9f4fc.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155.log) |
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411-e83819b5.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411.log) |

Note:

- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining.
- Training of DEST is not very stable, which is sensitive to random seeds.

## Citation

```bibtex
@article{YangDEST,
title={Depth Estimation with Simplified Transformer},
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
journal={arXiv preprint arXiv:2204.13791},
year={2022}
}
```

## Checklist

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

- [x] Basic docstrings & proper citation

- [x] Test-time correctness

- [x] A full README

- [x] Milestone 2: Indicates a successful model implementation.

- [x] Training-time correctness

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

- [ ] Unit tests

- [ ] Code polishing

- [ ] Metafile.yml

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
50 changes: 50 additions & 0 deletions projects/dest/configs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# DEST

[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)

## Introduction

<!-- [ALGORITHM] -->

<a href="https://github.com/NVIDIA/DL4AGX/tree/master/DEST">Official Repo</a>

## Abstract

<!-- [ABSTRACT] -->

Transformer and its variants have shown state-of-the-art results in many vision tasks recently, ranging from image classification to dense prediction. Despite of their success, limited work has been reported on improving the model efficiency for deployment in latency-critical applications, such as autonomous driving and robotic navigation. In this paper, we aim at improving upon the existing transformers in vision, and propose a method for Dense Estimation with Simplified Transformer (DEST), which is efficient and particularly suitable for deployment on GPU-based platforms. Through strategic design choices, our model leads to significant reduction in model size, complexity, as well as inference latency, while achieving superior accuracy as compared to state-of-the-art in the task of self-supervised monocular depth estimation. We also show that our design generalize well to other dense prediction task such as semantic segmentation without bells and whistles.

<!-- [IMAGE] -->

<div align=center>
<img src="https://user-images.githubusercontent.com/76149310/219313665-49fa89ed-4973-4496-bb33-3256f107e82d.png" width="70%"/>
</div>

## Citation

```bibtex
@article{YangDEST,
title={Depth Estimation with Simplified Transformer},
author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
journal={arXiv preprint arXiv:2204.13791},
year={2022}
}
```

## Results and models

### Cityscapes

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------: | -------: | -------------- | ----: | ------------- | ---------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
| DEST | SMIT-B0 | 1024x1024 | 160000 | - | - | 64.34 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b0_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b0_1024x1024_160k_cityscapes_20230105_232025-11f73f34.pth) |
| DEST | SMIT-B1 | 1024x1024 | 160000 | - | - | 68.21 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b1_1024x1024_160k_cityscapes_20230105_232358-0dd4e86e.pth) |
| DEST | SMIT-B2 | 1024x1024 | 160000 | - | - | 71.89 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b2_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b2_1024x1024_160k_cityscapes_20230105_231943-b06319ae.pth) |
| DEST | SMIT-B3 | 1024x1024 | 160000 | - | - | 73.51 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b3_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b3_1024x1024_160k_cityscapes_20230105_231800-ee4cec5c.pth) |
| DEST | SMIT-B4 | 1024x1024 | 160000 | - | - | 73.99 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b4_1024x1024_160k_cityscapes_20230105_232155-3ca9f4fc.pth) |
| DEST | SMIT-B5 | 1024x1024 | 160000 | - | - | 75.28 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dest/dest_simpatt-b5_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dest/dest_simpatt-b5_1024x1024_160k_cityscapes_20230105_231411-e83819b5.pth) |

Note:

- The above models are all training from scratch without pretrained backbones. Accuracy can be further enhanced by appropriate pretraining.
- Training of DEST is not very stable, which is sensitive to random seeds.
37 changes: 37 additions & 0 deletions projects/dest/configs/dest_simpatt-b0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# model settings
embed_dims = [32, 64, 160, 256]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='SimplifiedMixTransformer',
in_channels=3,
embed_dims=embed_dims,
num_stages=4,
num_layers=[2, 2, 2, 2],
num_heads=[1, 2, 5, 8],
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_cfg=norm_cfg),
decode_head=dict(
type='DESTHead',
in_channels=[32, 64, 160, 256],
in_index=[0, 1, 2, 3],
channels=32,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
33 changes: 33 additions & 0 deletions projects/dest/configs/dest_simpatt-b0_1024x1024_160k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
_base_ = [
'./dest_simpatt-b0.py',
'../../../configs/_base_/datasets/cityscapes_1024x1024.py',
'../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_160k.py'
]

custom_imports = dict(imports=['projects.dest.models'])

optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

data = dict(samples_per_gpu=1, workers_per_gpu=1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']

embed_dims = [64, 128, 250, 320]

model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims),
decode_head=dict(in_channels=embed_dims, channels=64))
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']

embed_dims = [64, 128, 250, 320]

model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims, num_layers=[3, 3, 6, 3]),
decode_head=dict(in_channels=embed_dims, channels=64))
22 changes: 22 additions & 0 deletions projects/dest/configs/dest_simpatt-b3_1024x1024_160k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']

embed_dims = [64, 128, 250, 320]

optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=1.)
}))

model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims, num_layers=[3, 6, 8, 3]),
decode_head=dict(in_channels=embed_dims, channels=64))
22 changes: 22 additions & 0 deletions projects/dest/configs/dest_simpatt-b4_1024x1024_160k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']

embed_dims = [64, 128, 250, 320]

optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=1.)
}))

model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims, num_layers=[3, 8, 12, 5]),
decode_head=dict(in_channels=embed_dims, channels=64))
22 changes: 22 additions & 0 deletions projects/dest/configs/dest_simpatt-b5_1024x1024_160k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_base_ = ['./dest_simpatt-b0_1024x1024_160k_cityscapes.py']

embed_dims = [64, 128, 250, 320]

optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=1.)
}))

model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(embed_dims=embed_dims, num_layers=[3, 10, 16, 5]),
decode_head=dict(in_channels=embed_dims, channels=64))
5 changes: 5 additions & 0 deletions projects/dest/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dest_head import DESTHead
from .smit import SimplifiedMixTransformer

__all__ = ['SimplifiedMixTransformer', 'DESTHead']
Loading