Skip to content

Commit

Permalink
[Feature] Add DBNet++ (#973)
Browse files Browse the repository at this point in the history
* add dbnet++

* fix docstring

* fix some param names

* fix

* fix docstring

* add init

* add doc; remove configs

* add dbnet++ to readme

* fix readme

* update config

* update readme

* update readme

* update ocr.py

* update metafile.yml

* update readme

* update readme

* move to dbnetpp

* fix paths

* fix head level

* fix typo

* update demo.md

* Update configs/textdet/dbnetpp/README.md

Co-authored-by: Tong Gao <gaotongxiao@gmail.com>

* fix typo

* fix link
  • Loading branch information
xinke-wang committed May 5, 2022
1 parent b4678eb commit fbc138d
Show file tree
Hide file tree
Showing 14 changed files with 320 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Supported algorithms:
<details open>
<summary>Text Detection</summary>

- [x] [DBNet](configs/textdet/dbnet/README.md) (AAAI'2020)
- [x] [DBNet](configs/textdet/dbnet/README.md) (AAAI'2020) / [DBNet++](configs/textdet/dbnetpp/README.md) (TPAMI'2022)
- [x] [Mask R-CNN](configs/textdet/maskrcnn/README.md) (ICCV'2017)
- [x] [PANet](configs/textdet/panet/README.md) (ICCV'2019)
- [x] [PSENet](configs/textdet/psenet/README.md) (CVPR'2019)
Expand Down
2 changes: 1 addition & 1 deletion README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检
<details open>
<summary>文字检测</summary>

- [x] [DBNet](configs/textdet/dbnet/README.md) (AAAI'2020)
- [x] [DBNet](configs/textdet/dbnet/README.md) (AAAI'2020) / [DBNet++](configs/textdet/dbnetpp/README.md) (TPAMI'2022)
- [x] [Mask R-CNN](configs/textdet/maskrcnn/README.md) (ICCV'2017)
- [x] [PANet](configs/textdet/panet/README.md) (ICCV'2019)
- [x] [PSENet](configs/textdet/psenet/README.md) (CVPR'2019)
Expand Down
28 changes: 28 additions & 0 deletions configs/_base_/det_models/dbnetpp_r50dcnv2_fpnc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
model = dict(
type='DBNet',
backbone=dict(
type='mmdet.ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
style='pytorch',
dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
stage_with_dcn=(False, True, True, True)),
neck=dict(
type='FPNC',
in_channels=[256, 512, 1024, 2048],
lateral_channels=256,
asf_cfg=dict(attention_type='ScaleChannelSpatial')),
bbox_head=dict(
type='DBHead',
in_channels=256,
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True),
postprocessor=dict(
type='DBPostprocessor', text_repr_type='quad',
epsilon_ratio=0.002)),
train_cfg=None,
test_cfg=None)
2 changes: 1 addition & 1 deletion configs/textdet/dbnet/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Collections:
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 8x GeForce GTX 1080 Ti
Training Resources: 1x GeForce GTX 1080 Ti
Architecture:
- ResNet
- FPNC
Expand Down
33 changes: 33 additions & 0 deletions configs/textdet/dbnetpp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# DBNetpp

> [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304)
<!-- [ALGORITHM] -->

## Abstract

Recently, segmentation-based scene text detection methods have drawn extensive attention in the scene text detection field, because of their superiority in detecting the text instances of arbitrary shapes and extreme aspect ratios, profiting from the pixel-level descriptions. However, the vast majority of the existing segmentation-based approaches are limited to their complex post-processing algorithms and the scale robustness of their segmentation models, where the post-processing algorithms are not only isolated to the model optimization but also time-consuming and the scale robustness is usually strengthened by fusing multi-scale feature maps directly. In this paper, we propose a Differentiable Binarization (DB) module that integrates the binarization process, one of the most important steps in the post-processing procedure, into a segmentation network. Optimized along with the proposed DB module, the segmentation network can produce more accurate results, which enhances the accuracy of text detection with a simple pipeline. Furthermore, an efficient Adaptive Scale Fusion (ASF) module is proposed to improve the scale robustness by fusing features of different scales adaptively. By incorporating the proposed DB and ASF with the segmentation network, our proposed scene text detector consistently achieves state-of-the-art results, in terms of both detection accuracy and speed, on five standard benchmarks.

<div align=center>
<img src="https://user-images.githubusercontent.com/45810070/166850828-f1e48c25-4a0f-429d-ae54-6997ed25c062.png"/>
</div>

## Results and models

### ICDAR2015

| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
| :---------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DBNetpp_r50dcn](/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext.py) ([model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext-20220502-db297554.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext-20220502-db297554.log.json))| ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.822 | 0.901 | 0.860 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.log.json) |

## Citation

```bibtex
@article{liao2022real,
title={Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion},
author={Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2022},
publisher={IEEE}
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_sgd_100k_iters.py',
'../../_base_/det_models/dbnetpp_r50dcnv2_fpnc.py',
'../../_base_/det_datasets/synthtext.py',
'../../_base_/det_pipelines/dbnet_pipeline.py'
]

train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}

train_pipeline_r50dcnv2 = {{_base_.train_pipeline_r50dcnv2}}
test_pipeline_4068_1024 = {{_base_.test_pipeline_4068_1024}}

data = dict(
samples_per_gpu=16,
workers_per_gpu=8,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline_r50dcnv2),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline_4068_1024),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline_4068_1024))

evaluation = dict(interval=200000, metric='hmean-iou') # do not evaluate
39 changes: 39 additions & 0 deletions configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_sgd_1200e.py',
'../../_base_/det_models/dbnetpp_r50dcnv2_fpnc.py',
'../../_base_/det_datasets/icdar2015.py',
'../../_base_/det_pipelines/dbnet_pipeline.py'
]

train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}

train_pipeline_r50dcnv2 = {{_base_.train_pipeline_r50dcnv2}}
test_pipeline_4068_1024 = {{_base_.test_pipeline_4068_1024}}

load_from = 'checkpoints/textdet/dbnetpp/res50dcnv2_synthtext.pth'

data = dict(
samples_per_gpu=32,
workers_per_gpu=8,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline_r50dcnv2),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline_4068_1024),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline_4068_1024))

evaluation = dict(
interval=100,
metric='hmean-iou',
save_best='0_hmean-iou:hmean',
rule='greater')
28 changes: 28 additions & 0 deletions configs/textdet/dbnetpp/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Collections:
- Name: DBNetpp
Metadata:
Training Data: ICDAR2015
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 1x Nvidia A100
Architecture:
- ResNet
- FPNC
Paper:
URL: https://arxiv.org/abs/2202.10304
Title: 'Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion'
README: configs/textdet/dbnetpp/README.md

Models:
- Name: dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py
In Collection: DBNetpp
Config: configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py
Metadata:
Training Data: ICDAR2015
Results:
- Task: Text Detection
Dataset: ICDAR2015
Metrics:
hmean-iou: 0.860
Weights: https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth
1 change: 1 addition & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ means that `batch_mode` and `print_result` are set to `True`)
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------: |
| DB_r18 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
| DB_r50 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
| DBPP_r50 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#dbnetpp) | :x: |
| DRRG | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#drrg) | :x: |
| FCE_IC15 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: |
| FCE_CTW_DCNv2 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: |
Expand Down
1 change: 1 addition & 0 deletions demo/README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ mmocr 为了方便使用提供了预置的模型配置和对应的预训练权
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------: |
| DB_r18 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
| DB_r50 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
| DBPP_r50 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#dbnetpp) | :x: |
| DRRG | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#drrg) | :x: |
| FCE_IC15 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: |
| FCE_CTW_DCNv2 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: |
Expand Down
132 changes: 130 additions & 2 deletions mmocr/models/textdet/necks/fpn_cat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList, auto_fp16
from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16

from mmocr.models.builder import NECKS

Expand All @@ -26,6 +27,8 @@ class FPNC(BaseModule):
bias_on_smooth (bool): Whether to use bias on smoothing layer.
bn_re_on_smooth (bool): Whether to use BatchNorm and ReLU on smoothing
layer.
asf_cfg (dict): Adaptive Scale Fusion module configs. The
attention_type can be 'ScaleChannelSpatial'.
conv_after_concat (bool): Whether to add a convolution layer after
the concatenation of predictions.
init_cfg (dict or list[dict], optional): Initialization configs.
Expand All @@ -39,8 +42,13 @@ def __init__(self,
bn_re_on_lateral=False,
bias_on_smooth=False,
bn_re_on_smooth=False,
asf_cfg=None,
conv_after_concat=False,
init_cfg=None):
init_cfg=[
dict(type='Kaiming', layer='Conv'),
dict(
type='Constant', layer='BatchNorm', val=1., bias=1e-4)
]):
super().__init__(init_cfg=init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
Expand All @@ -49,6 +57,7 @@ def __init__(self,
self.num_ins = len(in_channels)
self.bn_re_on_lateral = bn_re_on_lateral
self.bn_re_on_smooth = bn_re_on_smooth
self.asf_cfg = asf_cfg
self.conv_after_concat = conv_after_concat
self.lateral_convs = ModuleList()
self.smooth_convs = ModuleList()
Expand Down Expand Up @@ -88,6 +97,24 @@ def __init__(self,

self.lateral_convs.append(l_conv)
self.smooth_convs.append(smooth_conv)

if self.asf_cfg is not None:
self.asf_conv = ConvModule(
out_channels * self.num_outs,
out_channels * self.num_outs,
3,
padding=1,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
inplace=False)
if self.asf_cfg['attention_type'] == 'ScaleChannelSpatial':
self.asf_attn = ScaleChannelSpatialAttention(
self.out_channels * self.num_outs,
(self.out_channels * self.num_outs) // 4, self.num_outs)
else:
raise NotImplementedError

if self.conv_after_concat:
norm_cfg = dict(type='BN')
act_cfg = dict(type='ReLU')
Expand Down Expand Up @@ -135,9 +162,110 @@ def forward(self, inputs):
for i, out in enumerate(outs):
outs[i] = F.interpolate(
outs[i], size=outs[0].shape[2:], mode='nearest')

out = torch.cat(outs, dim=1)
if self.asf_cfg is not None:
asf_feature = self.asf_conv(out)
attention = self.asf_attn(asf_feature)
enhanced_feature = []
for i, out in enumerate(outs):
enhanced_feature.append(attention[:, i:i + 1] * outs[i])
out = torch.cat(enhanced_feature, dim=1)

if self.conv_after_concat:
out = self.out_conv(out)

return out


class ScaleChannelSpatialAttention(BaseModule):
"""Spatial Attention module in Real-Time Scene Text Detection with
Differentiable Binarization and Adaptive Scale Fusion.
This was partially adapted from https://github.com/MhLiao/DB
Args:
in_channels (int): A numbers of input channels.
c_wise_channels (int): Number of channel-wise attention channels.
out_channels (int): Number of output channels.
init_cfg (dict or list[dict], optional): Initialization configs.
"""

def __init__(self,
in_channels,
c_wise_channels,
out_channels,
init_cfg=[dict(type='Kaiming', layer='Conv', bias=0)]):
super().__init__(init_cfg=init_cfg)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# Channel Wise
self.channel_wise = Sequential(
ConvModule(
in_channels,
c_wise_channels,
1,
bias=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
inplace=False),
ConvModule(
c_wise_channels,
in_channels,
1,
bias=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='Sigmoid'),
inplace=False))
# Spatial Wise
self.spatial_wise = Sequential(
ConvModule(
1,
1,
3,
padding=1,
bias=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
inplace=False),
ConvModule(
1,
1,
1,
bias=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='Sigmoid'),
inplace=False))
# Attention Wise
self.attention_wise = ConvModule(
in_channels,
out_channels,
1,
bias=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='Sigmoid'),
inplace=False)

@auto_fp16()
def forward(self, inputs):
"""
Args:
inputs (Tensor): A concat FPN feature tensor that has the shape of
:math:`(N, C, H, W)`.
Returns:
Tensor: An attention map of shape :math:`(N, C_{out}, H, W)`
where :math:`C_{out}` is ``out_channels``.
"""
out = self.avg_pool(inputs)
out = self.channel_wise(out)
out = out + inputs
inputs = torch.mean(out, dim=1, keepdim=True)
out = self.spatial_wise(inputs) + out
out = self.attention_wise(out)

return out
Loading

0 comments on commit fbc138d

Please sign in to comment.