-
Notifications
You must be signed in to change notification settings - Fork 1
/
VGG.py
119 lines (96 loc) · 3.7 KB
/
VGG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from torch import nn
from ConvMod import ConvMod
class VGGClf(nn.Module):
""" Class VGG
"""
def __init__(
self,
in_channels=3,
dec_conv_filter_size=(3, 3),
dec_pool_size=2,
img_size=(224, 224),
dropout=0.3,
num_classes=2,
):
"""__init__ class constructor
Keyword Arguments:
in_channels {int} -- number of input channels (default: {3})
dec_conv_filter_size {tuple} -- kernel size for convolutional layers (default: {(3, 3)})
dec_pool_size {int} -- pooling factor (default: {2})
img_size {tuple} -- image dimensions (default: {(224, 224)})
dropout {float} -- dropout rate for dense/fully connected layers (default: {0.3})
num_classes {int} -- number of classes (default: {2})
"""
super(VGGClf, self).__init__()
self.in_channels = in_channels
self.dec_conv_filter_size = dec_conv_filter_size
self.dec_pool_size = dec_pool_size
self.dropout = dropout
self.img_size = img_size
self.num_classes = num_classes
self.conv_mod0 = ConvMod(self.in_channels, 32, self.dec_conv_filter_size)
self.conv_mod1 = ConvMod(32, 64, self.dec_conv_filter_size)
self.conv_mod2 = ConvMod(64, 128, self.dec_conv_filter_size)
self.conv_mod3 = ConvMod(128, 256, self.dec_conv_filter_size)
self.global_pool = nn.MaxPool2d((14, 14))
self.dense0 = nn.Linear(256, 128)
self.actv0 = nn.Sigmoid()
self.dropout0 = nn.Dropout(p=self.dropout)
self.dense1 = nn.Linear(128, 128)
self.actv1 = nn.Sigmoid()
self.dropout1 = nn.Dropout(p=self.dropout)
self.last_dense = nn.Linear(128, self.num_classes)
self.clf_pool = nn.MaxPool2d(self.dec_pool_size)
self.exp_pool = nn.AvgPool2d(
self.dec_pool_size
) # explanations are downsampled using average pooling --> see paper for details
self.apply(self._init_weights)
def _init_weights(self, m):
"""_init_weights initializes linear layers' weights
Arguments:
m {torch.nn.Layer} -- layer we want to modify
"""
if type(m) == nn.Linear:
m.bias.data.fill_(0.0)
nn.init.xavier_uniform_(m.weight)
def forward(self, x, expl):
"""forward forward pass
Arguments:
x {torch.Tensor} -- input image
expl {torch.Tensor} -- input explanation
Returns:
torch.Tensor -- output feature map
"""
last_expl = expl
x = self.conv_mod0(x)
x = torch.mul(
x, torch.cat([last_expl] * 32, dim=1)
) # connection between explainer and classifier
x = self.clf_pool(x)
last_expl = self.exp_pool(last_expl)
x = self.conv_mod1(x)
x = torch.mul(
x, torch.cat([last_expl] * 64, dim=1)
) # connection between explainer and classifier
x = self.clf_pool(x)
last_expl = self.exp_pool(last_expl)
x = self.conv_mod2(x)
x = torch.mul(
x, torch.cat([last_expl] * 128, dim=1)
) # connection between explainer and classifier
x = self.clf_pool(x)
last_expl = self.exp_pool(last_expl)
x = self.conv_mod3(x)
x = torch.mul(
x, torch.cat([last_expl] * 256, dim=1)
) # connection between explainer and classifier
x = self.clf_pool(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1)
x = self.dense0(x)
x = self.dropout0(self.actv0(x))
x = self.dense1(x)
x = self.dropout1(self.actv1(x))
x = self.last_dense(x)
return x