Skip to content

Commit

Permalink
[Feature] Support DDRNet (open-mmlab#2855)
Browse files Browse the repository at this point in the history
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Support DDRNet
Paper: [Deep Dual-resolution Networks for Real-time and Accurate
Semantic Segmentation of Road Scenes](https://arxiv.org/pdf/2101.06085)
official Code: https://github.com/ydhongHIT/DDRNet


There is already a PR
open-mmlab#1722 , but it has been
inactive for a long time.

## Current Result

### Cityscapes

#### inference with converted official weights

| Method | Backbone      | mIoU(official) | mIoU(converted weight) |
| ------ | ------------- | -------------- | ---------------------- |
| DDRNet | DDRNet23-slim | 77.8           | 77.84                  |
| DDRNet | DDRNet23 | 79.5 | 79.53 |

#### training with converted pretrained backbone

| Method | Backbone | Crop Size | Lr schd | Inf time(fps) | Device |
mIoU | mIoU(ms+flip) | config | download |
| ------ | ------------- | --------- | ------- | ------- | -------- |
----- | ------------- | ------------ | ------------ |
| DDRNet | DDRNet23-slim | 1024x1024 | 120000 | 85.85 | RTX 8000 | 77.85
| 79.80 |
[config](https://github.com/whu-pzhang/mmsegmentation/blob/ddrnet/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py)
| model \| log |
| DDRNet | DDRNet23 | 1024x1024 | 120000 | 33.41 | RTX 8000 | 79.53 |
80.98 |
[config](https://github.com/whu-pzhang/mmsegmentation/blob/ddrnet/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py)
| model \| log |


The converted pretrained backbone weights download link:

1.
[ddrnet23s_in1k_mmseg.pth](https://drive.google.com/file/d/1Ni4F1PMGGjuld-1S9fzDTmneLfpMuPTG/view?usp=sharing)
2.
[ddrnet23_in1k_mmseg.pth](https://drive.google.com/file/d/11rsijC1xOWB6B0LgNQkAG-W6e1OdbCyJ/view?usp=sharing)

## To do

- [x] support inference with converted official weights
- [x] support training on cityscapes dataset

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
  • Loading branch information
whu-pzhang and xiexinch committed Apr 27, 2023
1 parent 820c555 commit 3bc5cc4
Show file tree
Hide file tree
Showing 9 changed files with 595 additions and 2 deletions.
47 changes: 47 additions & 0 deletions configs/ddrnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# DDRNet

> [Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes](http://arxiv.org/abs/2101.06085)
## Introduction

<!-- [ALGORITHM] -->

<a href="https://github.com/ydhongHIT/DDRNet">Official Repo</a>

## Abstract

<!-- [ABSTRACT] -->

Semantic segmentation is a key technology for autonomous vehicles to understand the surrounding scenes. The appealing performances of contemporary models usually come at the expense of heavy computations and lengthy inference time, which is intolerable for self-driving. Using light-weight architectures (encoder-decoder or two-pathway) or reasoning on low-resolution images, recent methods realize very fast scene parsing, even running at more than 100 FPS on a single 1080Ti GPU. However, there is still a significant gap in performance between these real-time methods and the models based on dilation backbones. To tackle this problem, we proposed a family of efficient backbones specially designed for real-time semantic segmentation. The proposed deep dual-resolution networks (DDRNets) are composed of two deep branches between which multiple bilateral fusions are performed. Additionally, we design a new contextual information extractor named Deep Aggregation Pyramid Pooling Module (DAPPM) to enlarge effective receptive fields and fuse multi-scale context based on low-resolution feature maps. Our method achieves a new state-of-the-art trade-off between accuracy and speed on both Cityscapes and CamVid dataset. In particular, on a single 2080Ti GPU, DDRNet-23-slim yields 77.4% mIoU at 102 FPS on Cityscapes test set and 74.7% mIoU at 230 FPS on CamVid test set. With widely used test augmentation, our method is superior to most state-of-the-art models and requires much less computation. Codes and trained models are available online.

<!-- [IMAGE] -->

<div align=center>
<img src="https://raw.githubusercontent.com/ydhongHIT/DDRNet/main/figs/DDRNet_seg.png" width="80%"/>
</div>

## Results and models

### Cityscapes

| Method | Backbone | Crop Size | Lr schd | Mem(GB) | Inf time(fps) | Device | mIoU | mIoU(ms+flip) | config | download |
| ------ | ------------- | --------- | ------- | ------- | ------------- | ------ | ----- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| DDRNet | DDRNet23-slim | 1024x1024 | 120000 | 1.70 | 85.85 | A100 | 77.84 | 80.15 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230426_145312-6a5e5174.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230426_145312.json) |
| DDRNet | DDRNet23 | 1024x1024 | 120000 | 7.26 | 33.41 | A100 | 79.99 | 81.71 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230425_162633-81601db0.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230425_162633.json) |

## Notes

The pretrained weights in config files are converted from [the official repo](https://github.com/ydhongHIT/DDRNet#pretrained-models).

## Citation

```bibtex
@misc{hong2021ddrnet,
title={Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes},
author={Hong, Yuanduo and Pan, Huihui and Sun, Weichao and Jia, Yisong},
year={2021},
eprint={2101.06085},
archivePrefix={arXiv},
primaryClass={cs.CV},
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
_base_ = [
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py',
]

# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
# Licensed under the MIT License
class_weight = [
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
1.0507
]

crop_size = (1024, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='DDRNet',
in_channels=3,
channels=32,
ppm_channels=128,
norm_cfg=norm_cfg,
align_corners=False,
init_cfg=dict(
type='Pretrained',
checkpoint='pretrained/ddrnet23s_in1k_mmseg.pth')),
decode_head=dict(
type='DDRHead',
in_channels=32 * 4,
channels=64,
dropout_ratio=0.,
num_classes=19,
align_corners=False,
norm_cfg=norm_cfg,
loss_decode=[
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=0.4),
]),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

train_dataloader = dict(batch_size=6, num_workers=4)

iters = 120000
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=iters,
by_epoch=False)
]

# training schedule for 120k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=iters // 10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

randomness = dict(seed=304)
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
_base_ = [
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py',
]

# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
# Licensed under the MIT License
class_weight = [
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
1.0507
]

crop_size = (1024, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='DDRNet',
in_channels=3,
channels=64,
ppm_channels=128,
norm_cfg=norm_cfg,
align_corners=False,
init_cfg=dict(
type='Pretrained',
checkpoint='pretrained/ddrnet23_in1k_mmseg.pth')),
decode_head=dict(
type='DDRHead',
in_channels=64 * 4,
channels=128,
dropout_ratio=0.,
num_classes=19,
align_corners=False,
norm_cfg=norm_cfg,
loss_decode=[
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=0.4),
]),

# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

train_dataloader = dict(batch_size=6, num_workers=4)

iters = 120000
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=iters,
by_epoch=False)
]

# training schedule for 120k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=iters // 10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

randomness = dict(seed=304)
14 changes: 14 additions & 0 deletions configs/ddrnet/metafile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Collections:
- Name: ''
License: Apache License 2.0
Metadata:
Training Data:
- Cityscapes
Paper:
Title: Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation
of Road Scenes
URL: http://arxiv.org/abs/2101.06085
README: configs/ddrnet/README.md
Frameworks:
- PyTorch
Models: []
4 changes: 3 additions & 1 deletion mmseg/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .bisenetv1 import BiSeNetV1
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
from .ddrnet import DDRNet
from .erfnet import ERFNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
Expand All @@ -28,5 +29,6 @@
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN'
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
'DDRNet'
]
Loading

0 comments on commit 3bc5cc4

Please sign in to comment.