Skip to content

Commit

Permalink
abcnetv2 inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Harold-lkk committed Dec 27, 2022
1 parent e096df8 commit 0cf9be4
Show file tree
Hide file tree
Showing 7 changed files with 454 additions and 4 deletions.
5 changes: 4 additions & 1 deletion projects/ABCNet/abcnet/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from .abcnet_rec_decoder import ABCNetRecDecoder
from .abcnet_rec_encoder import ABCNetRecEncoder
from .bezier_roi_extractor import BezierRoIExtractor
from .bifpn import BiFPN
from .coordinate_head import CoordinateHead
from .only_rec_roi_head import OnlyRecRoIHead

__all__ = [
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',
'ABCNetRecDecoder', 'ABCNetRecEncoder', 'ABCNet', 'ABCNetRec',
'BezierRoIExtractor', 'OnlyRecRoIHead', 'ABCNetPostprocessor'
'BezierRoIExtractor', 'OnlyRecRoIHead', 'ABCNetPostprocessor', 'BiFPN',
'CoordinateHead'
]
3 changes: 1 addition & 2 deletions projects/ABCNet/abcnet/model/abcnet_rec_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def __init__(self, init_cfg=None):
stride=(2, 1),
bias='auto',
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU')),
nn.AvgPool2d(kernel_size=(2, 1), stride=1))
act_cfg=dict(type='ReLU')), nn.AdaptiveAvgPool2d((1, None)))

def forward(self, x):
return self.convs(x)
242 changes: 242 additions & 0 deletions projects/ABCNet/abcnet/model/bifpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmocr.registry import MODELS
from mmocr.utils import ConfigType, MultiConfig, OptConfigType


@MODELS.register_module()
class BiFPN(BaseModule):
"""illustration of a minimal bifpn unit P7_0 ------------------------->
P7_2 -------->
|-------------| ↑ ↓ |
P6_0 ---------> P6_1 ---------> P6_2 -------->
|-------------|--------------↑ ↑ ↓ | P5_0
---------> P5_1 ---------> P5_2 --------> |-------------|--------------↑
↑ ↓ | P4_0 ---------> P4_1 ---------> P4_2
--------> |-------------|--------------↑ ↑
|--------------↓ | P3_0 -------------------------> P3_2 -------->
"""

def __init__(self,
in_channels: List[int],
out_channels: int,
num_outs: int,
repeat_times: int = 2,
start_level: int = 0,
end_level: int = -1,
add_extra_convs: bool = False,
relu_before_extra_convs: bool = False,
no_norm_on_lateral: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None,
act_cfg: OptConfigType = None,
laterial_conv1x1: bool = False,
upsample_cfg: ConfigType = dict(mode='nearest'),
pool_cfg: ConfigType = dict(),
init_cfg: MultiConfig = dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super().__init__(init_cfg=init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.upsample_cfg = upsample_cfg.copy()
self.repeat_times = repeat_times
if end_level == -1 or end_level == self.num_ins - 1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level is not the last level, no extra level is allowed
self.backbone_end_level = end_level + 1
assert end_level < self.num_ins
assert num_outs == end_level - start_level + 1
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs

self.lateral_convs = nn.ModuleList()
self.extra_convs = nn.ModuleList()
self.bifpn_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
if in_channels[i] == out_channels:
l_conv = nn.Identity()
else:
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
act_cfg=act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)

for _ in range(repeat_times):
self.bifpn_convs.append(
BiFPNLayer(
channels=out_channels,
levels=num_outs,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
pool_cfg=pool_cfg))

# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0:
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
if in_channels == out_channels:
extra_fpn_conv = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1)
else:
extra_fpn_conv = nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.extra_convs.append(extra_fpn_conv)

def forward(self, inputs):

def extra_convs(inputs, extra_convs):
outputs = list()
for extra_conv in extra_convs:
inputs = extra_conv(inputs)
outputs.append(inputs)
return outputs

assert len(inputs) == len(self.in_channels)

# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
if self.num_outs > len(laterals) and self.add_extra_convs:
extra_source = inputs[self.backbone_end_level - 1]
for extra_conv in self.extra_convs:
extra_source = extra_conv(extra_source)
laterals.append(extra_source)

for bifpn_module in self.bifpn_convs:
laterals = bifpn_module(laterals)
outs = laterals

return tuple(outs)


def swish(x):
return x * x.sigmoid()


class BiFPNLayer(BaseModule):

def __init__(self,
channels,
levels,
init=0.5,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
upsample_cfg=None,
pool_cfg=None,
eps=0.0001,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.act_cfg = act_cfg
self.upsample_cfg = upsample_cfg
self.pool_cfg = pool_cfg
self.eps = eps
self.levels = levels
self.bifpn_convs = nn.ModuleList()
# weighted
self.weight_two_nodes = nn.Parameter(
torch.Tensor(2, levels).fill_(init))
self.weight_three_nodes = nn.Parameter(
torch.Tensor(3, levels - 2).fill_(init))
self.relu = nn.ReLU()
for _ in range(2):
for _ in range(self.levels - 1): # 1,2,3
fpn_conv = nn.Sequential(
ConvModule(
channels,
channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False))
self.bifpn_convs.append(fpn_conv)

def forward(self, inputs):
assert len(inputs) == self.levels
# build top-down and down-top path with stack
levels = self.levels
# w relu
w1 = self.relu(self.weight_two_nodes)
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
w2 = self.relu(self.weight_three_nodes)
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
# build top-down
idx_bifpn = 0
pathtd = inputs
inputs_clone = []
for in_tensor in inputs:
inputs_clone.append(in_tensor.clone())

for i in range(levels - 1, 0, -1):
_, _, h, w = pathtd[i - 1].shape
# pathtd[i - 1] = (
# w1[0, i - 1] * pathtd[i - 1] + w1[1, i - 1] *
# F.interpolate(pathtd[i], size=(h, w), mode='nearest')) / (
# w1[0, i - 1] + w1[1, i - 1] + self.eps)
pathtd[i -
1] = w1[0, i -
1] * pathtd[i - 1] + w1[1, i - 1] * F.interpolate(
pathtd[i], size=(h, w), mode='nearest')
pathtd[i - 1] = swish(pathtd[i - 1])
pathtd[i - 1] = self.bifpn_convs[idx_bifpn](pathtd[i - 1])
idx_bifpn = idx_bifpn + 1
# build down-top
for i in range(0, levels - 2, 1):
tmp_path = torch.stack([
inputs_clone[i + 1], pathtd[i + 1],
F.max_pool2d(pathtd[i], kernel_size=3, stride=2, padding=1)
],
dim=-1)
norm_weight = w2[:, i] / (w2[:, i].sum() + self.eps)
pathtd[i + 1] = (norm_weight * tmp_path).sum(dim=-1)
# pathtd[i + 1] = w2[0, i] * inputs_clone[i + 1]
# + w2[1, i] * pathtd[
# i + 1] + w2[2, i] * F.max_pool2d(
# pathtd[i], kernel_size=3, stride=2, padding=1)
pathtd[i + 1] = swish(pathtd[i + 1])
pathtd[i + 1] = self.bifpn_convs[idx_bifpn](pathtd[i + 1])
idx_bifpn = idx_bifpn + 1

pathtd[levels - 1] = w1[0, levels - 1] * pathtd[levels - 1] + w1[
1, levels - 1] * F.max_pool2d(
pathtd[levels - 2], kernel_size=3, stride=2, padding=1)
pathtd[levels - 1] = swish(pathtd[levels - 1])
pathtd[levels - 1] = self.bifpn_convs[idx_bifpn](pathtd[levels - 1])
return pathtd
56 changes: 56 additions & 0 deletions projects/ABCNet/abcnet/model/coordinate_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmocr.registry import MODELS


@MODELS.register_module()
class CoordinateHead(BaseModule):

def __init__(self,
in_channel=256,
conv_num=4,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)

mask_convs = list()
for i in range(conv_num):
if i == 0:
mask_conv = ConvModule(
in_channels=in_channel + 2, # 2 for coord
out_channels=in_channel,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
mask_conv = ConvModule(
in_channels=in_channel,
out_channels=in_channel,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
mask_convs.append(mask_conv)
self.mask_convs = nn.Sequential(*mask_convs)

def forward(self, features):
coord_features = list()
for feature in features:
x_range = torch.linspace(
-1, 1, feature.shape[-1], device=feature.device)
y_range = torch.linspace(
-1, 1, feature.shape[-2], device=feature.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([feature.shape[0], 1, -1, -1])
x = x.expand([feature.shape[0], 1, -1, -1])
coord = torch.cat([x, y], 1)
feature_with_coord = torch.cat([feature, coord], dim=1)
feature_with_coord = self.mask_convs(feature_with_coord)
feature_with_coord = feature_with_coord + feature
coord_features.append(feature_with_coord)
return coord_features
6 changes: 5 additions & 1 deletion projects/ABCNet/abcnet/model/only_rec_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ class OnlyRecRoIHead(BaseRoIHead):
"""Simplest base roi head including one bbox head and one mask head."""

def __init__(self,
neck=None,
sampler: OptMultiConfig = None,
roi_extractor: OptMultiConfig = None,
rec_head: OptMultiConfig = None,
init_cfg=None):
super().__init__(init_cfg)
if sampler is not None:
self.sampler = TASK_UTILS.build(sampler)
if neck is not None:
self.neck = MODELS.build(neck)
self.roi_extractor = MODELS.build(roi_extractor)
self.rec_head = MODELS.build(rec_head)

Expand All @@ -44,7 +47,8 @@ def loss(self, inputs: Tuple[Tensor], data_samples: DetSampleList) -> dict:

def predict(self, inputs: Tuple[Tensor],
data_samples: DetSampleList) -> RecSampleList:

if hasattr(self, 'neck') and self.neck is not None:
inputs = self.neck(inputs)
pred_instances = [ds.pred_instances for ds in data_samples]
bbox_feats = self.roi_extractor(inputs, pred_instances)
if bbox_feats.size(0) == 0:
Expand Down
Loading

0 comments on commit 0cf9be4

Please sign in to comment.