-
Notifications
You must be signed in to change notification settings - Fork 1
/
faster_rcnn.py
175 lines (148 loc) · 5.79 KB
/
faster_rcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/env python
# coding=utf-8
"""
本脚本是Faster R-CNN模块的组装
"""
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import MultiScaleRoIAlign
from utils.generalized_rcnn import GeneralizedRCNN
from utils.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from utils.roi_head import RoIHeads
from utils.transform import GeneralizedRCNNTransform
class FasterRCNN(GeneralizedRCNN):
"""
作用:
组装最终的FasterRCNN模块
"""
def __init__(self, backbone, num_classes=None,
# transform参数
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# RPN参数
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000,
rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000,
rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_img=256,
rpn_positive_fraction=0.5,
# RoIHead参数
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_img=512,
box_positive_fraction=0.25,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
bbox_reg_weights=None):
if not hasattr(backbone, "out_channels"):
# backbone要有out_channels属性,后面要用
raise ValueError("backbone should have the out_channels attr")
assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None)))
assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
out_channels = backbone.out_channels
if rpn_anchor_generator is None:
anchor_sizes = ((32,),(64,),(128,),(256,),(512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(anchor_sizes,
aspect_ratios)
if rpn_head is None:
rpn_head = RPNHead(out_channels,
rpn_anchor_generator.num_anchors_per_location()[0])
rpn_pre_nms_top_n = dict(
training=rpn_pre_nms_top_n_train,
testing=rpn_pre_nms_top_n_test
)
rpn_post_nms_top_n = dict(
training=rpn_post_nms_top_n_train,
testing=rpn_post_nms_top_n_test
)
rpn = RegionProposalNetwork(
rpn_anchor_generator, rpn_head,
rpn_pre_nms_top_n, rpn_post_nms_top_n,
rpn_nms_thresh, rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_img, rpn_positive_fraction
)
if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign(
featmap_names=["0", "1", "2", "3"],
output_size=7,
sampling_ratio=2
)
if box_head is None:
# roi特征的尺寸
resolution = box_roi_pool.output_size[0]
representation_size = 1024
box_head = TwoMLHead(
out_channels * resolution ** 2,
representation_size
)
if box_predictor is None:
representation_size = 1024
box_predictor = FastRCNNPredictor(
representation_size,
num_classes=num_classes
)
roi_heads = RoIHeads(
box_roi_pool, box_head, box_predictor,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_img, box_positive_fraction,
box_score_thresh, box_nms_thresh,
box_detections_per_img,
bbox_reg_weights
)
if image_mean is None:
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size,
image_mean, image_std)
super(FasterRCNN, self).__init__(transform, backbone, rpn, roi_heads)
class TwoMLHead(nn.Module):
"""
作用:
RoIAlign以后得到的box_feature需要经过该模块
这是一个两层全连接神经网络
参数:
in_channels: 输入的神经元数
representation_size
"""
def __init__(self, in_channels, representation_size):
super(TwoMLHead, self).__init__()
self.fc6 = nn.Linear(in_channels, representation_size)
self.fc7 = nn.Linear(representation_size, representation_size)
def forward(self, x):
x = x.flatten(start_dim=1)
x = F.relu(self.fc6(x))
x = F.relu(self.fc7(x))
return x
class FastRCNNPredictor(nn.Module):
"""
作用:
最终的分类头和回归头
返回:
scores: 类别logits值
bbox_deltas: 预测框偏移量
"""
def __init__(self, in_channels, num_classes):
super(FastRCNNPredictor, self).__init__()
self.cls_score = nn.Linear(in_channels, num_classes)
self.bbox_pred = nn.Linear(in_channels, num_classes*4)
def forward(self, x):
if x.dim() == 4:
assert list(x.shape[2:]) == [1, 1]
x = x.flatten(start_dim=1)
scores = self.cls_score(x)
bbox_deltas = self.bbox_pred(x)
return scores, bbox_deltas
if __name__ == "__main__":
from utils.backbone_utils import resnet_fpn_backbone
backbone = resnet_fpn_backbone("resnet50", True)
model = FasterRCNN(backbone, num_classes=2)
print(model)
import ipdb;ipdb.set_trace()