-
Notifications
You must be signed in to change notification settings - Fork 739
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e096df8
commit 0cf9be4
Showing
7 changed files
with
454 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from mmcv.cnn import ConvModule | ||
from mmengine.model import BaseModule | ||
|
||
from mmocr.registry import MODELS | ||
from mmocr.utils import ConfigType, MultiConfig, OptConfigType | ||
|
||
|
||
@MODELS.register_module() | ||
class BiFPN(BaseModule): | ||
"""illustration of a minimal bifpn unit P7_0 -------------------------> | ||
P7_2 --------> | ||
|-------------| ↑ ↓ | | ||
P6_0 ---------> P6_1 ---------> P6_2 --------> | ||
|-------------|--------------↑ ↑ ↓ | P5_0 | ||
---------> P5_1 ---------> P5_2 --------> |-------------|--------------↑ | ||
↑ ↓ | P4_0 ---------> P4_1 ---------> P4_2 | ||
--------> |-------------|--------------↑ ↑ | ||
|--------------↓ | P3_0 -------------------------> P3_2 --------> | ||
""" | ||
|
||
def __init__(self, | ||
in_channels: List[int], | ||
out_channels: int, | ||
num_outs: int, | ||
repeat_times: int = 2, | ||
start_level: int = 0, | ||
end_level: int = -1, | ||
add_extra_convs: bool = False, | ||
relu_before_extra_convs: bool = False, | ||
no_norm_on_lateral: bool = False, | ||
conv_cfg: OptConfigType = None, | ||
norm_cfg: OptConfigType = None, | ||
act_cfg: OptConfigType = None, | ||
laterial_conv1x1: bool = False, | ||
upsample_cfg: ConfigType = dict(mode='nearest'), | ||
pool_cfg: ConfigType = dict(), | ||
init_cfg: MultiConfig = dict( | ||
type='Xavier', layer='Conv2d', distribution='uniform')): | ||
super().__init__(init_cfg=init_cfg) | ||
assert isinstance(in_channels, list) | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.num_ins = len(in_channels) | ||
self.num_outs = num_outs | ||
self.relu_before_extra_convs = relu_before_extra_convs | ||
self.no_norm_on_lateral = no_norm_on_lateral | ||
self.upsample_cfg = upsample_cfg.copy() | ||
self.repeat_times = repeat_times | ||
if end_level == -1 or end_level == self.num_ins - 1: | ||
self.backbone_end_level = self.num_ins | ||
assert num_outs >= self.num_ins - start_level | ||
else: | ||
# if end_level is not the last level, no extra level is allowed | ||
self.backbone_end_level = end_level + 1 | ||
assert end_level < self.num_ins | ||
assert num_outs == end_level - start_level + 1 | ||
self.start_level = start_level | ||
self.end_level = end_level | ||
self.add_extra_convs = add_extra_convs | ||
|
||
self.lateral_convs = nn.ModuleList() | ||
self.extra_convs = nn.ModuleList() | ||
self.bifpn_convs = nn.ModuleList() | ||
for i in range(self.start_level, self.backbone_end_level): | ||
if in_channels[i] == out_channels: | ||
l_conv = nn.Identity() | ||
else: | ||
l_conv = ConvModule( | ||
in_channels[i], | ||
out_channels, | ||
1, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
bias=True, | ||
act_cfg=act_cfg, | ||
inplace=False) | ||
self.lateral_convs.append(l_conv) | ||
|
||
for _ in range(repeat_times): | ||
self.bifpn_convs.append( | ||
BiFPNLayer( | ||
channels=out_channels, | ||
levels=num_outs, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg, | ||
pool_cfg=pool_cfg)) | ||
|
||
# add extra conv layers (e.g., RetinaNet) | ||
extra_levels = num_outs - self.backbone_end_level + self.start_level | ||
if add_extra_convs and extra_levels >= 1: | ||
for i in range(extra_levels): | ||
if i == 0: | ||
in_channels = self.in_channels[self.backbone_end_level - 1] | ||
else: | ||
in_channels = out_channels | ||
if in_channels == out_channels: | ||
extra_fpn_conv = nn.MaxPool2d( | ||
kernel_size=3, stride=2, padding=1) | ||
else: | ||
extra_fpn_conv = nn.Sequential( | ||
ConvModule( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=1, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg), | ||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | ||
self.extra_convs.append(extra_fpn_conv) | ||
|
||
def forward(self, inputs): | ||
|
||
def extra_convs(inputs, extra_convs): | ||
outputs = list() | ||
for extra_conv in extra_convs: | ||
inputs = extra_conv(inputs) | ||
outputs.append(inputs) | ||
return outputs | ||
|
||
assert len(inputs) == len(self.in_channels) | ||
|
||
# build laterals | ||
laterals = [ | ||
lateral_conv(inputs[i + self.start_level]) | ||
for i, lateral_conv in enumerate(self.lateral_convs) | ||
] | ||
if self.num_outs > len(laterals) and self.add_extra_convs: | ||
extra_source = inputs[self.backbone_end_level - 1] | ||
for extra_conv in self.extra_convs: | ||
extra_source = extra_conv(extra_source) | ||
laterals.append(extra_source) | ||
|
||
for bifpn_module in self.bifpn_convs: | ||
laterals = bifpn_module(laterals) | ||
outs = laterals | ||
|
||
return tuple(outs) | ||
|
||
|
||
def swish(x): | ||
return x * x.sigmoid() | ||
|
||
|
||
class BiFPNLayer(BaseModule): | ||
|
||
def __init__(self, | ||
channels, | ||
levels, | ||
init=0.5, | ||
conv_cfg=None, | ||
norm_cfg=None, | ||
act_cfg=None, | ||
upsample_cfg=None, | ||
pool_cfg=None, | ||
eps=0.0001, | ||
init_cfg=None): | ||
super().__init__(init_cfg=init_cfg) | ||
self.act_cfg = act_cfg | ||
self.upsample_cfg = upsample_cfg | ||
self.pool_cfg = pool_cfg | ||
self.eps = eps | ||
self.levels = levels | ||
self.bifpn_convs = nn.ModuleList() | ||
# weighted | ||
self.weight_two_nodes = nn.Parameter( | ||
torch.Tensor(2, levels).fill_(init)) | ||
self.weight_three_nodes = nn.Parameter( | ||
torch.Tensor(3, levels - 2).fill_(init)) | ||
self.relu = nn.ReLU() | ||
for _ in range(2): | ||
for _ in range(self.levels - 1): # 1,2,3 | ||
fpn_conv = nn.Sequential( | ||
ConvModule( | ||
channels, | ||
channels, | ||
3, | ||
padding=1, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg, | ||
inplace=False)) | ||
self.bifpn_convs.append(fpn_conv) | ||
|
||
def forward(self, inputs): | ||
assert len(inputs) == self.levels | ||
# build top-down and down-top path with stack | ||
levels = self.levels | ||
# w relu | ||
w1 = self.relu(self.weight_two_nodes) | ||
w1 /= torch.sum(w1, dim=0) + self.eps # normalize | ||
w2 = self.relu(self.weight_three_nodes) | ||
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize | ||
# build top-down | ||
idx_bifpn = 0 | ||
pathtd = inputs | ||
inputs_clone = [] | ||
for in_tensor in inputs: | ||
inputs_clone.append(in_tensor.clone()) | ||
|
||
for i in range(levels - 1, 0, -1): | ||
_, _, h, w = pathtd[i - 1].shape | ||
# pathtd[i - 1] = ( | ||
# w1[0, i - 1] * pathtd[i - 1] + w1[1, i - 1] * | ||
# F.interpolate(pathtd[i], size=(h, w), mode='nearest')) / ( | ||
# w1[0, i - 1] + w1[1, i - 1] + self.eps) | ||
pathtd[i - | ||
1] = w1[0, i - | ||
1] * pathtd[i - 1] + w1[1, i - 1] * F.interpolate( | ||
pathtd[i], size=(h, w), mode='nearest') | ||
pathtd[i - 1] = swish(pathtd[i - 1]) | ||
pathtd[i - 1] = self.bifpn_convs[idx_bifpn](pathtd[i - 1]) | ||
idx_bifpn = idx_bifpn + 1 | ||
# build down-top | ||
for i in range(0, levels - 2, 1): | ||
tmp_path = torch.stack([ | ||
inputs_clone[i + 1], pathtd[i + 1], | ||
F.max_pool2d(pathtd[i], kernel_size=3, stride=2, padding=1) | ||
], | ||
dim=-1) | ||
norm_weight = w2[:, i] / (w2[:, i].sum() + self.eps) | ||
pathtd[i + 1] = (norm_weight * tmp_path).sum(dim=-1) | ||
# pathtd[i + 1] = w2[0, i] * inputs_clone[i + 1] | ||
# + w2[1, i] * pathtd[ | ||
# i + 1] + w2[2, i] * F.max_pool2d( | ||
# pathtd[i], kernel_size=3, stride=2, padding=1) | ||
pathtd[i + 1] = swish(pathtd[i + 1]) | ||
pathtd[i + 1] = self.bifpn_convs[idx_bifpn](pathtd[i + 1]) | ||
idx_bifpn = idx_bifpn + 1 | ||
|
||
pathtd[levels - 1] = w1[0, levels - 1] * pathtd[levels - 1] + w1[ | ||
1, levels - 1] * F.max_pool2d( | ||
pathtd[levels - 2], kernel_size=3, stride=2, padding=1) | ||
pathtd[levels - 1] = swish(pathtd[levels - 1]) | ||
pathtd[levels - 1] = self.bifpn_convs[idx_bifpn](pathtd[levels - 1]) | ||
return pathtd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.cnn import ConvModule | ||
from mmengine.model import BaseModule | ||
|
||
from mmocr.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class CoordinateHead(BaseModule): | ||
|
||
def __init__(self, | ||
in_channel=256, | ||
conv_num=4, | ||
norm_cfg=dict(type='BN'), | ||
act_cfg=dict(type='ReLU'), | ||
init_cfg=None): | ||
super().__init__(init_cfg=init_cfg) | ||
|
||
mask_convs = list() | ||
for i in range(conv_num): | ||
if i == 0: | ||
mask_conv = ConvModule( | ||
in_channels=in_channel + 2, # 2 for coord | ||
out_channels=in_channel, | ||
kernel_size=3, | ||
padding=1, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg) | ||
else: | ||
mask_conv = ConvModule( | ||
in_channels=in_channel, | ||
out_channels=in_channel, | ||
kernel_size=3, | ||
padding=1, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg) | ||
mask_convs.append(mask_conv) | ||
self.mask_convs = nn.Sequential(*mask_convs) | ||
|
||
def forward(self, features): | ||
coord_features = list() | ||
for feature in features: | ||
x_range = torch.linspace( | ||
-1, 1, feature.shape[-1], device=feature.device) | ||
y_range = torch.linspace( | ||
-1, 1, feature.shape[-2], device=feature.device) | ||
y, x = torch.meshgrid(y_range, x_range) | ||
y = y.expand([feature.shape[0], 1, -1, -1]) | ||
x = x.expand([feature.shape[0], 1, -1, -1]) | ||
coord = torch.cat([x, y], 1) | ||
feature_with_coord = torch.cat([feature, coord], dim=1) | ||
feature_with_coord = self.mask_convs(feature_with_coord) | ||
feature_with_coord = feature_with_coord + feature | ||
coord_features.append(feature_with_coord) | ||
return coord_features |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.