Skip to content

Commit

Permalink
Significantly faster training
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed Apr 26, 2019
1 parent 689f38a commit d1041e0
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 119 deletions.
80 changes: 37 additions & 43 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,30 @@ def __init__(self, anchors, num_classes, img_dim=416):
self.noobj_scale = 100
self.metrics = {}
self.img_dim = img_dim
self.nG = 0 # grid size

def compute_grid_offsets(self, grid_size, cuda=True):
self.nG = grid_size
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
self.stride = self.img_dim / self.nG
# Calculate offsets for each grid
self.grid_x = torch.arange(self.nG).repeat(self.nG, 1).view([1, 1, self.nG, self.nG]).type(FloatTensor)
self.grid_y = torch.arange(self.nG).repeat(self.nG, 1).t().view([1, 1, self.nG, self.nG]).type(FloatTensor)
self.scaled_anchors = FloatTensor([(a_w / self.stride, a_h / self.stride) for a_w, a_h in self.anchors])
self.anchor_w = self.scaled_anchors[:, 0:1].view((1, self.num_anchors, 1, 1))
self.anchor_h = self.scaled_anchors[:, 1:2].view((1, self.num_anchors, 1, 1))

def forward(self, x, targets=None):
nA = self.num_anchors
nB = x.size(0)
nG = x.size(2)
stride = self.img_dim / nG

# Tensors for cuda support
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if x.is_cuda else torch.ByteTensor

prediction = x.view(nB, nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous()
nB = x.size(0)
nG = x.size(2)

prediction = x.view(nB, self.num_anchors, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous()

# Get outputs
x = torch.sigmoid(prediction[..., 0]) # Center x
Expand All @@ -148,58 +159,43 @@ def forward(self, x, targets=None):
pred_conf = torch.sigmoid(prediction[..., 4]) # Conf
pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred.

# Calculate offsets for each grid
grid_x = torch.arange(nG).repeat(nG, 1).view([1, 1, nG, nG]).type(FloatTensor)
grid_y = torch.arange(nG).repeat(nG, 1).t().view([1, 1, nG, nG]).type(FloatTensor)
scaled_anchors = FloatTensor([(a_w / stride, a_h / stride) for a_w, a_h in self.anchors])
anchor_w = scaled_anchors[:, 0:1].view((1, nA, 1, 1))
anchor_h = scaled_anchors[:, 1:2].view((1, nA, 1, 1))
# If grid size does not match current we compute new offsets
if nG != self.nG:
self.compute_grid_offsets(nG, cuda=x.is_cuda)

# Add offset and scale with anchors
pred_boxes = FloatTensor(prediction[..., :4].shape)
pred_boxes[..., 0] = x.data + grid_x
pred_boxes[..., 1] = y.data + grid_y
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
pred_boxes[..., 0] = x.data + self.grid_x
pred_boxes[..., 1] = y.data + self.grid_y
pred_boxes[..., 2] = torch.exp(w.data) * self.anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * self.anchor_h

output = torch.cat(
(pred_boxes.view(nB, -1, 4) * stride, pred_conf.view(nB, -1, 1), pred_cls.view(nB, -1, self.num_classes)),
(
pred_boxes.view(nB, -1, 4) * self.stride,
pred_conf.view(nB, -1, 1),
pred_cls.view(nB, -1, self.num_classes),
),
-1,
)

if targets is None:
# Inference
return output
else:
# Training
if x.is_cuda:
self.mse_loss = self.mse_loss.cuda()
self.bce_loss = self.bce_loss.cuda()

iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tconf, tcls = build_targets(
pred_boxes=to_cpu(pred_boxes),
pred_cls=to_cpu(pred_cls),
target=to_cpu(targets),
anchors=to_cpu(scaled_anchors),
num_anchors=nA,
num_classes=self.num_classes,
grid_size=nG,
iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls = build_targets(
pred_boxes=pred_boxes,
pred_cls=pred_cls,
target=targets,
anchors=self.scaled_anchors,
ignore_thres=self.ignore_thres,
)

# Target variables
tx = tx.type(FloatTensor)
ty = ty.type(FloatTensor)
tw = tw.type(FloatTensor)
th = th.type(FloatTensor)
tconf = tconf.type(FloatTensor)
tcls = tcls.type(FloatTensor)

# Loss : Mask outputs to ignore non-existing objects (except with conf. loss)
loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
loss_h = self.mse_loss(h[obj_mask], th[obj_mask])
tconf = obj_mask.float()
loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])
loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj
Expand All @@ -210,12 +206,10 @@ def forward(self, x, targets=None):
cls_acc = 100 * class_mask[obj_mask].mean()
conf_obj = pred_conf[obj_mask].mean()
conf_noobj = pred_conf[noobj_mask].mean()
class_mask = class_mask.type(FloatTensor)
conf50 = (pred_conf > 0.5).float()
iou50 = (iou_scores > 0.5).type(FloatTensor)
iou75 = (iou_scores > 0.75).type(FloatTensor)
obj_mask = obj_mask.type(FloatTensor)
detected_mask = conf50 * class_mask * obj_mask
iou50 = (iou_scores > 0.5).float()
iou75 = (iou_scores > 0.75).float()
detected_mask = conf50 * class_mask * tconf
precision = torch.sum(iou50 * detected_mask) / (conf50.sum() + 1e-16)
recall50 = torch.sum(iou50 * detected_mask) / (obj_mask.sum() + 1e-16)
recall75 = torch.sum(iou75 * detected_mask) / (obj_mask.sum() + 1e-16)
Expand Down
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
# Get dataloader
dataset = ListDataset(train_path)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu, pin_memory=True
dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
pin_memory=True,
collate_fn=dataset.collate_fn,
)

optimizer = torch.optim.Adam(model.parameters())
Expand Down
20 changes: 13 additions & 7 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,19 @@ def __getitem__(self, index):
if np.random.random() < 0.5:
img, labels = horisontal_flip(img, labels)

# Fill matrix
filled_labels = torch.zeros((self.max_objects, 5))
if labels is not None:
labels = labels[: self.max_objects]
filled_labels[: len(labels)] = labels

return img_path, img, filled_labels
boxes = torch.zeros((len(labels), 6))
boxes[:, 1:] = labels

return img_path, img, boxes

@staticmethod
def collate_fn(batch):
paths, imgs, labels = list(zip(*batch))
for i, boxes in enumerate(labels):
boxes[:, 0] = i
imgs = torch.stack(imgs, 0)
labels = torch.cat(labels, 0)
return paths, imgs, labels

def __len__(self):
return len(self.img_files)
131 changes: 63 additions & 68 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ def get_batch_statistics(outputs, targets, iou_threshold):
return batch_metrics


def bbox_wh_iou(wh1, wh2):
wh2 = wh2.t()
w1, h1 = wh1[0], wh1[1]
w2, h2 = wh2[0], wh2[1]
inter_area = torch.min(w1, w2) * torch.min(h1, h2)
union_area = (w1 * h1 + 1e-16) + w2 * h2 - inter_area
return inter_area / union_area


def bbox_iou(box1, box2, x1y1x2y2=True):
"""
Returns the IoU of two bounding boxes
Expand Down Expand Up @@ -269,71 +278,57 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
return output


def build_targets(pred_boxes, pred_cls, target, anchors, num_anchors, num_classes, grid_size, ignore_thres):

nB = target.size(0)
nA = num_anchors
nC = num_classes
nG = grid_size
obj_mask = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
noobj_mask = torch.ByteTensor(nB, nA, nG, nG).fill_(1)
class_mask = torch.zeros(nB, nA, nG, nG).float()
iou_scores = torch.zeros(nB, nA, nG, nG).float()
tx = torch.zeros(nB, nA, nG, nG)
ty = torch.zeros(nB, nA, nG, nG)
tw = torch.zeros(nB, nA, nG, nG)
th = torch.zeros(nB, nA, nG, nG)
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0)

num_targets = 0
num_correct = 0
for b in range(nB):
for t in range(target.shape[1]):
if target[b, t].sum() == 0:
continue
num_targets += 1
# Convert to position relative to box
gx = target[b, t, 1] * nG
gy = target[b, t, 2] * nG
gw = target[b, t, 3] * nG
gh = target[b, t, 4] * nG
# Get grid box indices
gi = int(gx)
gj = int(gy)
# Get the shape of the gt box (centered at (100, 100))
gt_shape = torch.FloatTensor([100, 100, gw, gh]).unsqueeze(0)
# Get shape of the anchor boxes (centered at (100, 100))
anchor_shapes = torch.ones((len(anchors), 4)).float() * 100
anchor_shapes[:, 2:] = anchors
# Compute iou between gt and anchor shapes
anch_ious = bbox_iou(gt_shape, anchor_shapes, x1y1x2y2=False)
# Find the best matching anchor box
best_n = torch.argmax(anch_ious)
# Get ground truth box
gt_box = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0)
# Get class of target box
target_label = int(target[b, t, 0])
# Get the ious of prediction at each anchor
pred_ious = bbox_iou(gt_box, pred_boxes[b, :, gj, gi], x1y1x2y2=False)
# Get label correctness
label_correctness = (torch.argmax(pred_cls[b, :, gj, gi], -1) == target_label).float()
# Masks
obj_mask[b, best_n, gj, gi] = 1
noobj_mask[b, best_n, gj, gi] = 0
# Coordinates
tx[b, best_n, gj, gi] = gx - gi
ty[b, best_n, gj, gi] = gy - gj
# Width and height
tw[b, best_n, gj, gi] = math.log(gw / anchors[best_n][0] + 1e-16)
th[b, best_n, gj, gi] = math.log(gh / anchors[best_n][1] + 1e-16)
# One-hot encoding of label
tcls[b, best_n, gj, gi, target_label] = 1
tconf[b, best_n, gj, gi] = 1
# Compute label correctness and iou at best anchor
class_mask[b, best_n, gj, gi] = label_correctness[best_n]
iou_scores[b, best_n, gj, gi] = pred_ious[best_n]
# Where the overlap is larger than threshold set mask to zero (ignore)
noobj_mask[b, anch_ious > ignore_thres, gj, gi] = 0

return iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tconf, tcls
def build_targets(pred_boxes, pred_cls, target, anchors, ignore_thres):

ByteTensor = torch.cuda.ByteTensor if pred_boxes.is_cuda else torch.ByteTensor
FloatTensor = torch.cuda.FloatTensor if pred_boxes.is_cuda else torch.FloatTensor

nB = pred_boxes.size(0)
nA = pred_boxes.size(1)
nC = pred_cls.size(-1)
nG = pred_boxes.size(2)

# Output tensors
obj_mask = ByteTensor(nB, nA, nG, nG).fill_(0)
noobj_mask = ByteTensor(nB, nA, nG, nG).fill_(1)
class_mask = FloatTensor(nB, nA, nG, nG).fill_(0)
iou_scores = FloatTensor(nB, nA, nG, nG).fill_(0)
tx = FloatTensor(nB, nA, nG, nG).fill_(0)
ty = FloatTensor(nB, nA, nG, nG).fill_(0)
tw = FloatTensor(nB, nA, nG, nG).fill_(0)
th = FloatTensor(nB, nA, nG, nG).fill_(0)
tcls = FloatTensor(nB, nA, nG, nG, nC).fill_(0)

# Convert to position relative to box
target_boxes = target[:, 2:6] * nG
gxy = target_boxes[:, :2]
gwh = target_boxes[:, 2:]
# Get anchors with best iou
ious = torch.stack([bbox_wh_iou(anchor, gwh) for anchor in anchors])
best_ious, best_n = ious.max(0)
# Separate target values
b, target_labels = target[:, :2].long().t()
gx, gy = gxy.t()
gw, gh = gwh.t()
gi, gj = gxy.long().t()
# Set masks
obj_mask[b, best_n, gj, gi] = 1
noobj_mask[b, best_n, gj, gi] = 0

# Set noobj mask to zero where iou exceeds ignore threshold
for i, anchor_ious in enumerate(ious.t()):
noobj_mask[b[i], anchor_ious > ignore_thres, gj[i], gi[i]] = 0

# Coordinates
tx[b, best_n, gj, gi] = gx - gx.floor()
ty[b, best_n, gj, gi] = gy - gy.floor()
# Width and height
tw[b, best_n, gj, gi] = torch.log(gw / anchors[best_n][:, 0] + 1e-16)
th[b, best_n, gj, gi] = torch.log(gh / anchors[best_n][:, 1] + 1e-16)
# One-hot encoding of label
tcls[b, best_n, gj, gi, target_labels] = 1
# Compute label correctness and iou at best anchor
class_mask[b, best_n, gj, gi] = (pred_cls[b, best_n, gj, gi].argmax(-1) == target_labels).float()
iou_scores[b, best_n, gj, gi] = bbox_iou(pred_boxes[b, best_n, gj, gi], target_boxes, x1y1x2y2=False)

return iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls

0 comments on commit d1041e0

Please sign in to comment.