In [1]:
from lib_all import *

## Xây dựng hàm decode

In [2]:
def decode(default_boxes, offsets):
    """
    :param default_boxes:[8732, 4]
    :param offsets: [8732, 4]  trong do 4 la delta cx, delta cy, delta h, delta w
    :return: [8732, 4] cac thong so bbox
    """
    #tinh cx, cy, h, w
    boxes = torch.cat(
        default_boxes[:,:2]( 1 +  0.1 * offsets[:,:2]),
        default_boxes[:,2:]*torch.exp(0.2*offsets[:,2:]),
        dim = 1 #cat theo hang ngang
    )
    # chuyen ve dang x_min, y_min, x_max, y_max
    boxes [:, :2]  -= boxes[:, 2:] # x_min, y_ min
    boxes[:, 2:] +=boxes[:,:2]
    return boxes

## Xây dựng hàm non-maximum-superestion

Loại đi những cái không phải độ tự tin cao nhất

In [3]:
def nms(boxes, confident, threshold = 0.45, top_k = 200):
    """
    :param boxes: [8732, 4] cac thong so bbox
    :param confident: [8732] do tu tin cua tung bbox
    :param threshold: nguong overlap
    :param top_k:
    :return: tensor
    """
    count = 0
    keep = confident.new(confident.size()).zero_().long() # tensor_type giong confident, so chieu confident.size(), toan bo la so 0, kieu long

    # thong so box
    x_min = boxes[:, 0]
    y_min = boxes[:, 1]
    x_max = boxes[:, 2]
    y_max = boxes[:, 3]
    area = torch.mul((x_max-x_min),(y_max-y_min))

    tmp_x_min = boxes.new()
    tmp_y_min = boxes.new()
    tmp_x_max = boxes.new()
    tmp_y_max = boxes.new()
    tmp_w = boxes.new()
    tmp_h = boxes.new()

    value, index = confident.sort(0)
    index = index[-top_k:] # lay ra 200 cai cuoi co do tu tin cao nhat

    while index.numel() > 0:
        index_box_max = index[-1]

        #Luu bien lon nhat
        keep[count] = index_box_max
        count +=1

        if index.size(0) == 1: # neu chi con 1 phan tu thi khong con gi de ma so sanh
            break

        index = inden[:-1] # loai bo box co confident cao nhat

        #infomation boxes
        torch.index_select(x_min, 0 , index, out=tmp_x_min) # =lay gia tri x_min o vi tri index
        torch.index_select(y_min, 0 , index, out=tmp_y_min)
        torch.index_select(x_max, 0, index, out=tmp_x_max)
        torch.index_select(y_max, 0, index,out=tmp_y_max)

        #infomation supression
        tmp_x_min = torch.clamp(tmp_x_min, min=x_min[index]) # =x_min neu tmp_x_min < x_min[index]
        tmp_y_min = torch.clamp(tmp_y_min, min=y_min[index])
        tmp_x_max = torch.clamp(tmp_x_max, max=x_max[index])
        tmp_y_max = torch.clamp(tmp_y_max, max=y_max[index])

        #chuyen ve tensor co size ma index duoc giam di 1
        tmp_h.resize_as_(tmp_y_max)
        tmp_w.resize_as_(tmp_x_max)

        tmp_w = tmp_x_max - tmp_x_min
        tmp_h = tmp_y_max - tmp_y_min

        tmp_w = torch.clamp(tmp_w, min=0.0)
        tmp_h = torch.clamp(tmp_h, min=0.0)

        area_of_overlap = tmp_w*tmp_h

        others_area = torch.index_select(area, 0, index)

        area_of_union = area[index_box_max] + others_area - area_of_union

        iou = area_of_overlap/area_of_union

        #loai bo iou cao
        index = index[iou.le(threshold)]
    return keep, count

# Xây dựng class Detect

In [4]:
class Detect(Function): # Ke thua tu class function de khi goi ham ra tu dong chay ham so forward ben trong
    #contructor
    def __init__(self, confident_thresh = 0.01, top_k = 200, nms_thresh = 0.45):
        self.softmax = nn.Softmax()
        self.confident_thresh = confident_thresh
        self.top_k = top_k
        self.nms_thresh = nms_thresh

    def forward(self, loc_data, conf_data, dbox_list):
        """
        :param loc_data: dung de tinh decode
        :param conf_data: do tu tin cua data do
        :param dbox_list:
        :return:
        """
        num_batch = loc_data.size(0) #batch size
        num_dbox = loc_data.size(1) # 8732
        num_classe = conf_data.size(2) #21

        conf_data = self.softmax(conf_data) # (num_batch, num_dbox, num_classes) -> (num_batch, num_classes, num_dbox)
        conf_preds = conf_data.transpose

        #xu ly anh trong 1 batch cac buc anh
        for index in range(num_batch):
            # Tinh bbox tu offset infomation va default box
            decode_boxes = decode(dbox_list, loc_data[index])
            # Copy conference cua anh thu i
            conf_score = conf_preds[index].clone()

            #tao hop rong
            output = torch.zeros(num_batch, num_classe, self.top_k, 5)
            for index_class in range(1, num_classe):
                c_mask = conf_score[index_class].gt(self.confident_thresh) #list [0,1,...] lay nhung cai conf > 0.01
                score = conf_score[index_class][c_mask]

                # neu khong co phan tu thi thoi khong tinh nua
                if score.nelement() == 0: #numel()
                    continue

                #dua chieu ve giong chieu decode_boxes de tinh toan
                l_mask = c_mask.unsquzee(1).expand_as(decode_boxes) #(8732,4)
                boxes = decode_boxes[l_mask].view(-1,4)
                index, num_object = nms(boxes, score, self.nms_thresh, self.top_k)

                output[index, num_classe, :num_object] = torch.cat((score[index_class[:num_object]].unsqueeze(1), boxes[index_class[:num_object]]), 1) #cat cac phan tu theo chieu ngang 
        return output