In [None]:
import json
from tqdm import tqdm

import numpy as np

In [None]:
def group_boxes_by_class(boxes):
    grouped_boxes = {}
    for box in boxes:
        class_label = box[0]
        if class_label not in grouped_boxes:
            grouped_boxes[class_label] = []
        grouped_boxes[class_label].append(box[1:])

    return grouped_boxes

def approximate_nms_with_class(boxes, iou_threshold=0.99):
    nms_boxes = []

    boxes.sort(key=lambda x: x[0])
    boxes = group_boxes_by_class(boxes)

    for k, boxs in boxes.items():
        if len(boxs) <= 1:
            nms_boxes.append([k, *boxs[0]])
            return nms_boxes
        while len(boxs) > 1:
            selected_box = boxs.pop(0)
            selected_box = [k, *selected_box]
            nms_boxes.append(selected_box)
            boxes = [box for box in boxs if calculate_iou(selected_box, box) < iou_threshold]
    if boxes:
        nms_boxes.append([k, *boxes[0]])

    return nms_boxes

def calculate_iou(box1, box2):
    c, x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2

    intersection_x = max(0, min(x1 + w1, x2 + w2) - max(x1, x2))
    intersection_y = max(0, min(y1 + h1, y2 + h2) - max(y1, y2))

    intersection_area = intersection_x * intersection_y
    union_area = w1 * h1 + w2 * h2 - intersection_area

    iou = intersection_area / union_area if union_area > 0 else 0
    return iou

boxes_with_class = [
    [0, 50, 50, 100, 100],
    [1, 75, 75, 100, 100],
    [0, 50, 50, 100, 100],
    [1, 100, 100, 100, 100],
    [1, 100, 100, 100, 100],
    [2, 50, 150, 100, 100],
]

nms_result_with_class = approximate_nms_with_class(boxes_with_class)

print("Original Boxes with Class:", boxes_with_class)
print("NMS Result with Class:", nms_result_with_class)


In [None]:
def nms_dataset(input_json, thresh):
    with open(input_json) as f:
        data = json.load(f)

    var = [(ann['image_id'], ann['category_id']) for ann in data['annotations']]
    groups = np.array([v[0] for v in var])

    images = data['images']
    annotations = data['annotations']
        
    train_images = [x for x in images if x.get('id') in groups]
    k = 0
    nms_train_annotations = []
    for image in tqdm(images):
        annos = []
        for anno in annotations:
            if image['id'] == anno['image_id']:
                annos.append([anno['category_id'], anno['bbox'][0], anno['bbox'][1], anno['bbox'][2], anno['bbox'][3]])
        d = approximate_nms_with_class(annos, thresh)
        for i in d:
            nms_annotation = {}
            nms_annotation['image_id'] = image['id']
            nms_annotation['category_id'] = i[0]
            nms_annotation['area'] = i[3] * i[4]
            nms_annotation['bbox'] = [i[1], i[2], i[3], i[4]]
            nms_annotation['iscrowd'] = 0
            nms_annotation['id'] = k
            k += 1
            nms_train_annotations.append(nms_annotation)
        
        
    train_data = {
        'info' : data['info'],
        'licenses' : data['licenses'],
        'images' : train_images,
        'categories' : data['categories'],
        'annotations' : nms_train_annotations
        }


    with open(f'../../dataset/train_nms_thresh{thresh}.json', 'w') as f:
        json.dump(train_data, f, indent=4)


In [None]:
nms_dataset('../../dataset/train_fold_4.json', 0.99)