/
squeezenet.py
178 lines (149 loc) · 6.45 KB
/
squeezenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
MindSpore implementation of `SqueezeNet`.
Refer to SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size.
"""
import mindspore.common.initializer as init
from mindspore import Tensor, nn, ops
from .helpers import load_pretrained
from .layers.compatibility import Dropout
from .layers.pooling import GlobalAvgPooling
from .registry import register_model
__all__ = [
"SqueezeNet",
"squeezenet1_0",
"squeezenet1_1",
]
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"first_conv": "features.0",
"classifier": "classifier.1",
**kwargs,
}
default_cfgs = {
"squeezenet1_0": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/squeezenet/squeezenet1_0-eb911778.ckpt"),
"squeezenet1_1": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/squeezenet/squeezenet1_1-da256d3a.ckpt"),
}
class Fire(nn.Cell):
"""define the basic block of squeezenet"""
def __init__(
self,
in_channels: int,
squeeze_channels: int,
expand1x1_channels: int,
expand3x3_channels: int,
) -> None:
super().__init__()
self.squeeze = nn.Conv2d(in_channels, squeeze_channels, kernel_size=1, has_bias=True)
self.squeeze_activation = nn.ReLU()
self.expand1x1 = nn.Conv2d(squeeze_channels, expand1x1_channels, kernel_size=1, has_bias=True)
self.expand1x1_activation = nn.ReLU()
self.expand3x3 = nn.Conv2d(squeeze_channels, expand3x3_channels, kernel_size=3, pad_mode="same", has_bias=True)
self.expand3x3_activation = nn.ReLU()
def construct(self, x: Tensor) -> Tensor:
x = self.squeeze_activation(self.squeeze(x))
return ops.concat((self.expand1x1_activation(self.expand1x1(x)),
self.expand3x3_activation(self.expand3x3(x))), axis=1)
class SqueezeNet(nn.Cell):
r"""SqueezeNet model class, based on
`"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" <https://arxiv.org/abs/1602.07360>`_ # noqa: E501
.. note::
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
N x 3 x 227 x 227, so ensure your images are sized accordingly.
Args:
version: version of the architecture, '1_0' or '1_1'. Default: '1_0'.
num_classes: number of classification classes. Default: 1000.
drop_rate: dropout rate of the classifier. Default: 0.5.
in_channels: number the channels of the input. Default: 3.
"""
def __init__(
self,
version: str = "1_0",
num_classes: int = 1000,
drop_rate: float = 0.5,
in_channels: int = 3,
) -> None:
super().__init__()
if version == "1_0":
self.features = nn.SequentialCell([
nn.Conv2d(in_channels, 96, kernel_size=7, stride=2, pad_mode="valid", has_bias=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(512, 64, 256, 256),
])
elif version == "1_1":
self.features = nn.SequentialCell([
nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1, pad_mode="pad", has_bias=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(64, 16, 64, 64),
Fire(128, 16, 64, 64),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(128, 32, 128, 128),
Fire(256, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
])
else:
raise ValueError(f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected")
self.final_conv = nn.Conv2d(512, num_classes, kernel_size=1, has_bias=True)
self.classifier = nn.SequentialCell([
Dropout(p=drop_rate),
self.final_conv,
nn.ReLU(),
GlobalAvgPooling()
])
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights for cells."""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
if cell is self.final_conv:
cell.weight.set_data(init.initializer(init.Normal(), cell.weight.shape, cell.weight.dtype))
else:
cell.weight.set_data(init.initializer(init.HeUniform(), cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype))
def forward_features(self, x: Tensor) -> Tensor:
x = self.features(x)
return x
def forward_head(self, x: Tensor) -> Tensor:
x = self.classifier(x)
return x
def construct(self, x: Tensor) -> Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
@register_model
def squeezenet1_0(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> SqueezeNet:
"""Get SqueezeNet model of version 1.0.
Refer to the base class `models.SqueezeNet` for more details.
"""
default_cfg = default_cfgs["squeezenet1_0"]
model = SqueezeNet(version="1_0", num_classes=num_classes, in_channels=in_channels, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def squeezenet1_1(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> SqueezeNet:
"""Get SqueezeNet model of version 1.1.
Refer to the base class `models.SqueezeNet` for more details.
"""
default_cfg = default_cfgs["squeezenet1_1"]
model = SqueezeNet(version="1_1", num_classes=num_classes, in_channels=in_channels, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model