In [1]:
import torch

In [2]:
# itersection / union (prediction and ground truth)
def intersection_over_union(boxes_preds, boxes_label, box_format="midpoint"):  
    # boxes_preds.shape = (batch_size, 4) 
    # boxes_label.shape = (batch_size, 4)
    
    if box_format == 'corners':
        box1_x1 = boxes_preds[:, 0:1]
        box1_y1 = boxes_preds[:, 1:2]
        box1_x2 = boxes_preds[:, 2:3]
        box1_y2 = boxes_preds[:, 3:4]  # By using slioing, can maintain shape --> (N, 1)

        box2_x1 = boxes_labels[:, 0:1]
        box2_y1 = boxes_labels[:, 1:2]
        box2_x2 = boxes_labels[:, 2:3]
        box2_y2 = boxes_labels[:, 3:4]

        x1 = torch.max(box1_x1, box2_x1)
        y1 = torch.max(box1_y1, box2_y1)
        x2 = torch.min(box1_x2, box2_x2)
        y2 = torch.min(box1_y2, box2_y2)

    elif box_format == 'midpoint':
        # 0, 1, 2, 3 : mid_x, mid_y, width, height
        box1_x1 = boxes_preds[:, 0:1] - boxes_preds[:, 2:3] / 2
        box1_y1 = boxes_preds[:, 1:2] - boxes_preds[:, 3:4] / 2
        box1_x2 = boxes_preds[:, 0:1] + boxes_preds[:, 2:3] / 2
        box1_y2 = boxes_preds[:, 1:2] + boxes_preds[:, 3:4] / 2

        box2_x1 = boxes_labels[:, 0:1] - boxes_labels[:, 2:3] / 2
        box2_y1 = boxes_labels[:, 1:2] - boxes_labels[:, 3:4] / 2
        box2_x2 = boxes_labels[:, 0:1] + boxes_labels[:, 2:3] / 2
        box2_y2 = boxes_labels[:, 1:2] + boxes_labels[:, 3:4] / 2

    # clamp(0) --> for the case there are no intersection
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)

    box1_area = abs((box1_x2 - box1_x1) * (box1_y1 - box1_y2))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y1 - box2_y2))

    union = box1_area + box2_area - intersection
    
    return intersection / (union + 1e-7)

In [3]:
# Non Maximum Suppression
def NMS(bboxes, prob_thresh, iou_thresh, box_format="corners"):
    #bboxes = [[class, proba, x1, y1, x2, y2], [...]]
    assert type(bboxes) == list

    bboxes = [box for box in bboxes if box[1] > prob_threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1],
                    reverse=True)  # high -> low proba
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)  # highest proba

        bboxes = [
            box for box in bboxes if box[0] != chosen_box[0]  # not same class
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),  # COORD
                torch.tensor(box[2:]),
                box_format=box_format) < iou_thresh
        ]
        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms