In [91]:
import numpy as np
import warnings
import pprint as pp

In [73]:
def prefilter_boxes(
    boxes,
    scores,
    labels,
    weights,
    skip_thr
):
    new_boxes = dict()
    
    for i in range(len(boxes)):
        if len(boxes[i]) != len(scores[i]):
            print("Length of boxes arrays not equal to length of scores array")
            exit()

        if len(boxes[i]) != len(labels[i]):
            print("Length of boxes arrays not equal to length of labels array")
            exit()

        for j in range(len(boxes[i])):
            score = scores[i][j]
            if score < skip_thr:
                continue
                
            label = int(labels[i][j])
            box_coord = boxes[i][j]
            xmin = float(box_coord[0])
            ymin = float(box_coord[1])
            xmax = float(box_coord[2])
            ymax = float(box_coord[3])
            
            # [label, score, weight, model index, xmin, ymin, xmax, ymax]
            b = [int(label), float(score) * weights[i], weights[i], i, xmin, ymin, xmax, ymax]
            
            if label not in new_boxes:
                new_boxes[label] = []
            new_boxes[label].append(b)
            
    for label in new_boxes:
        current_boxes = np.array(new_boxes[label])
        new_boxes[label] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
        
    return new_boxes

In [74]:
def get_weighted_box(boxes, conf_type='avg'):
    # weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
    box = np.zeros(8, dtype=np.float32)
    conf = 0
    conf_list = []
    w = 0
    
    for b in boxes:
        box[4:] += (b[1] * b[4:])
        conf += b[1]
        conf_list.append(b[1])
        w += b[2]
        
    box[0] = boxes[0][0]
    box[1] = conf / len(boxes)
    box[2] = w
    box[3] = -1 # model index field is retained for consistency but is not used
    box[4:] /= conf
    return box

In [75]:
def find_matching_box_fast(F, B, match_iou):
    # find_matching_box_fast(weighted_boxes, boxes[j], iou_thr)
    # boxes_list (boxes): F
    # new_box: B
    
    def bb_iou_array(boxes, new_box): # F, B
        inter_xmin = np.maximum(boxes[:, 0], new_box[0])
        inter_ymin = np.maximum(boxes[:, 1], new_box[1])
        inter_xmax = np.minimum(boxes[:, 2], new_box[2])
        inter_ymax = np.minimum(boxes[:, 3], new_box[3])
        
        intersection = np.maximum(inter_xmax - inter_xmin, 0) * np.maximum(inter_ymax - inter_ymin, 0)
        
        union = (
            (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
            + (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
            - intersection
        )
        
        return intersection / (union + 1e-8)
    
    if F.shape[0] == 0:
        return -1, match_iou
    
    ious = bb_iou_array(F[:, 4:], B[4:])
    
    ious[F[:, 0] != B[0]] = -1
    
    best_idx = np.argmax(ious)
    best_iou = ious[best_idx]
    
    if best_iou <= match_iou:
        best_iou = match_iou
        best_idx = -1
        
    return best_idx, best_iou 

In [88]:
def weighted_boxes_fusion(
    boxes_list,
    scores_list,
    labels_list,
    weights=None,
    iou_thr=0.55,
    skip_box_thr=0.0,
    conf_type='avg'
):
    if weights is None:
        weights = np.ones(len(boxes_list))
    if len(weights) != len(boxes_list):
        print("[Warning] Incorrect number of weights {}. Must be {}. Set weights equal to 1".format(len(weights), len(boxes_list)))
        weights = np.ones(len(boxes_list))
    weights = np.array(weights)
    
    # Filter boxes
    filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
    if len(filtered_boxes) == 0:
        return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
    
    overall_boxes = []
    for label in filtered_boxes:
        boxes = filtered_boxes[label] # filtered boxes for a label: B
        new_boxes = [] # L
        weighted_boxes = np.empty((0, 8)) # F
        
        # Cluster
        for j in range(0, len(boxes)):
            index, best_iou = find_matching_box_fast(weighted_boxes, boxes[j], iou_thr)
            
            # Match Found
            if index != -1:
                new_boxes[index].append(boxes[j]) # append to L
                weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
            # Match Not Found
            else:
                new_boxes.append([boxes[j].copy()])
                weighted_boxes = np.vstack((weighted_boxes, boxes[j].copy()))
        
        # Rescale Confidence
        for i in range(len(new_boxes)):
            clustered_boxes = new_boxes[i]
            
            weighted_boxes[i, 1] = weighted_boxes[i, 1] * min(len(clustered_boxes), len(weights)) / weights.sum()
            
        overall_boxes.append(weighted_boxes)
        
    overall_boxes = np.concatenate(overall_boxes, axis=0)
    overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
    boxes = overall_boxes[:, 4:]
    scores = overall_boxes[:, 1]
    labels = overall_boxes[:, 0]
    return boxes, scores, labels

In [96]:
boxes_list = [
    [
        [10, 40, 38, 140],
        [40, 50, 78, 118],
        [19, 24, 64, 127]
    ],
    [
        [30, 45, 68, 125],
        [25, 28, 74, 132],
        [20, 25, 64, 128],
        [23, 30, 68, 123]
        
    ]
]

scores_list = [[0.62, 0.35, 0.99], [0.86, 0.94, 0.91, 0.89]]
labels_list = [[1, 1, 1],[1, 1, 1, 1]]
weights=[1,1]

iou_thr = 0.3
skip_box_thr = 0.01

for model_boxes in boxes_list:
    for box in model_boxes:
        box[0] /= 144.
        box[1] /= 144.
        box[2] /= 144.
        box[3] /= 144.
        
# WBF
boxes, scores, labels = weighted_boxes_fusion(
    boxes_list,
    scores_list,
    labels_list,
    iou_thr=iou_thr,
    skip_box_thr=skip_box_thr,
    weights=[1,1]
)
print(boxes)
print(scores)

[[0.16978744 0.21891868 0.47438708 0.87798017]
 [0.06944444 0.27777778 0.26388889 0.97222222]]
[0.82333332 0.31      ]
