/
densenet.py
278 lines (229 loc) · 10.2 KB
/
densenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""
MindSpore implementation of `DenseNet`.
Refer to: Densely Connected Convolutional Networks
"""
import math
from collections import OrderedDict
from typing import Tuple
import mindspore.common.initializer as init
from mindspore import Tensor, nn, ops
from .helpers import load_pretrained
from .layers.compatibility import Dropout
from .layers.pooling import GlobalAvgPooling
from .registry import register_model
__all__ = [
"DenseNet",
"densenet121",
"densenet161",
"densenet169",
"densenet201",
]
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"first_conv": "features.conv0",
"classifier": "classifier",
**kwargs,
}
default_cfgs = {
"densenet121": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/densenet/densenet121-120_5004_Ascend.ckpt"),
"densenet169": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/densenet/densenet169-120_5004_Ascend.ckpt"),
"densenet201": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/densenet/densenet201-120_5004_Ascend.ckpt"),
"densenet161": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/densenet/densenet161-120_5004_Ascend.ckpt"),
}
class _DenseLayer(nn.Cell):
"""Basic unit of DenseBlock (using bottleneck layer)"""
def __init__(
self,
num_input_features: int,
growth_rate: int,
bn_size: int,
drop_rate: float,
) -> None:
super().__init__()
self.norm1 = nn.BatchNorm2d(num_input_features)
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1)
self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, pad_mode="pad", padding=1)
self.drop_rate = drop_rate
self.dropout = Dropout(p=self.drop_rate)
def construct(self, features: Tensor) -> Tensor:
bottleneck = self.conv1(self.relu1(self.norm1(features)))
new_features = self.conv2(self.relu2(self.norm2(bottleneck)))
if self.drop_rate > 0.0:
new_features = self.dropout(new_features)
return new_features
class _DenseBlock(nn.Cell):
"""DenseBlock. Layers within a block are densely connected."""
def __init__(
self,
num_layers: int,
num_input_features: int,
bn_size: int,
growth_rate: int,
drop_rate: float,
) -> None:
super().__init__()
self.cell_list = nn.CellList()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features=num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
)
self.cell_list.append(layer)
def construct(self, init_features: Tensor) -> Tensor:
features = init_features
for layer in self.cell_list:
new_features = layer(features)
features = ops.concat((features, new_features), axis=1)
return features
class _Transition(nn.Cell):
"""Transition layer between two adjacent DenseBlock"""
def __init__(
self,
num_input_features: int,
num_output_features: int,
) -> None:
super().__init__()
self.features = nn.SequentialCell(OrderedDict([
("norm", nn.BatchNorm2d(num_input_features)),
("relu", nn.ReLU()),
("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1)),
("pool", nn.AvgPool2d(kernel_size=2, stride=2))
]))
def construct(self, x: Tensor) -> Tensor:
x = self.features(x)
return x
class DenseNet(nn.Cell):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
growth_rate: how many filters to add each layer (`k` in paper). Default: 32.
block_config: how many layers in each pooling block. Default: (6, 12, 24, 16).
num_init_features: number of filters in the first Conv2d. Default: 64.
bn_size (int): multiplicative factor for number of bottleneck layers
(i.e. bn_size * k features in the bottleneck layer). Default: 4.
drop_rate: dropout rate after each dense layer. Default: 0.
in_channels: number of input channels. Default: 3.
num_classes: number of classification classes. Default: 1000.
"""
def __init__(
self,
growth_rate: int = 32,
block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
num_init_features: int = 64,
bn_size: int = 4,
drop_rate: float = 0.0,
in_channels: int = 3,
num_classes: int = 1000,
) -> None:
super().__init__()
layers = OrderedDict()
# first Conv2d
num_features = num_init_features
layers["conv0"] = nn.Conv2d(in_channels, num_features, kernel_size=7, stride=2, pad_mode="pad", padding=3)
layers["norm0"] = nn.BatchNorm2d(num_features)
layers["relu0"] = nn.ReLU()
layers["pool0"] = nn.SequentialCell([
nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT"),
nn.MaxPool2d(kernel_size=3, stride=2),
])
# DenseBlock
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
)
layers[f"denseblock{i + 1}"] = block
num_features += num_layers * growth_rate
if i != len(block_config) - 1:
transition = _Transition(num_features, num_features // 2)
layers[f"transition{i + 1}"] = transition
num_features = num_features // 2
# final bn+ReLU
layers["norm5"] = nn.BatchNorm2d(num_features)
layers["relu5"] = nn.ReLU()
self.num_features = num_features
self.features = nn.SequentialCell(layers)
self.pool = GlobalAvgPooling()
self.classifier = nn.Dense(self.num_features, num_classes)
self._initialize_weights()
def _initialize_weights(self) -> None:
"""Initialize weights for cells."""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(
init.initializer(init.HeNormal(math.sqrt(5), mode="fan_out", nonlinearity="relu"),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(
init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"),
cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype))
cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(
init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype))
def forward_features(self, x: Tensor) -> Tensor:
x = self.features(x)
return x
def forward_head(self, x: Tensor) -> Tensor:
x = self.pool(x)
x = self.classifier(x)
return x
def construct(self, x: Tensor) -> Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
@register_model
def densenet121(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> DenseNet:
"""Get 121 layers DenseNet model.
Refer to the base class `models.DenseNet` for more details."""
default_cfg = default_cfgs["densenet121"]
model = DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, in_channels=in_channels,
num_classes=num_classes, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def densenet161(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> DenseNet:
"""Get 161 layers DenseNet model.
Refer to the base class `models.DenseNet` for more details."""
default_cfg = default_cfgs["densenet161"]
model = DenseNet(growth_rate=48, block_config=(6, 12, 36, 24), num_init_features=96, in_channels=in_channels,
num_classes=num_classes, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def densenet169(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> DenseNet:
"""Get 169 layers DenseNet model.
Refer to the base class `models.DenseNet` for more details."""
default_cfg = default_cfgs["densenet169"]
model = DenseNet(growth_rate=32, block_config=(6, 12, 32, 32), num_init_features=64, in_channels=in_channels,
num_classes=num_classes, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def densenet201(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> DenseNet:
"""Get 201 layers DenseNet model.
Refer to the base class `models.DenseNet` for more details."""
default_cfg = default_cfgs["densenet201"]
model = DenseNet(growth_rate=32, block_config=(6, 12, 48, 32), num_init_features=64, in_channels=in_channels,
num_classes=num_classes, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model