In [2]:
import torch
import torch.nn as nn
from IoU_implementation import intersection_over_union

In [3]:
class YoloLoss(nn.Module):
    def __init__(self,S=7,B=2,C=20):
        super(YoloLoss,self).__init__()
        self.mse = nn.MSELoss(reduction="sum")
        self.S = S
        self.B = B
        self.C = C
        self.lambda_noobj = 0.5
        self.lambda_coord = 5

    def forward(self,predictions,target):

        # predictions/target shape = (batch_size, 7, 7, 30)
        predictions = predictions.reshape(-1,self.S,self.S,self.C+self.B*5)

        # send x,y,w,h for 1st bbox
        # returns (batch_size,7,7,1)
        iou_b1 = intersection_over_union(predictions[...,21:25],target[...,21:25])
 
        # send x,y,w,h for 2nd bbox
        # returns (batch_size,7,7,1)
        iou_b2 = intersection_over_union(predictions[...,26:30],target[...,26:30])

        # iou_b1.unsqueeze(0) adds an extra dimension at the beginning,
        #  making its shape [1, batch_size, 7,7,1]
        # same happens for iou_b2.unsqueeze(0) 
        # torch.cat([...], dim=0) concatenates these tensors along the new dimension,
        # resulting in ious having shape [2, batch_size,7,7,1]
        ious = torch.cat([iou_b1.unsqueeze(0),iou_b2.unsqueeze(0)],dim=0)

        # torch.max(ious, dim=0) returns:
        # compares values along 0th dimension and returns max value
        # indices/bestboxes of those max value
        # output will be [batch_size,7,7,1] for both

        # EXAMPLE:
        # ious = torch.tensor(
        #     [
        #         [
        #             [0.5, 0.3],
        #             [0.4, 0.8]
        #         ],
        #         [
        #             [0.7, 0.2],
        #             [0.6, 0.4]
        #         ]
        #     ])
        # iou_maxes, bestbox = torch.max(ious, dim=0)
        # will give : 

        # Maximum IOUs:
        # tensor([[0.7, 0.3],
        #         [0.6, 0.8]])

        # Best bounding box indices:
        # tensor([[1, 0],
        #         [1, 0]])
        # In summary, get the best bboxes of the 2 bboxes for each cell
        # result: [2,batch_size,7,7,1] -> [batch_size,7,7,1]
        iou_maxes, bestbox = torch.max(ious,dim=0)

        # (d1, d2, d3, d4,...,d30) -> (d1, d2, d3, 1 , d4,...,d29)
        # selects 21st element along the last dimension
        exist_box = target[...,20].unsqueeze(3) # Iobji / Identity Object as described in YOLO paper

        # FOR BOX COORDINATES
        box_predictions = exist_box * (

            (
                bestbox * predictions[...,26:30] + (1-bestbox) * predictions[...,21:25]
            )
        )

        box_targets = exist_box * target[...,21:25]
        box_predictions[...,2:4] = torch.sign(box_predictions[...,2:4]) * torch.sqrt(torch.abs(box_predictions[...,2:4] + 1e-6 ))

        box_targets[...,2:4] = torch.sqrt(box_targets[...,2:4])

        # (N,S,S,4) -> (N*S*S,4)
        box_loss = self.mse(
                    torch.flatten(box_predictions,end_dim=-2),
                    torch.flatten(box_targets,end_dim=-2)
        )

        # FOR OBJECT LOSS
        pred_box = (
            bestbox * predictions[...,25:26] + (1-bestbox) * predictions[...,20:21]
        )

        # (N*S*S,1)
        object_loss = self.mse(
            torch.flatten(exist_box * pred_box),
            torch.flatten(exist_box * target[...,20:21]),
        )

        # FOR NO OBJECT LOSS

        # (N,S,S,1) -> (N,S*S)
        no_object_loss = self.mse(
            torch.flatten((1-exist_box)*predictions[...,20:21],start_dim=1),
            torch.flatten((1-exist_box)*target[...,20:21],start_dim=1),
        )

        no_object_loss += self.mse(
            torch.flatten((1-exist_box)*predictions[...,25:26],start_dim=1),
            torch.flatten((1-exist_box)*target[...,20:21],start_dim=1),
        )

        # FOR CLASS LOSS
        class_loss = self.mse(
            torch.flatten(exist_box*predictions[...,:20],end_dim=-2),
            torch.flatten(exist_box*target[...,:20],end_dim=-2),
        )

        loss = (
            self.lambda_coord * box_loss
            + object_loss
            + self.lambda_noobj * no_object_loss
            + class_loss
        )

        return loss
