Skip to content

Commit

Permalink
[Feature] add DPT head (#605)
Browse files Browse the repository at this point in the history
* add DPT head

* [fix] fix init error

* use mmcv function

* delete code

* remove transpose clas

* support NLC output shape

* Delete post_process_layer.py

* add unittest and docstring

* rename variables

* fix project error and add unittest

* match dpt weights

* add configs

* fix vit pos_embed bug and dpt feature fusion bug

* match vit output

* fix gelu

* minor change

* update unitest

* fix configs error

* inference test

* remove auxilary

* use local pretrain

* update training results

* update yml

* update fps and memory test

* update doc

* update readme

* add yml

* update doc

* remove with_cp

* update config

* update docstring

* remove dpt-l

* add init_cfg and modify readme.md

* Update dpt_vit-b16.py

* zh-n README

* use constructor instead of build function

* prevent tensor being modified by ConvModule

* fix unittest

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
  • Loading branch information
谢昕辰 and Junjun2016 committed Aug 30, 2021
1 parent 5753f41 commit 2825efe
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 1 deletion.
31 changes: 31 additions & 0 deletions configs/_base_/models/dpt_vit-b16.py
@@ -0,0 +1,31 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='pretrain/vit-b16_p16_224-80ecf9dd.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=224,
embed_dims=768,
num_layers=12,
num_heads=12,
out_indices=(2, 5, 8, 11),
final_norm=False,
with_cls_token=True,
output_cls_token=True),
decode_head=dict(
type='DPTHead',
in_channels=(768, 768, 768, 768),
channels=256,
embed_dims=768,
post_process_channels=[96, 192, 384, 768],
num_classes=150,
readout_type='project',
input_transform='multiple_select',
in_index=(0, 1, 2, 3),
norm_cfg=norm_cfg,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=None,
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole')) # yapf: disable
47 changes: 47 additions & 0 deletions configs/dpt/README.md
@@ -0,0 +1,47 @@
# Vision Transformer for Dense Prediction

## Introduction

<!-- [ALGORITHM] -->

```latex
@article{dosoViTskiy2020,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={DosoViTskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
journal={arXiv preprint arXiv:2010.11929},
year={2020}
}
@article{Ranftl2021,
author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun},
title = {Vision Transformers for Dense Prediction},
journal = {ArXiv preprint},
year = {2021},
}
```

## Usage

To use other repositories' pre-trained models, it is necessary to convert keys.

We provide a script [`vit2mmseg.py`](../../tools/model_converters/vit2mmseg.py) in the tools directory to convert the key of models from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to MMSegmentation style.

```shell
python tools/model_converters/vit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```

E.g.

```shell
python tools/model_converters/vit2mmseg.py https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth pretrain/jx_vit_base_p16_224-80ecf9dd.pth
```

This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.

## Results and models

### ADE20K

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | ---------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| DPT | ViT-B | 512x512 | 160000 | 8.09 | 10.41 | 46.97 | 48.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-db31cf52.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-20210809_172025.log.json) |
28 changes: 28 additions & 0 deletions configs/dpt/dpt.yml
@@ -0,0 +1,28 @@
Collections:
- Metadata:
Training Data:
- ADE20K
Name: dpt
Models:
- Config: configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py
In Collection: dpt
Metadata:
backbone: ViT-B
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 96.06
lr schd: 160000
memory (GB): 8.09
Name: dpt_vit-b16_512x512_160k_ade20k
Results:
Dataset: ADE20K
Metrics:
mIoU: 46.97
mIoU(ms+flip): 48.34
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-db31cf52.pth
32 changes: 32 additions & 0 deletions configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py
@@ -0,0 +1,32 @@
_base_ = [
'../_base_/models/dpt_vit-b16.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2, workers_per_gpu=2)
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Expand Up @@ -6,6 +6,7 @@
from .da_head import DAHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .dpt_head import DPTHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
Expand All @@ -29,5 +30,5 @@
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'SegformerHead'
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead'
]

0 comments on commit 2825efe

Please sign in to comment.