In [2]:
import torch
import torch.nn as nn
import os
import import_ipynb
from utils import intersection_over_union

importing Jupyter notebook from utils.ipynb


In [14]:
t = torch.empty(32,7,7,25)[...,2:4]
print(t.shape)

torch.Size([32, 7, 7, 2])


In [15]:
class YoloLoss(nn.Module):
    """
    Calculate the loss for yolo (v1) model
    """

    def __init__(self, S = 7, B = 2, C = 20):
        super().__init__()
        self.mse = nn.MSELoss(reduction="sum")

        self.S = S
        self.B = B
        self.C = C

        self.lambda_noobj = 0.5    # weight
        self.lambda_coord = 5      # weight

    def forward(self, output, target):
                                                                                     # output tensor shape : ( batch, 1470(=7*7*30) )
        output = output.reshape(-1, self.S, self.S, self.C + self.B * 5)             # ( batch, 7, 7, 30 )

        # Calculate IoU for the two predicted bounding boxes with target bbox
        # In output tensor ( batch, 7, 7, 30 ) there are two bound boxes coordinate information and corresponding confidence 
        # So, for one grid cell ( one of the 7*7 ) it is made of  class probability(20) and two bound boxes( for one bnd there exist (x,y,w,h,c) ) 
        # To calculate Loss we need to express Indicator functions
        # And for first Indicator function IoU ihave to calculated
        # Thus below is calculating IoU
        # Think that IoU is calculated by bnd coordinates
        # Target shape is ( batch, 7, 7, 25 )



        # Target has only one bound box
        # Output has two bound boxes
        # output[..., 21:25] , output[..., 26:30] and target[..., 21:25] means coordinates of bnds

        iou_b1 = intersection_over_union(output[..., 21:25], target[..., 21:25])     # ( batch, 7, 7, 5 )
        iou_b2 = intersection_over_union(output[..., 26:30], target[..., 21:25])     # ( batch, 7, 7, 5 )
        iou_b1_u = iou_b1.unsqueeze(0)                                               # ( 1, batch, 7, 7, 5 )
        iou_b2_u = iou_b2.unsqueeze(0)                                               # ( 1, batch, 7, 7, 5 )
        ious = torch.cat([ iou_b1_u, iou_b2_u], dim=0)                               # ( 2, batch, 7, 7, 5 )


        # Take the box with highest IoU out of the two prediction
        # Note that bestbox will be indices of 0, 1 for which bbox was best
        iou_maxes, bestbox = torch.max(ious, dim=0)                # maximum and argmax and shapes are    ( 2, batch, 7, 1, 7, 30 ) and (  )
        exists_box = target[..., 20].unsqueeze(3)                  # in paper this is Iobj_i              ( batch, 7, 7, 1 ) 

        # ======================== #
        #   FOR BOX COORDINATES    #
        # ======================== #

        # Set boxes with no object in them to 0. We only take out one of the two 
        # predictions, which is the one with highest Iou calculated previously.
        box_predictions = exists_box * ( ( bestbox * output[..., 26:30] + (1 - bestbox) * output[..., 21:25] ) )    
        # It implys Indicator function
        # 그니까 exist_box는 물체가 있고 없고 이고 거기게 bestbox로 best IoI만 계산 되게끔 설정하므로 두 논리가 합쳐져 Indicator function을 구현

        box_targets = exists_box * target[..., 21:25]

         

        # Take sqrt of width, height of boxes to ensure that
        # 2:4 만 바꿈.
        box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt( torch.abs(box_predictions[..., 2:4] + 1e-6)  )
        box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])

        box_loss = self.mse(
            torch.flatten(box_predictions, end_dim=-2),
            torch.flatten(box_targets, end_dim=-2),
        )

        # ==================== #
        #   FOR OBJECT LOSS    #
        # ==================== #

        # pred_box is the confidence score for the bbox with highest IoU
        pred_box = (
            bestbox * output[..., 25:26] + (1 - bestbox) * output[..., 20:21]
        )

        object_loss = self.mse(
            torch.flatten(exists_box * pred_box),
            torch.flatten(exists_box * target[..., 20:21]),
        )

        # ======================= #
        #   FOR NO OBJECT LOSS    #
        # ======================= #

        #max_no_obj = torch.max(predictions[..., 20:21], predictions[..., 25:26])
        #no_object_loss = self.mse(
        #    torch.flatten((1 - exists_box) * max_no_obj, start_dim=1),
        #    torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
        #)

        no_object_loss = self.mse(
            torch.flatten((1 - exists_box) * output[..., 20:21], start_dim=1),
            torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
        )

        no_object_loss += self.mse(
            torch.flatten((1 - exists_box) * output[..., 25:26], start_dim=1),
            torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1)
        )

        # ================== #
        #   FOR CLASS LOSS   #
        # ================== #

        class_loss = self.mse(
            torch.flatten(exists_box * output[..., :20], end_dim=-2,),
            torch.flatten(exists_box * target[..., :20], end_dim=-2,),
        )

        loss = (
            self.lambda_coord * box_loss  # first two rows in paper
            + object_loss  # third row in paper
            + self.lambda_noobj * no_object_loss  # forth row
            + class_loss  # fifth row
        )

        return loss