diff --git a/README.md b/README.md index ac593fb42..34255fe7b 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ Supported algorithms:
Text Detection -- [x] [DBNet](configs/textdet/dbnet/README.md) (AAAI'2020) +- [x] [DBNet](configs/textdet/dbnet/README.md) (AAAI'2020) / [DBNet++](configs/textdet/dbnetpp/README.md) (TPAMI'2022) - [x] [Mask R-CNN](configs/textdet/maskrcnn/README.md) (ICCV'2017) - [x] [PANet](configs/textdet/panet/README.md) (ICCV'2019) - [x] [PSENet](configs/textdet/psenet/README.md) (CVPR'2019) diff --git a/README_zh-CN.md b/README_zh-CN.md index b70705786..a4c58ad05 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -66,7 +66,7 @@ MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检
文字检测 -- [x] [DBNet](configs/textdet/dbnet/README.md) (AAAI'2020) +- [x] [DBNet](configs/textdet/dbnet/README.md) (AAAI'2020) / [DBNet++](configs/textdet/dbnetpp/README.md) (TPAMI'2022) - [x] [Mask R-CNN](configs/textdet/maskrcnn/README.md) (ICCV'2017) - [x] [PANet](configs/textdet/panet/README.md) (ICCV'2019) - [x] [PSENet](configs/textdet/psenet/README.md) (CVPR'2019) diff --git a/configs/_base_/det_models/dbnetpp_r50dcnv2_fpnc.py b/configs/_base_/det_models/dbnetpp_r50dcnv2_fpnc.py new file mode 100644 index 000000000..f8eaf2ffd --- /dev/null +++ b/configs/_base_/det_models/dbnetpp_r50dcnv2_fpnc.py @@ -0,0 +1,28 @@ +model = dict( + type='DBNet', + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + style='pytorch', + dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), + stage_with_dcn=(False, True, True, True)), + neck=dict( + type='FPNC', + in_channels=[256, 512, 1024, 2048], + lateral_channels=256, + asf_cfg=dict(attention_type='ScaleChannelSpatial')), + bbox_head=dict( + type='DBHead', + in_channels=256, + loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True), + postprocessor=dict( + type='DBPostprocessor', text_repr_type='quad', + epsilon_ratio=0.002)), + train_cfg=None, + test_cfg=None) diff --git a/configs/textdet/dbnet/metafile.yml b/configs/textdet/dbnet/metafile.yml index 597fe42e4..c6abdbca6 100644 --- a/configs/textdet/dbnet/metafile.yml +++ b/configs/textdet/dbnet/metafile.yml @@ -5,7 +5,7 @@ Collections: Training Techniques: - SGD with Momentum - Weight Decay - Training Resources: 8x GeForce GTX 1080 Ti + Training Resources: 1x GeForce GTX 1080 Ti Architecture: - ResNet - FPNC diff --git a/configs/textdet/dbnetpp/README.md b/configs/textdet/dbnetpp/README.md new file mode 100644 index 000000000..a3e08c3be --- /dev/null +++ b/configs/textdet/dbnetpp/README.md @@ -0,0 +1,33 @@ +# DBNetpp + +> [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304) + + + +## Abstract + +Recently, segmentation-based scene text detection methods have drawn extensive attention in the scene text detection field, because of their superiority in detecting the text instances of arbitrary shapes and extreme aspect ratios, profiting from the pixel-level descriptions. However, the vast majority of the existing segmentation-based approaches are limited to their complex post-processing algorithms and the scale robustness of their segmentation models, where the post-processing algorithms are not only isolated to the model optimization but also time-consuming and the scale robustness is usually strengthened by fusing multi-scale feature maps directly. In this paper, we propose a Differentiable Binarization (DB) module that integrates the binarization process, one of the most important steps in the post-processing procedure, into a segmentation network. Optimized along with the proposed DB module, the segmentation network can produce more accurate results, which enhances the accuracy of text detection with a simple pipeline. Furthermore, an efficient Adaptive Scale Fusion (ASF) module is proposed to improve the scale robustness by fusing features of different scales adaptively. By incorporating the proposed DB and ASF with the segmentation network, our proposed scene text detector consistently achieves state-of-the-art results, in terms of both detection accuracy and speed, on five standard benchmarks. + +
+ +
+ +## Results and models + +### ICDAR2015 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :---------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [DBNetpp_r50dcn](/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext.py) ([model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext-20220502-db297554.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext-20220502-db297554.log.json))| ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.822 | 0.901 | 0.860 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.log.json) | + +## Citation + +```bibtex +@article{liao2022real, + title={Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion}, + author={Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang}, + journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, + year={2022}, + publisher={IEEE} +} +``` diff --git a/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext.py b/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext.py new file mode 100644 index 000000000..cab4e9f77 --- /dev/null +++ b/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext.py @@ -0,0 +1,33 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/schedules/schedule_sgd_100k_iters.py', + '../../_base_/det_models/dbnetpp_r50dcnv2_fpnc.py', + '../../_base_/det_datasets/synthtext.py', + '../../_base_/det_pipelines/dbnet_pipeline.py' +] + +train_list = {{_base_.train_list}} +test_list = {{_base_.test_list}} + +train_pipeline_r50dcnv2 = {{_base_.train_pipeline_r50dcnv2}} +test_pipeline_4068_1024 = {{_base_.test_pipeline_4068_1024}} + +data = dict( + samples_per_gpu=16, + workers_per_gpu=8, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + train=dict( + type='UniformConcatDataset', + datasets=train_list, + pipeline=train_pipeline_r50dcnv2), + val=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_4068_1024), + test=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_4068_1024)) + +evaluation = dict(interval=200000, metric='hmean-iou') # do not evaluate diff --git a/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py b/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py new file mode 100644 index 000000000..bc6ab78ca --- /dev/null +++ b/configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py @@ -0,0 +1,39 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/schedules/schedule_sgd_1200e.py', + '../../_base_/det_models/dbnetpp_r50dcnv2_fpnc.py', + '../../_base_/det_datasets/icdar2015.py', + '../../_base_/det_pipelines/dbnet_pipeline.py' +] + +train_list = {{_base_.train_list}} +test_list = {{_base_.test_list}} + +train_pipeline_r50dcnv2 = {{_base_.train_pipeline_r50dcnv2}} +test_pipeline_4068_1024 = {{_base_.test_pipeline_4068_1024}} + +load_from = 'checkpoints/textdet/dbnetpp/res50dcnv2_synthtext.pth' + +data = dict( + samples_per_gpu=32, + workers_per_gpu=8, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + train=dict( + type='UniformConcatDataset', + datasets=train_list, + pipeline=train_pipeline_r50dcnv2), + val=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_4068_1024), + test=dict( + type='UniformConcatDataset', + datasets=test_list, + pipeline=test_pipeline_4068_1024)) + +evaluation = dict( + interval=100, + metric='hmean-iou', + save_best='0_hmean-iou:hmean', + rule='greater') diff --git a/configs/textdet/dbnetpp/metafile.yml b/configs/textdet/dbnetpp/metafile.yml new file mode 100644 index 000000000..b40571c11 --- /dev/null +++ b/configs/textdet/dbnetpp/metafile.yml @@ -0,0 +1,28 @@ +Collections: +- Name: DBNetpp + Metadata: + Training Data: ICDAR2015 + Training Techniques: + - SGD with Momentum + - Weight Decay + Training Resources: 1x Nvidia A100 + Architecture: + - ResNet + - FPNC + Paper: + URL: https://arxiv.org/abs/2202.10304 + Title: 'Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion' + README: configs/textdet/dbnetpp/README.md + +Models: + - Name: dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py + In Collection: DBNetpp + Config: configs/textdet/dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py + Metadata: + Training Data: ICDAR2015 + Results: + - Task: Text Detection + Dataset: ICDAR2015 + Metrics: + hmean-iou: 0.860 + Weights: https://download.openmmlab.com/mmocr/textdet/dbnet/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth diff --git a/demo/README.md b/demo/README.md index f5e5eb3d0..2b492b9b9 100644 --- a/demo/README.md +++ b/demo/README.md @@ -199,6 +199,7 @@ means that `batch_mode` and `print_result` are set to `True`) | ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------: | | DB_r18 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: | | DB_r50 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: | +| DBPP_r50 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#dbnetpp) | :x: | | DRRG | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#drrg) | :x: | | FCE_IC15 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: | | FCE_CTW_DCNv2 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: | diff --git a/demo/README_zh-CN.md b/demo/README_zh-CN.md index dd193b92c..1b6545cb1 100644 --- a/demo/README_zh-CN.md +++ b/demo/README_zh-CN.md @@ -196,6 +196,7 @@ mmocr 为了方便使用提供了预置的模型配置和对应的预训练权 | ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------: | | DB_r18 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: | | DB_r50 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: | +| DBPP_r50 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#dbnetpp) | :x: | | DRRG | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#drrg) | :x: | | FCE_IC15 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: | | FCE_CTW_DCNv2 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: | diff --git a/mmocr/models/textdet/necks/fpn_cat.py b/mmocr/models/textdet/necks/fpn_cat.py index 90d9d222d..5acd2bcc1 100644 --- a/mmocr/models/textdet/necks/fpn_cat.py +++ b/mmocr/models/textdet/necks/fpn_cat.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule -from mmcv.runner import BaseModule, ModuleList, auto_fp16 +from mmcv.runner import BaseModule, ModuleList, Sequential, auto_fp16 from mmocr.models.builder import NECKS @@ -26,6 +27,8 @@ class FPNC(BaseModule): bias_on_smooth (bool): Whether to use bias on smoothing layer. bn_re_on_smooth (bool): Whether to use BatchNorm and ReLU on smoothing layer. + asf_cfg (dict): Adaptive Scale Fusion module configs. The + attention_type can be 'ScaleChannelSpatial'. conv_after_concat (bool): Whether to add a convolution layer after the concatenation of predictions. init_cfg (dict or list[dict], optional): Initialization configs. @@ -39,8 +42,13 @@ def __init__(self, bn_re_on_lateral=False, bias_on_smooth=False, bn_re_on_smooth=False, + asf_cfg=None, conv_after_concat=False, - init_cfg=None): + init_cfg=[ + dict(type='Kaiming', layer='Conv'), + dict( + type='Constant', layer='BatchNorm', val=1., bias=1e-4) + ]): super().__init__(init_cfg=init_cfg) assert isinstance(in_channels, list) self.in_channels = in_channels @@ -49,6 +57,7 @@ def __init__(self, self.num_ins = len(in_channels) self.bn_re_on_lateral = bn_re_on_lateral self.bn_re_on_smooth = bn_re_on_smooth + self.asf_cfg = asf_cfg self.conv_after_concat = conv_after_concat self.lateral_convs = ModuleList() self.smooth_convs = ModuleList() @@ -88,6 +97,24 @@ def __init__(self, self.lateral_convs.append(l_conv) self.smooth_convs.append(smooth_conv) + + if self.asf_cfg is not None: + self.asf_conv = ConvModule( + out_channels * self.num_outs, + out_channels * self.num_outs, + 3, + padding=1, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + inplace=False) + if self.asf_cfg['attention_type'] == 'ScaleChannelSpatial': + self.asf_attn = ScaleChannelSpatialAttention( + self.out_channels * self.num_outs, + (self.out_channels * self.num_outs) // 4, self.num_outs) + else: + raise NotImplementedError + if self.conv_after_concat: norm_cfg = dict(type='BN') act_cfg = dict(type='ReLU') @@ -135,9 +162,110 @@ def forward(self, inputs): for i, out in enumerate(outs): outs[i] = F.interpolate( outs[i], size=outs[0].shape[2:], mode='nearest') + out = torch.cat(outs, dim=1) + if self.asf_cfg is not None: + asf_feature = self.asf_conv(out) + attention = self.asf_attn(asf_feature) + enhanced_feature = [] + for i, out in enumerate(outs): + enhanced_feature.append(attention[:, i:i + 1] * outs[i]) + out = torch.cat(enhanced_feature, dim=1) if self.conv_after_concat: out = self.out_conv(out) return out + + +class ScaleChannelSpatialAttention(BaseModule): + """Spatial Attention module in Real-Time Scene Text Detection with + Differentiable Binarization and Adaptive Scale Fusion. + + This was partially adapted from https://github.com/MhLiao/DB + + Args: + in_channels (int): A numbers of input channels. + c_wise_channels (int): Number of channel-wise attention channels. + out_channels (int): Number of output channels. + init_cfg (dict or list[dict], optional): Initialization configs. + """ + + def __init__(self, + in_channels, + c_wise_channels, + out_channels, + init_cfg=[dict(type='Kaiming', layer='Conv', bias=0)]): + super().__init__(init_cfg=init_cfg) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # Channel Wise + self.channel_wise = Sequential( + ConvModule( + in_channels, + c_wise_channels, + 1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + inplace=False), + ConvModule( + c_wise_channels, + in_channels, + 1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='Sigmoid'), + inplace=False)) + # Spatial Wise + self.spatial_wise = Sequential( + ConvModule( + 1, + 1, + 3, + padding=1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + inplace=False), + ConvModule( + 1, + 1, + 1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='Sigmoid'), + inplace=False)) + # Attention Wise + self.attention_wise = ConvModule( + in_channels, + out_channels, + 1, + bias=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='Sigmoid'), + inplace=False) + + @auto_fp16() + def forward(self, inputs): + """ + Args: + inputs (Tensor): A concat FPN feature tensor that has the shape of + :math:`(N, C, H, W)`. + + Returns: + Tensor: An attention map of shape :math:`(N, C_{out}, H, W)` + where :math:`C_{out}` is ``out_channels``. + """ + out = self.avg_pool(inputs) + out = self.channel_wise(out) + out = out + inputs + inputs = torch.mean(out, dim=1, keepdim=True) + out = self.spatial_wise(inputs) + out + out = self.attention_wise(out) + + return out diff --git a/mmocr/models/textdet/postprocess/db_postprocessor.py b/mmocr/models/textdet/postprocess/db_postprocessor.py index d9dbbeb2d..f185b6332 100644 --- a/mmocr/models/textdet/postprocess/db_postprocessor.py +++ b/mmocr/models/textdet/postprocess/db_postprocessor.py @@ -21,6 +21,7 @@ class DBPostprocessor(BasePostprocessor): min_text_width (int): The minimum width of boundary polygon/box predicted. unclip_ratio (float): The unclip ratio for text regions dilation. + epsilon_ratio (float): The epsilon ratio for approximation accuracy. max_candidates (int): The maximum candidate number. """ @@ -30,6 +31,7 @@ def __init__(self, min_text_score=0.3, min_text_width=5, unclip_ratio=1.5, + epsilon_ratio=0.01, max_candidates=3000, **kwargs): super().__init__(text_repr_type) @@ -37,6 +39,7 @@ def __init__(self, self.min_text_score = min_text_score self.min_text_width = min_text_width self.unclip_ratio = unclip_ratio + self.epsilon_ratio = epsilon_ratio self.max_candidates = max_candidates def __call__(self, preds): @@ -62,7 +65,7 @@ def __call__(self, preds): for i, poly in enumerate(contours): if i > self.max_candidates: break - epsilon = 0.01 * cv2.arcLength(poly, True) + epsilon = self.epsilon_ratio * cv2.arcLength(poly, True) approx = cv2.approxPolyDP(poly, epsilon, True) points = approx.reshape((-1, 2)) if points.shape[0] < 4: diff --git a/mmocr/utils/ocr.py b/mmocr/utils/ocr.py index 9304b20ce..c0d6d8dfc 100755 --- a/mmocr/utils/ocr.py +++ b/mmocr/utils/ocr.py @@ -201,6 +201,13 @@ def __init__(self, 'dbnet/' 'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth' }, + 'DBPP_r50': { + 'config': + 'dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py', + 'ckpt': + 'dbnet/' + 'dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth' + }, 'DRRG': { 'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py', diff --git a/tests/test_models/test_textdet_neck.py b/tests/test_models/test_textdet_neck.py index 7bee9d7e9..4410a4d13 100644 --- a/tests/test_models/test_textdet_neck.py +++ b/tests/test_models/test_textdet_neck.py @@ -9,14 +9,20 @@ def test_fpnc(): in_channels = [64, 128, 256, 512] size = [112, 56, 28, 14] + asf_cfgs = [ + None, + dict(attention_type='ScaleChannelSpatial'), + ] for flag in [False, True]: - fpnc = FPNC( - in_channels=in_channels, - bias_on_lateral=flag, - bn_re_on_lateral=flag, - bias_on_smooth=flag, - bn_re_on_smooth=flag, - conv_after_concat=flag) + for asf_cfg in asf_cfgs: + fpnc = FPNC( + in_channels=in_channels, + bias_on_lateral=flag, + bn_re_on_lateral=flag, + bias_on_smooth=flag, + bn_re_on_smooth=flag, + asf_cfg=asf_cfg, + conv_after_concat=flag) fpnc.init_weights() inputs = [] for i in range(4):