In [None]:
# Bounding box pruner to merge bounding boxes likely to correspond to same object. 

import cv2
import csv
import pandas
import matplotlib.pyplot as plt 
import os, math
import operator
import numpy as np
import matplotlib.patches as patches

def getIOU(bbox1, bbox2):
    intertl_x = max(bbox1[1], bbox2[1])
    intertl_y = min(bbox1[2], bbox2[2])
    interbr_x = min(bbox1[3], bbox2[3])
    interbr_y = max(bbox1[4], bbox2[4])
    
    interArea = (interbr_x - intertl_x) * (intertl_y - interbr_y)
    #print(interbr_x - intertl_x)
    #print(intertl_y - interbr_y)
    if (interbr_x - intertl_x) < 0 or (intertl_y - interbr_y) < 0:
        return -1
    
    area1 = (bbox1[3] - bbox1[1]) * (bbox1[2] - bbox1[4])
    area2 = (bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[4])
    
    IOU = interArea / (area1 + area2 - interArea)
    #print("inter:", interArea, "\tarea1:", area1, "\tarea2:", area2)
    return IOU
    



def getMatches(bounding_boxes, iou_thresh = 0.8):
    IOUs = []
    listOfMatches = []
    for i, bbox in zip(range(0, len(bounding_boxes)), bounding_boxes):
        #print("Inside getMatches:", bbox)
        listOfMatches.append([i])
        # inner loop through all remaining bboxes
        for j, inbbox in zip(range(i + 1, len(bounding_boxes)), bounding_boxes[i+1:]):
            #print(i, j)
            #print("bbox:", bbox, "\tinbbox", inbbox)
            IOUs.append(getIOU(bbox, inbbox))
            #print("IOU:",IOUs[-1])

            # check if IOU is sufficiently high AND the class is the same
            if IOUs[-1] > iou_thresh and bbox[0] == inbbox[0]:
                listOfMatches[i].append(j)
                #print("Bounding boxes ", i, " and ", j, "will be merged.")
    return listOfMatches


def aggregateBoundingBoxes(df, show_plots = False):
    
    if show_plots:
        fig, ax = plt.subplots()
        curax = plt.gca()

    bboxes = []
    
    for bbox in df.itertuples(index = True):
        #print(bbox)
        tl_x = bbox[2] - bbox[4] / 2
        tl_y = bbox[3] + bbox[5] / 2
        br_x = bbox[2] + bbox[4] / 2
        br_y = bbox[3] - bbox[5] / 2
        object_type = bbox[1]
        
        if show_plots:
            curax.add_patch(patches.Rectangle((bbox[2] - bbox[4] / 2, bbox[3] - bbox[5] / 2), bbox[4], bbox[5], edgecolor='r',facecolor=(0, 1, 1, 0.7)))
            curax.text(bbox[2] - bbox[4] / 2 + 0.05, bbox[3] - bbox[5] / 2 + 0.05, bbox[0] )

        bboxes.append([object_type, tl_x, tl_y, br_x, br_y])
        
    num_bboxes_original = len(bboxes)
    bbox = 0    

    matched_bboxes = []  # bboxes that have an IOU > iou_thresh
    old_bboxes = bboxes
    new_bboxes = []  # merged bboxes, initialise as original bboxes
    #bboxes = bboxes[0:3]
    #print("bboxes:",bboxes)

    finished = False
    max_it = 5
    it = 0
    if show_plots:
        fig2, ax2 = plt.subplots(max_it, 1, figsize=(8, 6))
    while (finished == False and it < max_it):
        #print("=========Iteration", it, "==========\n\n")
        matched_bboxes = getMatches(old_bboxes) # returns list

        processed_bboxes = []
        for matches in matched_bboxes:
            #print("Matches:", matches)
            if len(matches) > 1:   
                # check if new match isn't just a subset of an already calculated one
                if not set(matches).issubset(set(processed_bboxes)):
                    #print("Matches: ", matches, " Processed_bboxes:", processed_bboxes)
                    #print("intersection:",set(matches) & set(processed_bboxes))
                    #print("is subset?", set(matches).issubset(set(processed_bboxes)))

                    # calculate the new bbox coords using the mean values
                    new_tlx = np.mean([c[1] for c in [old_bboxes[m] for m in matches]])
                    new_tly = np.mean([c[2] for c in [old_bboxes[m] for m in matches]])
                    new_brx = np.mean([c[3] for c in [old_bboxes[m] for m in matches]])
                    new_bry = np.mean([c[4] for c in [old_bboxes[m] for m in matches]])
                    new_bboxes.append([old_bboxes[0][0], new_tlx, new_tly, new_brx, new_bry])

                    if show_plots:
                        # plot new bbox
                        ax2[it].add_patch(patches.Rectangle((new_tlx, new_bry), new_brx - new_tlx, new_tly - new_bry, edgecolor='g',facecolor=(0.4, 0.1, 1, 0.4)))
                        ax2[it].text(new_tlx + 0.05, new_bry + 0.05, matches )
                    for mbox in matches:
                        #if not set(processed_bboxes).issuperset(mbox):
                        if not mbox in processed_bboxes:
                            processed_bboxes.append(mbox)

            elif matches[0] not in processed_bboxes:
                #print(bboxes[matches[0]])
                new_bboxes.append(old_bboxes[matches[0]])
                processed_bboxes.append(matches[0])
                if show_plots:
                    ax2[it].add_patch(patches.Rectangle((old_bboxes[matches[0]][1], 
                                                       old_bboxes[matches[0]][4]),
                                                      old_bboxes[matches[0]][3] - old_bboxes[matches[0]][1], 
                                                      old_bboxes[matches[0]][2] - old_bboxes[matches[0]][4], 
                                                      edgecolor='y',facecolor=(0.4, 0.1, 1, 0.4)))

        #print("processed: ", processed_bboxes)
        matched_bboxes = new_bboxes
        old_bboxes = new_bboxes
        new_bboxes = []
        if show_plots:
            ax2[it].autoscale_view() 
            ax2[it].set_title("Iteration" + str(it))
        it += 1
    # end of while loop

   #print("Old Boxes[0]:", old_bboxes[0])
   #print("Old Boxes[0][1]:", old_bboxes[0][1])
   #print("Old Boxes[0][3]:", old_bboxes[0][3])
   #
    if show_plots:
        curax.autoscale_view()
        #ax2.autoscale_view()
        plt.show()
    print("Bounding boxes reduced from ", str(num_bboxes_original),"to", len(old_bboxes),".")
    
    # convert back to darknet format
    num_el = len(old_bboxes)
    i = 0
    while i < num_el:
    #for obbox in old_bboxes:
        #print(i)
        temp_bbox = [old_bboxes[i][0],
                    (old_bboxes[i][1] + old_bboxes[i][3])/2,
                    (old_bboxes[i][2] + old_bboxes[i][4])/2,
                    (old_bboxes[i][3] - old_bboxes[i][1]),
                    (old_bboxes[i][2] - old_bboxes[i][4])]
        #print("my_coords:", old_bboxes[i], "\ndarknet: ", temp_bbox)
        
        # clean up rounding errors and bad bboxes
        remove_bad_label = False
        for j in range(1, 5):
            if temp_bbox[j] > 1:
                if j < 3:
                    if temp_bbox[j] < 1.05:  # this looks like a rounding error
                        temp_bbox[j] = 1.0
                    else:
                        remove_bad_label = True
                else:  # don't want sizes larger than 1.0 ever
                    remove_bad_label = True
                    print("ERROR: coordinate ", i, "is suspiciously large, removing.")
            if temp_bbox[j] < 0:
                if j < 3:
                    if temp_bbox[j] > -0.05:  # this looks like a rounding error
                        temp_bbox[j] = 0.0
                    else:
                        remove_bad_label = True
                else:
                    remove_bad_label = True
                    print("ERROR: coordinate ", i, "is suspiciously small, removing.")
        if remove_bad_label == True:
            print("removing ", old_bboxes[i])
            del(old_bboxes[i])
            i -= 1  # decrease i since our list contains one element less now
            num_el -= 1
        else:
            old_bboxes[i] = temp_bbox
        i += 1
        #print("darknet2:", old_bboxes[i])
    return old_bboxes
    
    #fig.savefig("asd.png")
    
#aggregateBoundingBoxes()

In [None]:
file_path = "imgs"
fn = [os.path.join(file_path, f) for f in os.listdir(file_path) if (os.path.isfile(os.path.join(file_path, f)) and f.lower().endswith('.txt'))]

bad_labels = 0

# loop through all label files and find overlapping bounding boxes (i.e. same class, IOU > thresh)
for i, file in zip(range(1, len(fn)+1),fn):
    print("(",i,"/",len(fn),")")
    datafrm = pandas.read_csv(file, delimiter = ' ', header = None)
    
    new_df = aggregateBoundingBoxes(datafrm, show_plots = False)
    #print(new_df)
    #print("Bad elements:", sum([itm > 1 for sublist in new_df for itm in sublist[1:]]))
    if (sum([itm > 1 for sublist in new_df for itm in sublist[1:]]) > 0):
        bad_labels += 1
        print("WARNING: label is outside of image bounds (", bad_labels, ":", new_df)
        if i < 1000:
            print(datafrm)
        
    # rename old label file to <filename>.backup
    os.rename(file, file + ".backup")
    
    # save new label file
    new_df = pandas.DataFrame(new_df)
    #print(new_df)
    
    new_df.to_csv(file, sep = ' ', header = False, index = False)
    #new_df.to_csv(file.split('.')[0] + 'A.csv', sep = ' ', header = False, index = False)