In [3]:
import torch
import torch.nn as nn

### Intersection over union (IoU) for yolo coordinates

In [53]:
def intersection_over_union(boxes_preds,boxes_labels,img_w,img_h):
    
    b1_x1 = (boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2)*img_w
    b1_y1 = (boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2)*img_h
    b1_x2 = (boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2)*img_w
    b1_y2 = (boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2)*img_h
    
    b2_x1 = (boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2)*img_w
    b2_y1 = (boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2)*img_h
    b2_x2 = (boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2)*img_w
    b2_y2 = (boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2)*img_h
    
    x1 = torch.max(b1_x1, b2_x1)
    y1 = torch.max(b1_y1, b2_y1)
    x2 = torch.min(b1_x2, b2_x2)
    y2 = torch.min(b1_y2, b2_y2)
    
    intersection = ((x2 - x1).clamp(0) * (y2 - y1).clamp(0))
    
    box1_area = abs((b1_x2 - b1_x1) * (b1_y2 - b1_y1))
    box2_area = abs((b2_x2 - b2_x1) * (b2_y2 - b2_y1))

    return intersection / (box1_area + box2_area - intersection)

In [54]:
class YoloLoss(nn.Module):
        def __init__(self, S=7, B=2, C=20):
            super(YoloLoss, self).__init__()
        
            self.mse = nn.MSELoss(reduction="sum")

            self.S = S
            self.B = B
            self.C = C

            self.lambda_noobj = 0.5
            self.lambda_coord = 5
        
        def forward(self, predictions, target):
            
            predictions = predictions.view(self.S * self.S,self.B*5 + self.C)
            
            iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25],448,448)
            iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
            ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
            
            iou_maxes, bestbox = torch.max(ious, dim=0)
            exists_box = target[..., 20:21]
            
            box_predictions = exists_box * ((bestbox * predictions[..., 26:30]+ (1 - bestbox) * predictions[..., 21:25]))
            box_targets = exists_box * target[..., 21:25]
            
            box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
            torch.abs(box_predictions[..., 2:4] + 1e-6))
            
            box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
            
            box_loss = self.mse((box_predictions),(box_targets))
            
            pred_box = (bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21])
            
            object_loss = self.mse((exists_box * pred_box),(exists_box * target[..., 20:21]))
            
            no_object_loss = self.mse(((1 - exists_box) * predictions[..., 20:21]),((1 - exists_box) * target[..., 20:21]))
        
            no_object_loss += self.mse(((1 - exists_box) * predictions[..., 25:26]),((1 - exists_box) * target[..., 20:21]))
        
            class_loss = self.mse((exists_box * predictions[..., :20]),(exists_box * target[..., :20]))
            
            loss = (
            self.lambda_coord * box_loss  # first two rows in paper
            + object_loss  # third row in paper
            + self.lambda_noobj * no_object_loss  # forth row
            + class_loss  # fifth row
            )
            
            return loss     

In [60]:
yolo_pred=torch.rand(7*7,30)

In [61]:
yolo_target=torch.rand(7*7,25)

In [67]:
loss_tens=YoloLoss().forward(yolo_pred,yolo_target)

In [68]:
loss_tens

tensor(94.2044)

Loss