/
squeeze_net.py
76 lines (72 loc) · 2.68 KB
/
squeeze_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn
class FireModule(nn.Module):
def __init__(self, input_s, output_s, output_e_1, output_e_3):
super(FireModule, self).__init__()
self.conv_s = nn.Conv2d(input_s, output_s, (1,1), padding=0)
self.batchnorm1 = nn.BatchNorm2d(output_s)
self.conv_e1 = nn.Conv2d(output_s, output_e_1, (1,1), padding=0)
self.batchnorm2 = nn.BatchNorm2d(output_e_1)
self.conv_e3 = nn.Conv2d(output_s, output_e_3, (3,3), padding=1)
self.batchnorm3 = nn.BatchNorm2d(output_e_3)
self.relu = nn.ReLU()
def forward(self, s):
s = self.conv_s(s)
s = self.batchnorm1(s)
s = self.relu(s)
e1 = self.conv_e1(s)
e1 = self.batchnorm2(e1)
e1 = self.relu(e1)
e3 = self.conv_e3(s)
e3 = self.batchnorm3(e3)
e3 = self.relu(e3)
return torch.cat([e1,e3],dim=1)
class SqueezeNet(nn.Module):
def __init__(self):
super(SqueezeNet, self).__init__()
self.conv1 = nn.Conv2d(3,96,(7,7),stride=2,padding=3)
self.batchnorm1 = nn.BatchNorm2d(96)
self.pool1 = nn.MaxPool2d((3,3),stride=2,padding=0)
self.fire1 = FireModule(96,16,64,64)
self.fire2 = FireModule(128,16,64,64)
self.fire3 = FireModule(128,32,96,96)
self.pool2 = nn.MaxPool2d((3,3),stride=2,padding=0)
self.fire4 = FireModule(192,48,96,96)
self.fire5 = FireModule(192,48,96,96)
self.fire6 = FireModule(192,48,96,96)
self.fire7 = FireModule(192,48,64,64)
self.pool3 = nn.MaxPool2d((3,3),stride=2,padding=0)
self.fire8 = FireModule(128,48,64,64)
self.pool4 = nn.MaxPool2d((3,3),stride=2,padding=0)
self.fire9 = FireModule(128,48,64,64)
self.pool5 = nn.MaxPool2d((3,3),stride=3,padding=0)
self.fire10 = FireModule(128,48,64,64)
self.conv2 = nn.Conv2d(128,1,(1,1),padding=0)
self.pool6 = nn.AvgPool2d((2,2))
self.flatten = nn.Flatten()
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.conv1(x)
x = self.batchnorm1(x)
x = self.relu(x)
x = self.pool1(x)
x1 = self.fire1(x)
x2 = self.fire2(x1)
x3 = self.fire3(x2+x1)
x = self.pool2(x3)
x4 = self.fire4(x)
x5 = self.fire5(x+x4)
x6 = self.fire6(x5)
x7 = self.fire7(x6+x5)
x = self.pool3(x7)
x8 = self.fire8(x)
x = self.pool4(x+x8)
x9 = self.fire9(x)
x = self.pool5(x+x9)
x10 = self.fire10(x)
x = self.conv2(x+x10)
x = self.sigmoid(x)
x = self.pool6(x)
x = self.flatten(x)
return torch.squeeze(x)