Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add DPT head #605

Merged
merged 48 commits into from Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
46dce78
add DPT head
Jun 17, 2021
7b80fd0
[fix] fix init error
Jun 17, 2021
01b3da2
use mmcv function
Jun 18, 2021
e9df435
delete code
Jun 19, 2021
93635c0
merge upstream
Jun 19, 2021
b21ea15
remove transpose clas
Jun 19, 2021
2efb2eb
support NLC output shape
Jun 19, 2021
685644a
Merge branch 'add_vit_output_type' into dpt
Jun 19, 2021
5f877e1
Delete post_process_layer.py
Jun 22, 2021
5ce02d3
add unittest and docstring
Jun 22, 2021
7f7e4a4
Merge branch 'dpt' of https://github.com/xiexinch/mmsegmentation into…
Jun 22, 2021
de5b3a2
merge conflict
Jun 22, 2021
adbfb60
merge upstream master
Jul 5, 2021
31c42bd
rename variables
Jul 5, 2021
bf900b6
fix project error and add unittest
Jul 5, 2021
716863b
match dpt weights
Jul 6, 2021
94bf935
add configs
Jul 6, 2021
d4cd924
fix vit pos_embed bug and dpt feature fusion bug
Jul 7, 2021
ded2834
merge master
Jul 20, 2021
f147aa9
match vit output
Jul 20, 2021
0e4fb4f
fix gelu
Jul 20, 2021
6073dfa
minor change
Jul 20, 2021
1ebb558
update unitest
Jul 20, 2021
b3903ca
fix configs error
Jul 20, 2021
ef87aa5
inference test
Jul 22, 2021
9669d54
remove auxilary
Jul 22, 2021
0363746
use local pretrain
Jul 29, 2021
e1ecf6a
update training results
Aug 11, 2021
0126c24
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
Aug 11, 2021
7726d2b
update yml
Aug 11, 2021
c5593af
update fps and memory test
Aug 12, 2021
30aabc4
update doc
Aug 19, 2021
64e6f64
update readme
Aug 19, 2021
b749507
merge master
Aug 19, 2021
96ce175
add yml
Aug 19, 2021
fa61339
update doc
Aug 19, 2021
55bcd74
remove with_cp
Aug 19, 2021
4b33f6f
update config
Aug 19, 2021
76344cd
update docstring
Aug 19, 2021
94fb8d4
remove dpt-l
Aug 25, 2021
5e56d1b
add init_cfg and modify readme.md
Aug 25, 2021
f4ad2fa
Update dpt_vit-b16.py
Junjun2016 Aug 25, 2021
161d494
zh-n README
Aug 25, 2021
6b506ba
Merge branch 'dpt' of github.com:xiexinch/mmsegmentation into dpt
Aug 25, 2021
dca6387
solve conflict
Aug 30, 2021
a41ce05
use constructor instead of build function
Aug 30, 2021
78b56b1
prevent tensor being modified by ConvModule
Aug 30, 2021
522cdff
fix unittest
Aug 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions configs/_base_/models/dpt_vit-b16.py
@@ -0,0 +1,32 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='pretrain/vit-b16_p16_224.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=224,
embed_dims=768,
num_layers=12,
num_heads=12,
out_indices=(2, 5, 8, 11),
final_norm=False,
with_cls_token=True,
with_cp=True,
output_cls_token=True),
decode_head=dict(
type='DPTHead',
in_channels=(768, 768, 768, 768),
channels=256,
embed_dims=768,
post_process_channels=[96, 192, 384, 768],
num_classes=150,
readout_type='project',
input_transform='multiple_select',
in_index=(0, 1, 2, 3),
norm_cfg=norm_cfg,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=None,
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole')) # yapf: disable
40 changes: 40 additions & 0 deletions configs/dpt/README.md
@@ -0,0 +1,40 @@
# Vision Transformer for Dense Prediction

## Introduction

<!-- [ALGORITHM] -->

```latex
@article{dosoViTskiy2020,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={DosoViTskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
journal={arXiv preprint arXiv:2010.11929},
year={2020}
}

@article{Ranftl2021,
author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun},
title = {Vision Transformers for Dense Prediction},
journal = {ArXiv preprint},
year = {2021},
}
```

## How to use ViT pretrain weights
xiexinch marked this conversation as resolved.
Show resolved Hide resolved

We convert the backbone weights from the pytorch-image-models repo (https://github.com/rwightman/pytorch-image-models) with `tools/model_converters/vit_convert.py`.
xiexinch marked this conversation as resolved.
Show resolved Hide resolved

You may follow below steps to start segformer training preparation:
xiexinch marked this conversation as resolved.
Show resolved Hide resolved

1. Download segformer pretrain weights (Suggest put in `pretrain/`);
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
2. Run convert script to convert official pretrain weights: `python tools/model_converters/vit_convert.py pretrain/vit_timm.pth pretrain/vit-b16__p16_224.pth`;
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
3. Modify `pretrained` of VisionTransformer model config, for example, `pretrained` of `dpt_vit-b16.py` is set to `pretrain/vit-b16_p16_224.pth`;

## Results and models

### ADE20K

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | ---------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| DPT | ViT-B | 512x512 | 160000 | 8.09 | 10.41 | 46.97 | 48.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-db31cf52.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-20210809_172025.log.json) |
| DPT | ViT-L | 512x512 | 160000 | 18.37 | 4.36 | 46.19 | 46.97 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dpt/dpt_vit-l16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-l16_512x512_160k_ade20k/dpt_vit-l16_512x512_160k_ade20k-7b753ca6.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-l16_512x512_160k_ade20k/dpt_vit-l16_512x512_160k_ade20k-20210809_172025.log.json) |
50 changes: 50 additions & 0 deletions configs/dpt/dpt.yml
@@ -0,0 +1,50 @@
Collections:
- Metadata:
Training Data:
- ADE20K
Name: dpt
Models:
- Config: configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py
In Collection: dpt
Metadata:
backbone: ViT-B
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 96.06
lr schd: 160000
memory (GB): 8.09
Name: dpt_vit-b16_512x512_160k_ade20k
Results:
Dataset: ADE20K
Metrics:
mIoU: 46.97
mIoU(ms+flip): 48.34
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-b16_512x512_160k_ade20k/dpt_vit-b16_512x512_160k_ade20k-db31cf52.pth
- Config: configs/dpt/dpt_vit-l16_512x512_160k_ade20k.py
In Collection: dpt
Metadata:
backbone: ViT-L
crop size: (512,512)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (512,512)
value: 229.36
lr schd: 160000
memory (GB): 18.37
Name: dpt_vit-l16_512x512_160k_ade20k
Results:
Dataset: ADE20K
Metrics:
mIoU: 46.19
mIoU(ms+flip): 46.97
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/dpt/dpt_vit-l16_512x512_160k_ade20k/dpt_vit-l16_512x512_160k_ade20k-7b753ca6.pth
32 changes: 32 additions & 0 deletions configs/dpt/dpt_vit-b16_512x512_160k_ade20k.py
@@ -0,0 +1,32 @@
_base_ = [
'../_base_/models/dpt_vit-b16.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(workers_per_gpu=2)
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 24 additions & 0 deletions configs/dpt/dpt_vit-l16_512x512_160k_ade20k.py
@@ -0,0 +1,24 @@
_base_ = './dpt_vit-b16_512x512_160k_ade20k.py'

model = dict(
type='EncoderDecoder',
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
pretrained='pretrain/vit-l16_p16_384.pth', # noqa
backbone=dict(
type='VisionTransformer',
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
img_size=384,
embed_dims=1024,
num_heads=16,
num_layers=24,
out_indices=(5, 11, 17, 23),
final_norm=False,
with_cls_token=True,
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
output_cls_token=True),
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
decode_head=dict(
type='DPTHead',
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
in_channels=(1024, 1024, 1024, 1024),
channels=256,
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
embed_dims=1024,
post_process_channels=[256, 512, 1024, 1024]),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole')) # yapf: disable
xiexinch marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Expand Up @@ -6,6 +6,7 @@
from .da_head import DAHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .dpt_head import DPTHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
Expand All @@ -29,5 +30,5 @@
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'SegformerHead'
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead'
]