From 3b6fed14034d3b7b27b3c7dbd8f8e0e32549ee5c Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Thu, 2 Feb 2023 17:21:09 +0800 Subject: [PATCH] abcnetv2 train --- projects/ABCNet/abcnet/model/__init__.py | 1 + .../ABCNet/abcnet/model/abcnet_det_head.py | 4 +- .../abcnet/model/abcnet_det_postprocessor.py | 4 +- .../abcnet/model/bezier_roi_extractor.py | 2 +- projects/ABCNet/abcnet/model/bifpn.py | 15 +++-- projects/ABCNet/abcnet/model/rec_roi_head.py | 48 +++++++++++-- .../abcnet/model/task_utils/__init__.py | 5 ++ .../abcnet/model/task_utils/assigner.py | 20 ++++++ .../ABCNet/abcnet/model/task_utils/sampler.py | 26 +++++++ .../ABCNet/config/_base_/default_runtime.py | 2 +- .../_base_/schedules/schedule_sgd_500e.py | 2 +- .../_base_abcnet-v2_resnet50_bifpn.py | 67 ++++++++++++++++++- ...abcnet-v2_resnet50_bifpn_500e_icdar2015.py | 16 ++++- 13 files changed, 188 insertions(+), 24 deletions(-) create mode 100644 projects/ABCNet/abcnet/model/task_utils/__init__.py create mode 100644 projects/ABCNet/abcnet/model/task_utils/assigner.py create mode 100644 projects/ABCNet/abcnet/model/task_utils/sampler.py diff --git a/projects/ABCNet/abcnet/model/__init__.py b/projects/ABCNet/abcnet/model/__init__.py index f22d9b4f1f..d466f0a543 100644 --- a/projects/ABCNet/abcnet/model/__init__.py +++ b/projects/ABCNet/abcnet/model/__init__.py @@ -12,6 +12,7 @@ from .bifpn import BiFPN from .coordinate_head import CoordinateHead from .rec_roi_head import RecRoIHead +from .task_utils import * # noqa: F401,F403 __all__ = [ 'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone', diff --git a/projects/ABCNet/abcnet/model/abcnet_det_head.py b/projects/ABCNet/abcnet/model/abcnet_det_head.py index 4eb45d905e..12c0412d1d 100644 --- a/projects/ABCNet/abcnet/model/abcnet_det_head.py +++ b/projects/ABCNet/abcnet/model/abcnet_det_head.py @@ -181,8 +181,8 @@ def forward_single(self, x, scale, stride): # float to avoid overflow when enabling FP16 if self.use_scale: bbox_pred = scale(bbox_pred).float() - else: - bbox_pred = bbox_pred.float() + # else: + # bbox_pred = bbox_pred.float() if self.norm_on_bbox: # bbox_pred needed for gradient computation has been modified # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace diff --git a/projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py b/projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py index db9a4d141c..185008edf2 100644 --- a/projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py +++ b/projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py @@ -216,8 +216,8 @@ def __call__(self, pred_results, data_samples, training: bool = False): Returns: list[TextDetDataSample]: Batch of post-processed datasamples. """ - if training: - return data_samples + # if training: + # return data_samples cfg = self.train_cfg if training else self.test_cfg if cfg is None: cfg = {} diff --git a/projects/ABCNet/abcnet/model/bezier_roi_extractor.py b/projects/ABCNet/abcnet/model/bezier_roi_extractor.py index a4848d18e7..ace66fa038 100644 --- a/projects/ABCNet/abcnet/model/bezier_roi_extractor.py +++ b/projects/ABCNet/abcnet/model/bezier_roi_extractor.py @@ -88,7 +88,7 @@ def forward(self, feats: Tuple[Tensor], # convert fp32 to fp16 when amp is on rois = rois.type_as(feats[0]) out_size = self.roi_layers[0].output_size - feats = feats[:3] + # feats = feats[:3] num_levels = len(feats) roi_feats = feats[0].new_zeros( rois.size(0), self.out_channels, *out_size) diff --git a/projects/ABCNet/abcnet/model/bifpn.py b/projects/ABCNet/abcnet/model/bifpn.py index 7f117dffe6..3d45322d00 100644 --- a/projects/ABCNet/abcnet/model/bifpn.py +++ b/projects/ABCNet/abcnet/model/bifpn.py @@ -170,10 +170,12 @@ def __init__(self, self.bifpn_convs = nn.ModuleList() # weighted self.weight_two_nodes = nn.Parameter( - torch.Tensor(2, levels).fill_(init)) + torch.Tensor(2, levels).fill_(init), requires_grad=True) + self.weight_three_nodes = nn.Parameter( - torch.Tensor(3, levels - 2).fill_(init)) - self.relu = nn.ReLU() + torch.Tensor(3, levels - 2).fill_(init), requires_grad=True) + + # self.relu = nn.ReLU(inplace=False) for _ in range(2): for _ in range(self.levels - 1): # 1,2,3 fpn_conv = nn.Sequential( @@ -193,9 +195,10 @@ def forward(self, inputs): # build top-down and down-top path with stack levels = self.levels # w relu - w1 = self.relu(self.weight_two_nodes) - w1 /= torch.sum(w1, dim=0) + self.eps # normalize - w2 = self.relu(self.weight_three_nodes) + + _w1 = F.relu(self.weight_two_nodes) + w1 = _w1 / (torch.sum(_w1, dim=0) + self.eps) # normalize + w2 = F.relu(self.weight_three_nodes) # w2 /= torch.sum(w2, dim=0) + self.eps # normalize # build top-down idx_bifpn = 0 diff --git a/projects/ABCNet/abcnet/model/rec_roi_head.py b/projects/ABCNet/abcnet/model/rec_roi_head.py index a102902c53..9349560e5d 100644 --- a/projects/ABCNet/abcnet/model/rec_roi_head.py +++ b/projects/ABCNet/abcnet/model/rec_roi_head.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple +from typing import Optional, Sequence, Tuple from mmengine.structures import LabelData from torch import Tensor @@ -15,14 +15,19 @@ class RecRoIHead(BaseRoIHead): """Simplest base roi head including one bbox head and one mask head.""" def __init__(self, - neck=None, + inputs_indices: Optional[Sequence] = None, + neck: OptMultiConfig = None, + assigner: OptMultiConfig = None, sampler: OptMultiConfig = None, roi_extractor: OptMultiConfig = None, rec_head: OptMultiConfig = None, init_cfg=None): super().__init__(init_cfg) - if sampler is not None: - self.sampler = TASK_UTILS.build(sampler) + self.inputs_indices = inputs_indices + self.assigner = assigner + if assigner is not None: + self.assigner = TASK_UTILS.build(assigner) + self.sampler = TASK_UTILS.build(sampler) if neck is not None: self.neck = MODELS.build(neck) self.roi_extractor = MODELS.build(roi_extractor) @@ -43,11 +48,39 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict: Returns: dict[str, Tensor]: A dictionary of loss components """ - proposals = [ - ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples - ] + + if self.inputs_indices is not None: + inputs = [inputs[i] for i in self.inputs_indices] + # proposals = [ + # ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples + # ] + proposals = list() + for ds in data_samples: + pred_instances = ds.pred_instances + gt_instances = ds.gt_instances + # # assign + # gt_beziers = gt_instances.beziers + # pred_beziers = pred_instances.beziers + # assign_index = [ + # int( + # torch.argmin( + # torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1))) + # for i in range(len(pred_beziers)) + # ] + # proposal = InstanceData() + # proposal.texts = gt_instances.texts + gt_instances[ + # assign_index].texts + # proposal.beziers = torch.cat( + # [gt_instances.beziers, pred_instances.beziers], dim=0) + if self.assigner: + gt_instances, pred_instances = self.assigner.assign( + gt_instances, pred_instances) + proposal = self.sampler.sample(gt_instances, pred_instances) + proposals.append(proposal) proposals = [p for p in proposals if len(p) > 0] + if hasattr(self, 'neck') and self.neck is not None: + inputs = self.neck(inputs) bbox_feats = self.roi_extractor(inputs, proposals) rec_data_samples = [ TextRecogDataSample(gt_text=LabelData(item=text)) @@ -57,6 +90,7 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict: def predict(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> RecSampleList: + inputs = inputs[:3] if hasattr(self, 'neck') and self.neck is not None: inputs = self.neck(inputs) pred_instances = [ds.pred_instances for ds in data_samples] diff --git a/projects/ABCNet/abcnet/model/task_utils/__init__.py b/projects/ABCNet/abcnet/model/task_utils/__init__.py new file mode 100644 index 0000000000..4339caf252 --- /dev/null +++ b/projects/ABCNet/abcnet/model/task_utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .assigner import L1DistanceAssigner +from .sampler import ConcatSampler, OnlyGTSampler + +__all__ = ['L1DistanceAssigner', 'ConcatSampler', 'OnlyGTSampler'] diff --git a/projects/ABCNet/abcnet/model/task_utils/assigner.py b/projects/ABCNet/abcnet/model/task_utils/assigner.py new file mode 100644 index 0000000000..c2c7ff2848 --- /dev/null +++ b/projects/ABCNet/abcnet/model/task_utils/assigner.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmocr.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class L1DistanceAssigner: + + def assign(self, gt_instances, pred_instances): + gt_beziers = gt_instances.beziers + pred_beziers = pred_instances.beziers + assign_index = [ + int( + torch.argmin( + torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1))) + for i in range(len(pred_beziers)) + ] + pred_instances.assign_index = assign_index + return gt_instances, pred_instances diff --git a/projects/ABCNet/abcnet/model/task_utils/sampler.py b/projects/ABCNet/abcnet/model/task_utils/sampler.py new file mode 100644 index 0000000000..01ba2d9f01 --- /dev/null +++ b/projects/ABCNet/abcnet/model/task_utils/sampler.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.structures import InstanceData + +from mmocr.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class ConcatSampler: + + def sample(self, gt_instances, pred_instances): + if len(pred_instances) == 0: + return gt_instances + proposals = InstanceData() + proposals.texts = gt_instances.texts + gt_instances[ + pred_instances.assign_index].texts + proposals.beziers = torch.cat( + [gt_instances.beziers, pred_instances.beziers], dim=0) + return proposals + + +@TASK_UTILS.register_module() +class OnlyGTSampler: + + def sample(self, gt_instances, pred_instances): + return gt_instances[~gt_instances.ignored] diff --git a/projects/ABCNet/config/_base_/default_runtime.py b/projects/ABCNet/config/_base_/default_runtime.py index 8a1f12a380..4b9b72c53f 100644 --- a/projects/ABCNet/config/_base_/default_runtime.py +++ b/projects/ABCNet/config/_base_/default_runtime.py @@ -1,6 +1,6 @@ default_scope = 'mmocr' env_cfg = dict( - cudnn_benchmark=True, + cudnn_benchmark=False, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), ) diff --git a/projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py b/projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py index 431c48ff9d..9b5d3f5961 100644 --- a/projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py +++ b/projects/ABCNet/config/_base_/schedules/schedule_sgd_500e.py @@ -3,7 +3,7 @@ type='OptimWrapper', optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001), clip_grad=dict(type='value', clip_value=1)) -train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=20) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=10) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') # learning policy diff --git a/projects/ABCNet/config/abcnet_v2/_base_abcnet-v2_resnet50_bifpn.py b/projects/ABCNet/config/abcnet_v2/_base_abcnet-v2_resnet50_bifpn.py index 3f16df00b7..5f38102d81 100644 --- a/projects/ABCNet/config/abcnet_v2/_base_abcnet-v2_resnet50_bifpn.py +++ b/projects/ABCNet/config/abcnet_v2/_base_abcnet-v2_resnet50_bifpn.py @@ -68,7 +68,25 @@ std=0.01, bias=-4.59511985013459), # -log((1-p)/p) where p=0.01 ), - module_loss=None, + module_loss=dict( + type='ABCNetDetModuleLoss', + num_classes=num_classes, + strides=strides, + center_sampling=True, + center_sample_radius=1.5, + bbox_coder=bbox_coder, + norm_on_bbox=norm_on_bbox, + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=use_sigmoid_cls, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0), + loss_centerness=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0)), postprocessor=dict( type='ABCNetDetPostprocessor', # rescale_fields=['polygons', 'bboxes'], @@ -76,13 +94,20 @@ strides=[8, 16, 32, 64, 128], bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'), with_bezier=True, + train_cfg=dict( + # rescale_fields=['polygon', 'bboxes', 'bezier'], + nms_pre=1000, + nms=dict(type='nms', iou_threshold=0.4), + score_thr=0.7), test_cfg=dict( # rescale_fields=['polygon', 'bboxes', 'bezier'], nms_pre=1000, nms=dict(type='nms', iou_threshold=0.4), - score_thr=0.3))), + score_thr=0.4))), roi_head=dict( type='RecRoIHead', + assigner=dict(type='L1DistanceAssigner'), + sampler=dict(type='ConcatSampler'), neck=dict(type='CoordinateHead'), roi_extractor=dict( type='BezierRoIExtractor', @@ -97,7 +122,14 @@ decoder=dict( type='ABCNetRecDecoder', dictionary=dictionary, - postprocessor=dict(type='AttentionPostprocessor'), + postprocessor=dict( + type='AttentionPostprocessor', + ignore_chars=['padding', 'unknown']), + module_loss=dict( + type='CEModuleLoss', + ignore_first_char=False, + ignore_char=-1, + reduction='mean'), max_seq_len=25))), postprocessor=dict( type='ABCNetPostprocessor', @@ -120,3 +152,32 @@ type='PackTextDetInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] + +train_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + color_type='color_ignore_orientation'), + dict( + type='LoadOCRAnnotations', + with_polygon=True, + with_bbox=True, + with_label=True, + with_text=True), + dict(type='RemoveIgnored'), + dict(type='RandomCrop', min_side_ratio=0.1), + dict( + type='RandomRotate', + max_angle=30, + pad_with_fixed_color=True, + use_canvas=True), + dict( + type='RandomChoiceResize', + scales=[(980, 2900), (1044, 2900), (1108, 2900), (1172, 2900), + (1236, 2900), (1300, 2900), (1364, 2900), (1428, 2900), + (1492, 2900)], + keep_ratio=True), + dict( + type='PackTextDetInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) +] diff --git a/projects/ABCNet/config/abcnet_v2/abcnet-v2_resnet50_bifpn_500e_icdar2015.py b/projects/ABCNet/config/abcnet_v2/abcnet-v2_resnet50_bifpn_500e_icdar2015.py index 5b51f56243..ca201ec083 100644 --- a/projects/ABCNet/config/abcnet_v2/abcnet-v2_resnet50_bifpn_500e_icdar2015.py +++ b/projects/ABCNet/config/abcnet_v2/abcnet-v2_resnet50_bifpn_500e_icdar2015.py @@ -2,12 +2,21 @@ '_base_abcnet-v2_resnet50_bifpn.py', '../_base_/datasets/icdar2015.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_sgd_500e.py', ] # dataset settings +icdar2015_textspotting_train = _base_.icdar2015_textspotting_train +icdar2015_textspotting_train.pipeline = _base_.train_pipeline icdar2015_textspotting_test = _base_.icdar2015_textspotting_test icdar2015_textspotting_test.pipeline = _base_.test_pipeline +train_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=icdar2015_textspotting_train) val_dataloader = dict( batch_size=1, num_workers=4, @@ -20,4 +29,9 @@ val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') -custom_imports = dict(imports=['abcnet'], allow_failed_imports=False) +custom_imports = dict( + imports=['projects.ABCNet.abcnet'], allow_failed_imports=False) + +load_from = 'checkpoints/abcnet-v2_resnet50_bifpn_500e_pretrain.pth' + +find_unused_parameters = True