-
Notifications
You must be signed in to change notification settings - Fork 83
/
PCBModel.py
63 lines (53 loc) · 1.81 KB
/
PCBModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from .resnet import resnet50
class PCBModel(nn.Module):
def __init__(
self,
last_conv_stride=1,
num_stripes=6,
local_conv_out_channels=256,
num_classes=0
):
super(PCBModel, self).__init__()
self.base = resnet50(pretrained=True, last_conv_stride=last_conv_stride)
self.local_conv = nn.Conv2d(2048, local_conv_out_channels, 1)
self.local_bn = nn.BatchNorm2d(local_conv_out_channels)
self.local_relu = nn.ReLU(inplace=True)
self.num_stripes = num_stripes
if num_classes > 0:
self.fc_list = nn.ModuleList()
for _ in range(num_stripes):
fc = nn.Linear(local_conv_out_channels, num_classes)
init.normal(fc.weight, std=0.001)
init.constant(fc.bias, 0)
self.fc_list.append(fc)
def forward(self, x):
"""
Returns:
local_feat_list: each member with shape [N, c]
logits_list: each member with shape [N, num_classes]
"""
# shape [N, C, H, W]
feat = self.base(x)
assert feat.size(2) % self.num_stripes == 0
stripe_h = int(feat.size(2) / self.num_stripes)
local_feat_list = []
logits_list = []
for i in range(self.num_stripes):
# shape [N, C, 1, 1]
local_feat = F.avg_pool2d(
feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],
(stripe_h, feat.size(-1)))
# shape [N, c, 1, 1]
local_feat = self.local_relu(self.local_bn(self.local_conv(local_feat)))
# shape [N, c]
local_feat = local_feat.view(local_feat.size(0), -1)
local_feat_list.append(local_feat)
if hasattr(self, 'fc_list'):
logits_list.append(self.fc_list[i](local_feat))
if hasattr(self, 'fc_list'):
return local_feat_list, logits_list
return local_feat_list