Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] abcnetv2 train #1704

Open
wants to merge 1 commit into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/ABCNet/abcnet/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .bifpn import BiFPN
from .coordinate_head import CoordinateHead
from .rec_roi_head import RecRoIHead
from .task_utils import * # noqa: F401,F403

__all__ = [
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',
Expand Down
6 changes: 3 additions & 3 deletions projects/ABCNet/abcnet/model/abcnet_det_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ABCNetDetHead(BaseTextDetHead):

def __init__(self,
in_channels,
module_loss=dict(type='ABCNetLoss'),
module_loss=dict(type='ABCNetDetModuleLoss'),
postprocessor=dict(type='ABCNetDetPostprocessor'),
num_classes=1,
strides=(4, 8, 16, 32, 64),
Expand Down Expand Up @@ -181,8 +181,8 @@ def forward_single(self, x, scale, stride):
# float to avoid overflow when enabling FP16
if self.use_scale:
bbox_pred = scale(bbox_pred).float()
else:
bbox_pred = bbox_pred.float()
# else:
# bbox_pred = bbox_pred.float()
if self.norm_on_bbox:
# bbox_pred needed for gradient computation has been modified
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
Expand Down
18 changes: 13 additions & 5 deletions projects/ABCNet/abcnet/model/abcnet_det_module_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
from mmdet.models.task_modules.prior_generators import MlvlPointGenerator
from mmdet.models.utils import multi_apply
from mmdet.utils import reduce_mean
Expand Down Expand Up @@ -149,11 +150,17 @@ def forward(self, inputs: Tuple[Tensor],
avg_factor=centerness_denorm)
loss_centerness = self.loss_centerness(
pos_centerness, pos_centerness_targets, avg_factor=num_pos)
loss_bezier = self.loss_bezier(
pos_bezier_preds,
pos_bezier_targets,
weight=pos_centerness_targets[:, None],
avg_factor=centerness_denorm)
# loss_bezier = self.loss_bezier(
# pos_bezier_preds,
# pos_bezier_targets,
# weight=pos_centerness_targets[:, None],
# avg_factor=centerness_denorm)

loss_bezier = F.smooth_l1_loss(
pos_bezier_preds, pos_bezier_targets, reduction='none')
loss_bezier = (
(loss_bezier.mean(dim=-1) * pos_centerness_targets).sum() /
centerness_denorm)
else:
loss_bbox = pos_bbox_preds.sum()
loss_centerness = pos_centerness.sum()
Expand Down Expand Up @@ -250,6 +257,7 @@ def _get_targets_single(self, data_sample: TextDetDataSample,
polygons = gt_instances.polygons
beziers = gt_bboxes.new([poly2bezier(poly) for poly in polygons])
gt_instances.beziers = beziers
# beziers = gt_instances.beziers
if num_gts == 0:
return gt_labels.new_full((num_points,), self.num_classes), \
gt_bboxes.new_zeros((num_points, 4)), \
Expand Down
4 changes: 2 additions & 2 deletions projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def __call__(self, pred_results, data_samples, training: bool = False):
Returns:
list[TextDetDataSample]: Batch of post-processed datasamples.
"""
if training:
return data_samples
# if training:
# return data_samples
Comment on lines +219 to +220
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

del

cfg = self.train_cfg if training else self.test_cfg
if cfg is None:
cfg = {}
Expand Down
2 changes: 1 addition & 1 deletion projects/ABCNet/abcnet/model/bezier_roi_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forward(self, feats: Tuple[Tensor],
# convert fp32 to fp16 when amp is on
rois = rois.type_as(feats[0])
out_size = self.roi_layers[0].output_size
feats = feats[:3]
# feats = feats[:3]
num_levels = len(feats)
roi_feats = feats[0].new_zeros(
rois.size(0), self.out_channels, *out_size)
Expand Down
15 changes: 9 additions & 6 deletions projects/ABCNet/abcnet/model/bifpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ def __init__(self,
self.bifpn_convs = nn.ModuleList()
# weighted
self.weight_two_nodes = nn.Parameter(
torch.Tensor(2, levels).fill_(init))
torch.Tensor(2, levels).fill_(init), requires_grad=True)

self.weight_three_nodes = nn.Parameter(
torch.Tensor(3, levels - 2).fill_(init))
self.relu = nn.ReLU()
torch.Tensor(3, levels - 2).fill_(init), requires_grad=True)

# self.relu = nn.ReLU(inplace=False)
for _ in range(2):
for _ in range(self.levels - 1): # 1,2,3
fpn_conv = nn.Sequential(
Expand All @@ -193,9 +195,10 @@ def forward(self, inputs):
# build top-down and down-top path with stack
levels = self.levels
# w relu
w1 = self.relu(self.weight_two_nodes)
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
w2 = self.relu(self.weight_three_nodes)

_w1 = F.relu(self.weight_two_nodes)
w1 = _w1 / (torch.sum(_w1, dim=0) + self.eps) # normalize
w2 = F.relu(self.weight_three_nodes)
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
# build top-down
idx_bifpn = 0
Expand Down
48 changes: 41 additions & 7 deletions projects/ABCNet/abcnet/model/rec_roi_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import Optional, Sequence, Tuple

from mmengine.structures import LabelData
from torch import Tensor
Expand All @@ -15,14 +15,19 @@ class RecRoIHead(BaseRoIHead):
"""Simplest base roi head including one bbox head and one mask head."""

def __init__(self,
neck=None,
inputs_indices: Optional[Sequence] = None,
neck: OptMultiConfig = None,
assigner: OptMultiConfig = None,
sampler: OptMultiConfig = None,
roi_extractor: OptMultiConfig = None,
rec_head: OptMultiConfig = None,
init_cfg=None):
super().__init__(init_cfg)
if sampler is not None:
self.sampler = TASK_UTILS.build(sampler)
self.inputs_indices = inputs_indices
self.assigner = assigner
if assigner is not None:
self.assigner = TASK_UTILS.build(assigner)
self.sampler = TASK_UTILS.build(sampler)
if neck is not None:
self.neck = MODELS.build(neck)
self.roi_extractor = MODELS.build(roi_extractor)
Expand All @@ -43,11 +48,39 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
Returns:
dict[str, Tensor]: A dictionary of loss components
"""
proposals = [
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
]

if self.inputs_indices is not None:
inputs = [inputs[i] for i in self.inputs_indices]
# proposals = [
# ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
# ]
proposals = list()
for ds in data_samples:
pred_instances = ds.pred_instances
gt_instances = ds.gt_instances
# # assign
# gt_beziers = gt_instances.beziers
# pred_beziers = pred_instances.beziers
# assign_index = [
# int(
# torch.argmin(
# torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
# for i in range(len(pred_beziers))
# ]
# proposal = InstanceData()
# proposal.texts = gt_instances.texts + gt_instances[
# assign_index].texts
# proposal.beziers = torch.cat(
# [gt_instances.beziers, pred_instances.beziers], dim=0)
if self.assigner:
gt_instances, pred_instances = self.assigner.assign(
gt_instances, pred_instances)
proposal = self.sampler.sample(gt_instances, pred_instances)
proposals.append(proposal)

proposals = [p for p in proposals if len(p) > 0]
if hasattr(self, 'neck') and self.neck is not None:
inputs = self.neck(inputs)
bbox_feats = self.roi_extractor(inputs, proposals)
rec_data_samples = [
TextRecogDataSample(gt_text=LabelData(item=text))
Expand All @@ -57,6 +90,7 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:

def predict(self, inputs: Tuple[Tensor],
data_samples: DetSampleList) -> RecSampleList:
inputs = inputs[:3]
if hasattr(self, 'neck') and self.neck is not None:
inputs = self.neck(inputs)
pred_instances = [ds.pred_instances for ds in data_samples]
Expand Down
5 changes: 5 additions & 0 deletions projects/ABCNet/abcnet/model/task_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .assigner import L1DistanceAssigner
from .sampler import ConcatSampler, OnlyGTSampler

__all__ = ['L1DistanceAssigner', 'ConcatSampler', 'OnlyGTSampler']
20 changes: 20 additions & 0 deletions projects/ABCNet/abcnet/model/task_utils/assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmocr.registry import TASK_UTILS


@TASK_UTILS.register_module()
class L1DistanceAssigner:

def assign(self, gt_instances, pred_instances):
gt_beziers = gt_instances.beziers
pred_beziers = pred_instances.beziers
assign_index = [
int(
torch.argmin(
torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
for i in range(len(pred_beziers))
]
pred_instances.assign_index = assign_index
return gt_instances, pred_instances
26 changes: 26 additions & 0 deletions projects/ABCNet/abcnet/model/task_utils/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.structures import InstanceData

from mmocr.registry import TASK_UTILS


@TASK_UTILS.register_module()
class ConcatSampler:

def sample(self, gt_instances, pred_instances):
if len(pred_instances) == 0:
return gt_instances
proposals = InstanceData()
proposals.texts = gt_instances.texts + gt_instances[
pred_instances.assign_index].texts
proposals.beziers = torch.cat(
[gt_instances.beziers, pred_instances.beziers], dim=0)
return proposals


@TASK_UTILS.register_module()
class OnlyGTSampler:

def sample(self, gt_instances, pred_instances):
return gt_instances[~gt_instances.ignored]
2 changes: 1 addition & 1 deletion projects/ABCNet/config/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
default_scope = 'mmocr'
env_cfg = dict(
cudnn_benchmark=True,
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001),
clip_grad=dict(type='value', clip_value=1))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=20)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning policy
Expand Down
68 changes: 65 additions & 3 deletions projects/ABCNet/config/abcnet_v2/_base_abcnet-v2_resnet50_bifpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,47 @@
std=0.01,
bias=-4.59511985013459), # -log((1-p)/p) where p=0.01
),
module_loss=None,
module_loss=dict(
type='ABCNetDetModuleLoss',
num_classes=num_classes,
strides=strides,
center_sampling=True,
center_sample_radius=1.5,
bbox_coder=bbox_coder,
norm_on_bbox=norm_on_bbox,
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=use_sigmoid_cls,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0),
loss_centerness=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0)),
postprocessor=dict(
type='ABCNetDetPostprocessor',
# rescale_fields=['polygons', 'bboxes'],
use_sigmoid_cls=use_sigmoid_cls,
strides=[8, 16, 32, 64, 128],
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
with_bezier=True,
train_cfg=dict(
# rescale_fields=['polygon', 'bboxes', 'bezier'],
nms_pre=1000,
nms=dict(type='nms', iou_threshold=0.4),
score_thr=0.7),
test_cfg=dict(
# rescale_fields=['polygon', 'bboxes', 'bezier'],
nms_pre=1000,
nms=dict(type='nms', iou_threshold=0.4),
score_thr=0.3))),
score_thr=0.4))),
roi_head=dict(
type='RecRoIHead',
inputs_indices=(0, 1, 2),
assigner=dict(type='L1DistanceAssigner'),
sampler=dict(type='ConcatSampler'),
neck=dict(type='CoordinateHead'),
roi_extractor=dict(
type='BezierRoIExtractor',
Expand All @@ -97,7 +123,14 @@
decoder=dict(
type='ABCNetRecDecoder',
dictionary=dictionary,
postprocessor=dict(type='AttentionPostprocessor'),
postprocessor=dict(
type='AttentionPostprocessor',
ignore_chars=['padding', 'unknown']),
module_loss=dict(
type='CEModuleLoss',
ignore_first_char=False,
ignore_char=-1,
reduction='mean'),
max_seq_len=25))),
postprocessor=dict(
type='ABCNetPostprocessor',
Expand All @@ -120,3 +153,32 @@
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]

train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
color_type='color_ignore_orientation'),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True,
with_text=True),
dict(type='RemoveIgnored'),
dict(type='RandomCrop', min_side_ratio=0.1),
dict(
type='RandomRotate',
max_angle=30,
pad_with_fixed_color=True,
use_canvas=True),
dict(
type='RandomChoiceResize',
scales=[(980, 2900), (1044, 2900), (1108, 2900), (1172, 2900),
(1236, 2900), (1300, 2900), (1364, 2900), (1428, 2900),
(1492, 2900)],
keep_ratio=True),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]
Loading