Skip to content

Commit

Permalink
abcnetv2 train
Browse files Browse the repository at this point in the history
  • Loading branch information
Harold-lkk committed Feb 2, 2023
1 parent bf41194 commit 3b6fed1
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 24 deletions.
1 change: 1 addition & 0 deletions projects/ABCNet/abcnet/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .bifpn import BiFPN
from .coordinate_head import CoordinateHead
from .rec_roi_head import RecRoIHead
from .task_utils import * # noqa: F401,F403

__all__ = [
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',
Expand Down
4 changes: 2 additions & 2 deletions projects/ABCNet/abcnet/model/abcnet_det_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def forward_single(self, x, scale, stride):
# float to avoid overflow when enabling FP16
if self.use_scale:
bbox_pred = scale(bbox_pred).float()
else:
bbox_pred = bbox_pred.float()
# else:
# bbox_pred = bbox_pred.float()
if self.norm_on_bbox:
# bbox_pred needed for gradient computation has been modified
# by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
Expand Down
4 changes: 2 additions & 2 deletions projects/ABCNet/abcnet/model/abcnet_det_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def __call__(self, pred_results, data_samples, training: bool = False):
Returns:
list[TextDetDataSample]: Batch of post-processed datasamples.
"""
if training:
return data_samples
# if training:
# return data_samples
cfg = self.train_cfg if training else self.test_cfg
if cfg is None:
cfg = {}
Expand Down
2 changes: 1 addition & 1 deletion projects/ABCNet/abcnet/model/bezier_roi_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forward(self, feats: Tuple[Tensor],
# convert fp32 to fp16 when amp is on
rois = rois.type_as(feats[0])
out_size = self.roi_layers[0].output_size
feats = feats[:3]
# feats = feats[:3]
num_levels = len(feats)
roi_feats = feats[0].new_zeros(
rois.size(0), self.out_channels, *out_size)
Expand Down
15 changes: 9 additions & 6 deletions projects/ABCNet/abcnet/model/bifpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ def __init__(self,
self.bifpn_convs = nn.ModuleList()
# weighted
self.weight_two_nodes = nn.Parameter(
torch.Tensor(2, levels).fill_(init))
torch.Tensor(2, levels).fill_(init), requires_grad=True)

self.weight_three_nodes = nn.Parameter(
torch.Tensor(3, levels - 2).fill_(init))
self.relu = nn.ReLU()
torch.Tensor(3, levels - 2).fill_(init), requires_grad=True)

# self.relu = nn.ReLU(inplace=False)
for _ in range(2):
for _ in range(self.levels - 1): # 1,2,3
fpn_conv = nn.Sequential(
Expand All @@ -193,9 +195,10 @@ def forward(self, inputs):
# build top-down and down-top path with stack
levels = self.levels
# w relu
w1 = self.relu(self.weight_two_nodes)
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
w2 = self.relu(self.weight_three_nodes)

_w1 = F.relu(self.weight_two_nodes)
w1 = _w1 / (torch.sum(_w1, dim=0) + self.eps) # normalize
w2 = F.relu(self.weight_three_nodes)
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
# build top-down
idx_bifpn = 0
Expand Down
48 changes: 41 additions & 7 deletions projects/ABCNet/abcnet/model/rec_roi_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import Optional, Sequence, Tuple

from mmengine.structures import LabelData
from torch import Tensor
Expand All @@ -15,14 +15,19 @@ class RecRoIHead(BaseRoIHead):
"""Simplest base roi head including one bbox head and one mask head."""

def __init__(self,
neck=None,
inputs_indices: Optional[Sequence] = None,
neck: OptMultiConfig = None,
assigner: OptMultiConfig = None,
sampler: OptMultiConfig = None,
roi_extractor: OptMultiConfig = None,
rec_head: OptMultiConfig = None,
init_cfg=None):
super().__init__(init_cfg)
if sampler is not None:
self.sampler = TASK_UTILS.build(sampler)
self.inputs_indices = inputs_indices
self.assigner = assigner
if assigner is not None:
self.assigner = TASK_UTILS.build(assigner)
self.sampler = TASK_UTILS.build(sampler)
if neck is not None:
self.neck = MODELS.build(neck)
self.roi_extractor = MODELS.build(roi_extractor)
Expand All @@ -43,11 +48,39 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:
Returns:
dict[str, Tensor]: A dictionary of loss components
"""
proposals = [
ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
]

if self.inputs_indices is not None:
inputs = [inputs[i] for i in self.inputs_indices]
# proposals = [
# ds.gt_instances[~ds.gt_instances.ignored] for ds in data_samples
# ]
proposals = list()
for ds in data_samples:
pred_instances = ds.pred_instances
gt_instances = ds.gt_instances
# # assign
# gt_beziers = gt_instances.beziers
# pred_beziers = pred_instances.beziers
# assign_index = [
# int(
# torch.argmin(
# torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
# for i in range(len(pred_beziers))
# ]
# proposal = InstanceData()
# proposal.texts = gt_instances.texts + gt_instances[
# assign_index].texts
# proposal.beziers = torch.cat(
# [gt_instances.beziers, pred_instances.beziers], dim=0)
if self.assigner:
gt_instances, pred_instances = self.assigner.assign(
gt_instances, pred_instances)
proposal = self.sampler.sample(gt_instances, pred_instances)
proposals.append(proposal)

proposals = [p for p in proposals if len(p) > 0]
if hasattr(self, 'neck') and self.neck is not None:
inputs = self.neck(inputs)
bbox_feats = self.roi_extractor(inputs, proposals)
rec_data_samples = [
TextRecogDataSample(gt_text=LabelData(item=text))
Expand All @@ -57,6 +90,7 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:

def predict(self, inputs: Tuple[Tensor],
data_samples: DetSampleList) -> RecSampleList:
inputs = inputs[:3]
if hasattr(self, 'neck') and self.neck is not None:
inputs = self.neck(inputs)
pred_instances = [ds.pred_instances for ds in data_samples]
Expand Down
5 changes: 5 additions & 0 deletions projects/ABCNet/abcnet/model/task_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .assigner import L1DistanceAssigner
from .sampler import ConcatSampler, OnlyGTSampler

__all__ = ['L1DistanceAssigner', 'ConcatSampler', 'OnlyGTSampler']
20 changes: 20 additions & 0 deletions projects/ABCNet/abcnet/model/task_utils/assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmocr.registry import TASK_UTILS


@TASK_UTILS.register_module()
class L1DistanceAssigner:

def assign(self, gt_instances, pred_instances):
gt_beziers = gt_instances.beziers
pred_beziers = pred_instances.beziers
assign_index = [
int(
torch.argmin(
torch.abs(gt_beziers - pred_beziers[i]).sum(dim=1)))
for i in range(len(pred_beziers))
]
pred_instances.assign_index = assign_index
return gt_instances, pred_instances
26 changes: 26 additions & 0 deletions projects/ABCNet/abcnet/model/task_utils/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.structures import InstanceData

from mmocr.registry import TASK_UTILS


@TASK_UTILS.register_module()
class ConcatSampler:

def sample(self, gt_instances, pred_instances):
if len(pred_instances) == 0:
return gt_instances
proposals = InstanceData()
proposals.texts = gt_instances.texts + gt_instances[
pred_instances.assign_index].texts
proposals.beziers = torch.cat(
[gt_instances.beziers, pred_instances.beziers], dim=0)
return proposals


@TASK_UTILS.register_module()
class OnlyGTSampler:

def sample(self, gt_instances, pred_instances):
return gt_instances[~gt_instances.ignored]
2 changes: 1 addition & 1 deletion projects/ABCNet/config/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
default_scope = 'mmocr'
env_cfg = dict(
cudnn_benchmark=True,
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001),
clip_grad=dict(type='value', clip_value=1))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=20)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=500, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning policy
Expand Down
67 changes: 64 additions & 3 deletions projects/ABCNet/config/abcnet_v2/_base_abcnet-v2_resnet50_bifpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,46 @@
std=0.01,
bias=-4.59511985013459), # -log((1-p)/p) where p=0.01
),
module_loss=None,
module_loss=dict(
type='ABCNetDetModuleLoss',
num_classes=num_classes,
strides=strides,
center_sampling=True,
center_sample_radius=1.5,
bbox_coder=bbox_coder,
norm_on_bbox=norm_on_bbox,
loss_cls=dict(
type='mmdet.FocalLoss',
use_sigmoid=use_sigmoid_cls,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=1.0),
loss_centerness=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0)),
postprocessor=dict(
type='ABCNetDetPostprocessor',
# rescale_fields=['polygons', 'bboxes'],
use_sigmoid_cls=use_sigmoid_cls,
strides=[8, 16, 32, 64, 128],
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
with_bezier=True,
train_cfg=dict(
# rescale_fields=['polygon', 'bboxes', 'bezier'],
nms_pre=1000,
nms=dict(type='nms', iou_threshold=0.4),
score_thr=0.7),
test_cfg=dict(
# rescale_fields=['polygon', 'bboxes', 'bezier'],
nms_pre=1000,
nms=dict(type='nms', iou_threshold=0.4),
score_thr=0.3))),
score_thr=0.4))),
roi_head=dict(
type='RecRoIHead',
assigner=dict(type='L1DistanceAssigner'),
sampler=dict(type='ConcatSampler'),
neck=dict(type='CoordinateHead'),
roi_extractor=dict(
type='BezierRoIExtractor',
Expand All @@ -97,7 +122,14 @@
decoder=dict(
type='ABCNetRecDecoder',
dictionary=dictionary,
postprocessor=dict(type='AttentionPostprocessor'),
postprocessor=dict(
type='AttentionPostprocessor',
ignore_chars=['padding', 'unknown']),
module_loss=dict(
type='CEModuleLoss',
ignore_first_char=False,
ignore_char=-1,
reduction='mean'),
max_seq_len=25))),
postprocessor=dict(
type='ABCNetPostprocessor',
Expand All @@ -120,3 +152,32 @@
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]

train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=file_client_args,
color_type='color_ignore_orientation'),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True,
with_text=True),
dict(type='RemoveIgnored'),
dict(type='RandomCrop', min_side_ratio=0.1),
dict(
type='RandomRotate',
max_angle=30,
pad_with_fixed_color=True,
use_canvas=True),
dict(
type='RandomChoiceResize',
scales=[(980, 2900), (1044, 2900), (1108, 2900), (1172, 2900),
(1236, 2900), (1300, 2900), (1364, 2900), (1428, 2900),
(1492, 2900)],
keep_ratio=True),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
'_base_abcnet-v2_resnet50_bifpn.py',
'../_base_/datasets/icdar2015.py',
'../_base_/default_runtime.py',
'../_base_/schedules/schedule_sgd_500e.py',
]

# dataset settings
icdar2015_textspotting_train = _base_.icdar2015_textspotting_train
icdar2015_textspotting_train.pipeline = _base_.train_pipeline
icdar2015_textspotting_test = _base_.icdar2015_textspotting_test
icdar2015_textspotting_test.pipeline = _base_.test_pipeline

train_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=icdar2015_textspotting_train)
val_dataloader = dict(
batch_size=1,
num_workers=4,
Expand All @@ -20,4 +29,9 @@
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

custom_imports = dict(imports=['abcnet'], allow_failed_imports=False)
custom_imports = dict(
imports=['projects.ABCNet.abcnet'], allow_failed_imports=False)

load_from = 'checkpoints/abcnet-v2_resnet50_bifpn_500e_pretrain.pth'

find_unused_parameters = True

0 comments on commit 3b6fed1

Please sign in to comment.