-
Notifications
You must be signed in to change notification settings - Fork 4
/
models.py
342 lines (295 loc) · 14.1 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
from __future__ import division
from itertools import chain
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from yoeo.utils.parse_config import parse_model_config
from yoeo.utils.utils import weights_init_normal, to_cpu, seg_iou
def create_modules(module_defs):
"""
Constructs module list of layer blocks from module configuration in module_defs
"""
hyperparams = module_defs.pop(0)
hyperparams.update({
'batch': int(hyperparams['batch']),
'subdivisions': int(hyperparams['subdivisions']),
'width': int(hyperparams['width']),
'height': int(hyperparams['height']),
'channels': int(hyperparams['channels']),
'optimizer': hyperparams.get('optimizer'),
'momentum': float(hyperparams['momentum']),
'decay': float(hyperparams['decay']),
'learning_rate': float(hyperparams['learning_rate']),
'burn_in': int(hyperparams['burn_in']),
'max_batches': int(hyperparams['max_batches']),
'policy': hyperparams['policy'],
'lr_steps': list(zip(map(int, hyperparams["steps"].split(",")),
map(float, hyperparams["scales"].split(","))))
})
assert hyperparams["height"] == hyperparams["width"], \
"Height and width should be equal! Non square images are padded with zeros."
output_filters = [hyperparams["channels"]]
module_list = nn.ModuleList()
for module_i, module_def in enumerate(module_defs):
modules = nn.Sequential()
if module_def["type"] == "convolutional":
bn = int(module_def["batch_normalize"])
filters = int(module_def["filters"])
kernel_size = int(module_def["size"])
pad = (kernel_size - 1) // 2
modules.add_module(
f"conv_{module_i}",
nn.Conv2d(
in_channels=output_filters[-1],
out_channels=filters,
kernel_size=kernel_size,
stride=int(module_def["stride"]),
padding=pad,
bias=not bn,
),
)
if bn:
modules.add_module(f"batch_norm_{module_i}",
nn.BatchNorm2d(filters, momentum=0.1, eps=1e-5))
if module_def["activation"] == "leaky":
modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1))
if module_def["activation"] == "mish":
modules.add_module(f"mish_{module_i}", Mish())
elif module_def["type"] == "maxpool":
kernel_size = int(module_def["size"])
stride = int(module_def["stride"])
if kernel_size == 2 and stride == 1:
modules.add_module(f"_debug_padding_{module_i}", nn.ZeroPad2d((0, 1, 0, 1)))
maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride,
padding=int((kernel_size - 1) // 2))
modules.add_module(f"maxpool_{module_i}", maxpool)
elif module_def["type"] == "upsample":
upsample = Upsample(scale_factor=int(module_def["stride"]), mode="nearest")
modules.add_module(f"upsample_{module_i}", upsample)
elif module_def["type"] == "route":
layers = [int(x) for x in module_def["layers"].split(",")]
filters = sum([output_filters[1:][i] for i in layers]) // int(module_def.get("groups", 1))
modules.add_module(f"route_{module_i}", nn.Sequential())
elif module_def["type"] == "shortcut":
filters = output_filters[1:][int(module_def["from"])]
modules.add_module(f"shortcut_{module_i}", nn.Sequential())
elif module_def["type"] == "yolo":
anchor_idxs = [int(x) for x in module_def["mask"].split(",")]
# Extract anchors
anchors = [int(x) for x in module_def["anchors"].split(",")]
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors = [anchors[i] for i in anchor_idxs]
num_classes = int(module_def["classes"])
# Define detection layer
yolo_layer = YOLOLayer(anchors, num_classes)
modules.add_module(f"yolo_{module_i}", yolo_layer)
elif module_def["type"] == "seg":
num_classes = int(module_def["classes"])
modules.add_module(f"seg_{module_i}", SegLayer(num_classes))
# Register module list and number of output filters
module_list.append(modules)
output_filters.append(filters)
return hyperparams, module_list
class Upsample(nn.Module):
""" nn.Upsample is deprecated """
def __init__(self, scale_factor, mode="nearest"):
super(Upsample, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
return x
class Mish(nn.Module):
""" The MISH activation function (https://github.com/digantamisra98/Mish) """
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class YOLOLayer(nn.Module):
"""Detection layer"""
def __init__(self, anchors, num_classes):
super(YOLOLayer, self).__init__()
self.num_anchors = len(anchors)
self.num_classes = num_classes
self.mse_loss = nn.MSELoss()
self.bce_loss = nn.BCELoss()
self.no = num_classes + 5 # number of outputs per anchor
self.grid = torch.zeros(1) # TODO
anchors = torch.tensor(list(chain(*anchors))).float().view(-1, 2)
self.register_buffer('anchors', anchors)
self.register_buffer(
'anchor_grid', anchors.clone().view(1, -1, 1, 1, 2))
self.stride = None
def forward(self, x, img_size):
stride = img_size // x.size(2)
self.stride = stride
bs, _, ny, nx = x.shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x = x.view(bs, self.num_anchors, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.grid.shape[2:4] != x.shape[2:4]:
self.grid = self._make_grid(nx, ny).to(x.device)
x = torch.cat([
(x[..., 0:2].sigmoid() + self.grid) * stride, # xy
torch.exp(x[..., 2:4]) * self.anchor_grid, # wh
x[..., 4:].sigmoid(),
], axis=4).view(bs, -1, self.no)
return x
@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing='ij')
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
class SegLayer(nn.Module):
"""Detection layer"""
def __init__(self, num_classes):
super(SegLayer, self).__init__()
self.num_classes = num_classes
def forward(self, x):
if self.training:
return x
else:
return torch.argmax(x, dim=1)
class Darknet(nn.Module):
"""YOLOv3 object detection model"""
def __init__(self, config_path):
super(Darknet, self).__init__()
self.module_defs = parse_model_config(config_path)
self.hyperparams, self.module_list = create_modules(self.module_defs)
self.yolo_layers = [layer[0] for layer in self.module_list if isinstance(layer[0], YOLOLayer)]
self.seg_layers = [layer[0] for layer in self.module_list if isinstance(layer[0], SegLayer)]
self.num_seg_classes = self.seg_layers[0].num_classes
self.seen = 0
self.header_info = np.array([0, 0, 0, self.seen, 0], dtype=np.int32)
def forward(self, x, bb_targets=None, mask_targets=None):
img_size = x.size(2)
loss = 0
layer_outputs, yolo_outputs, segmentation_outputs = [], [], []
for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
if module_def["type"] in ["convolutional", "upsample", "maxpool"]:
x = module(x)
elif module_def["type"] == "route":
combined_outputs = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)
group_size = combined_outputs.shape[1] // int(module_def.get("groups", 1))
group_id = int(module_def.get("group_id", 0))
x = combined_outputs[:, group_size * group_id : group_size * (group_id + 1)] # Slice groupings used by yolo v4
elif module_def["type"] == "shortcut":
layer_i = int(module_def["from"])
x = layer_outputs[-1] + layer_outputs[layer_i]
elif module_def["type"] == "yolo":
x = module[0](x, img_size)
yolo_outputs.append(x)
elif module_def["type"] == "seg":
x = module[0](x)
segmentation_outputs.append(x)
layer_outputs.append(x)
return (yolo_outputs, segmentation_outputs) if self.training else (torch.cat(yolo_outputs, 1), torch.cat(segmentation_outputs, 1))
def load_darknet_weights(self, weights_path):
"""Parses and loads the weights stored in 'weights_path'"""
# Open the weights file
with open(weights_path, "rb") as f:
# First five are header values
header = np.fromfile(f, dtype=np.int32, count=5)
self.header_info = header # Needed to write header when saving weights
self.seen = header[3] # number of images seen during training
weights = np.fromfile(f, dtype=np.float32) # The rest are weights
# Establish cutoff for loading backbone weights
cutoff = None
# If the weights file has a cutoff, we can find out about it by looking at the filename
# examples: darknet53.conv.74 -> cutoff is 74
filename = os.path.basename(weights_path)
if ".conv." in filename:
try:
cutoff = int(filename.split(".")[-1]) # use last part of filename
except ValueError:
pass
ptr = 0
for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
if i == cutoff:
break
if module_def["type"] == "convolutional":
conv_layer = module[0]
if module_def["batch_normalize"]:
# Load BN bias, weights, running mean and running variance
bn_layer = module[1]
num_b = bn_layer.bias.numel() # Number of biases
# Bias
bn_b = torch.from_numpy(
weights[ptr: ptr + num_b]).view_as(bn_layer.bias)
bn_layer.bias.data.copy_(bn_b)
ptr += num_b
# Weight
bn_w = torch.from_numpy(
weights[ptr: ptr + num_b]).view_as(bn_layer.weight)
bn_layer.weight.data.copy_(bn_w)
ptr += num_b
# Running Mean
bn_rm = torch.from_numpy(
weights[ptr: ptr + num_b]).view_as(bn_layer.running_mean)
bn_layer.running_mean.data.copy_(bn_rm)
ptr += num_b
# Running Var
bn_rv = torch.from_numpy(
weights[ptr: ptr + num_b]).view_as(bn_layer.running_var)
bn_layer.running_var.data.copy_(bn_rv)
ptr += num_b
else:
# Load conv. bias
num_b = conv_layer.bias.numel()
conv_b = torch.from_numpy(
weights[ptr: ptr + num_b]).view_as(conv_layer.bias)
conv_layer.bias.data.copy_(conv_b)
ptr += num_b
# Load conv. weights
num_w = conv_layer.weight.numel()
conv_w = torch.from_numpy(
weights[ptr: ptr + num_w]).view_as(conv_layer.weight)
conv_layer.weight.data.copy_(conv_w)
ptr += num_w
def save_darknet_weights(self, path, cutoff=-1):
"""
@:param path - path of the new weights file
@:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved)
"""
fp = open(path, "wb")
self.header_info[3] = self.seen
self.header_info.tofile(fp)
# Iterate through layers
for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
if module_def["type"] == "convolutional":
conv_layer = module[0]
# If batch norm, load bn first
if module_def["batch_normalize"]:
bn_layer = module[1]
bn_layer.bias.data.cpu().numpy().tofile(fp)
bn_layer.weight.data.cpu().numpy().tofile(fp)
bn_layer.running_mean.data.cpu().numpy().tofile(fp)
bn_layer.running_var.data.cpu().numpy().tofile(fp)
# Load conv bias
else:
conv_layer.bias.data.cpu().numpy().tofile(fp)
# Load conv weights
conv_layer.weight.data.cpu().numpy().tofile(fp)
fp.close()
def load_model(model_path, weights_path=None):
"""Loads the yolo model from file.
:param model_path: Path to model definition file (.cfg)
:type model_path: str
:param weights_path: Path to weights or checkpoint file (.weights or .pth)
:type weights_path: str
:return: Returns model
:rtype: Darknet
"""
device = torch.device("cuda" if torch.cuda.is_available()
else "cpu") # Select device for inference
model = Darknet(model_path).to(device)
model.apply(weights_init_normal)
# If pretrained weights are specified, start from checkpoint or weight file
if weights_path:
if weights_path.endswith(".pth"):
# Load checkpoint weights
model.load_state_dict(torch.load(weights_path, map_location=device))
else:
# Load darknet weights
model.load_darknet_weights(weights_path)
return model