In [5]:
import torch

import sourc_code.utils.box_utils as box_utils
from lib_all import *

In [None]:
class MultiBoxLoss(nn.Module):
    def __init__(self, jaccard_threshold = 0.5, scale_ne_po = 3, device = "cpu"):
        """
        :param jaccard_threshold: nguong jaccard, lon hon thi la positive
        :param scale_ne_po: ti le negative/positive
        :param device: train tren thiet bi nao
        """
        super(MultiBoxLoss, self).__init__() # khoi tao Module va ke thua
        self.jaccard_threshold = jaccard_threshold
        self.scale_ne_po = scale_ne_po
        self.device = device

    def forward(self, predictions, targets):
        """
        :param prediction: du doan
        :param target: label
        :return:
        """
        loc_data, conf_data, dbox_list = predictions

        #loc_data = (batch_size, num_dbox, 4)
        batch_size = loc_data.size(0)
        num_dbox = loc_data.size(1)
        num_classes = conf_data.size(2)

        # Bien doi thong tin annotation thanh label can hoc bang cach tao ra 2 tensor rong
        conf_t_label = torch.LongTensor(batch_size, num_dbox).to(self.device) # chuyen ve dang long de sau nay dua vao F.
        loc_t= torch.Tensor(batch_size, num_dbox, 4)

        for index in range(batch_size):
            truths_box = targets[index][:, :-1].to(self.device) #(object, location)
            labels = targets[index][:, -1].to(self.device)

            dbox = dbox_list.to(self.device)
            variances = [0.1, 0.2] # he so de bien doi tu dbox sang bbox
            box_utils.match(self.jaccard_threshold, truths_box, dbox, variances, labels, loc_t, conf_t_label, index)

        #SmoothL1Loss
        pos_mask = conf_t_label > 0 # khong lay background, conf_t_label [num_img, dfbox]
        #cac tensor khi tinh toan can cung hinh dang voi nhau
        pos_index = pos_mask.unsqueeze(pos_mask.dim()).expand_as(loc_data)

        #positive dbox, loc_data
        loc_p = loc_data[pos_index].view(-1,4)

        # nhan offset
        loc_t = loc_t[pos_index].view(-1,4)

        loss_loc =  F.smooth_l1_loss(loc_p, loc_t, reduction="sum")

        #loss_conf
        #Cross_entropy
        batch_conf = conf_data.view(-1, num_classes) #(num_batch*num_dbox, num_classes)
        loss_conf = F.cross_entropy(batch_conf, conf_t_label, reduction="none")

        #hard negative mining
        num_pos = pos_mask.long().sum(1, keepdim = True) # tinh so >=1 theo hang ngang cua tung anh va giu cac chieu con lai
        loss_conf = loss_conf.view(batch_size, - 1 ) # dua ve dang nay sau nay de tinh loss

        # sap xep
        _, index_loss = loss_conf.sort(1, descending = True) #sap xep giam dan
        # thu hang cua cac index theo do lon, chi ra duoc index nay lon thu bao nhieu
        _,index_rank = loss_conf.sort(1)

        num_neg = torch.clamp(num_pos*self.scale_ne_po, max=num_dbox)
        neg_mask = index_rank < (num_neg).expand_as(index_rank) # dam bao tensor nhan duoc cua dinh dang cua index_rank

        #(num_batch, 8732) -> (num_batch, 8732, 21)
        pos_idx_mask = pos_mask.unsqueeze(2).expand_as(conf_data)
        neg_idx_mask = neg_mask.unsqueeze(2).expand_as(conf_data)

        # mang du doan ra
        conf_t_pre = conf_data[(pos_idx_mask+neg_idx_mask).gt(0)].view(-1, num_classes) #(num_batch*8732, 21)
        # nhan cua minh
        conf_t_label_ = conf_t_label[(pos_mask+neg_mask).gt(0)]


        loss_conf = F.cross_entropy(conf_t_pre, conf_t_label_, reduction="sum")

        # total loss = loss_loc + loss_conf
        N = num_pos.sum() # negative khong tinh
        loss_loc = loss_loc/N
        loss_conf = loss_conf/N

        return loss_loc, loss_conf

# Test

In [16]:
loss = torch.Tensor([[5,2,4,3],[3,1,5,6]])
_, index = loss.sort(1,descending = True)
_,index_final = index.sort(1)
print(index)
print(index_final)

tensor([[0, 2, 3, 1],
        [3, 2, 0, 1]])
tensor([[0, 3, 1, 2],
        [2, 3, 1, 0]])
