-
Notifications
You must be signed in to change notification settings - Fork 356
/
frcnn.py
71 lines (66 loc) · 2.46 KB
/
frcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from nets.vgg16 import decom_vgg16
from nets.resnet50 import resnet50
from nets.rpn import RegionProposalNetwork
from nets.classifier import VGG16RoIHead,Resnet50RoIHead
import time
import numpy as np
class FasterRCNN(nn.Module):
def __init__(self, num_classes,
mode = "training",
loc_normalize_mean = (0., 0., 0., 0.),
loc_normalize_std = (0.1, 0.1, 0.2, 0.2),
feat_stride = 16,
anchor_scales = [8, 16, 32],
ratios = [0.5, 1, 2],
backbone = 'vgg'
):
super(FasterRCNN, self).__init__()
self.loc_normalize_mean = loc_normalize_mean
self.loc_normalize_std = loc_normalize_std
self.feat_stride = feat_stride
if backbone == 'vgg':
self.extractor, classifier = decom_vgg16()
self.rpn = RegionProposalNetwork(
512, 512,
ratios=ratios,
anchor_scales=anchor_scales,
feat_stride=self.feat_stride,
mode = mode
)
self.head = VGG16RoIHead(
n_class=num_classes + 1,
roi_size=7,
spatial_scale=(1. / self.feat_stride),
classifier=classifier
)
elif backbone == 'resnet50':
self.extractor, classifier = resnet50()
self.rpn = RegionProposalNetwork(
1024, 512,
ratios=ratios,
anchor_scales=anchor_scales,
feat_stride=self.feat_stride,
mode = mode
)
self.head = Resnet50RoIHead(
n_class=num_classes + 1,
roi_size=14,
spatial_scale=(1. / self.feat_stride),
classifier=classifier
)
def forward(self, x, scale=1.):
img_size = x.shape[2:]
h = self.extractor(x)
rpn_locs, rpn_scores, rois, roi_indices, anchor = \
self.rpn.forward(h, img_size, scale)
# print(np.shape(h))
# print(np.shape(rois))
# print(roi_indices)
roi_cls_locs, roi_scores = self.head.forward(h, rois, roi_indices)
return roi_cls_locs, roi_scores, rois, roi_indices
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()