In [1]:
import glob
import sys
import cv2
import matplotlib.pyplot as plt

In [2]:
def union_area(a,b):

    x = min(a[0], b[0])
    y = min(a[1], b[1])
    w = max(a[0]+a[2], b[0]+b[2]) - x
    h = max(a[1]+a[3], b[1]+b[3]) - y
    return w*h

def intersection_area(a,b):
    x = max(a[0], b[0])
    y = max(a[1], b[1])
    
    w = min(a[0]+a[2], b[0]+b[2]) - x
    h = min(a[1]+a[3], b[1]+b[3]) - y
    if w < 0 or h < 0:
        return 0
    else:
        return w * h

def str2int(a):
    return [int(x) for x in a]

def extract_boxes(fname):
    with open(fname) as f:
        content = f.readlines()
        f.close()
        content = [x.strip() for x in content]
        content = [str2int(x.split(' ')[-4:]) for x in content]
        return content

In [3]:
paths = glob.glob("/media/nasir/Drive1/code/SAR/AutomatedSARShipDetection/python_cfar/SAR-Ship-Dataset/detection-results/*.txt")
len(paths)


43819

In [10]:
def get_precision_recall(threshold):
    files_stats = {}


    falseNegative = 0
    truePositive = 0
    falsePositive = 0
    trueNegative = 0

    for index, path in enumerate(paths):
        pred_bboxes = extract_boxes(path)
        gt_bboxes = extract_boxes(path.replace('detection-results', 'ground-truth'))
        fp = 0; tp = 0; fn = 0
        box_index_of_tp = []
        for index_g, gt_box in enumerate(gt_bboxes):
            ious = []
            for index_p, pred_box in enumerate(pred_bboxes):
                iou = intersection_area(gt_box, pred_box) / union_area(gt_box, pred_box)
                if iou >threshold:
                    box_index_of_tp.append(index_p)
                    ious.append(iou)

            if len(ious) == 0:
                fn+=1
            elif len(ious) > 0:
                tp+=1

        diff = len(pred_bboxes) - (len(list(set(box_index_of_tp))))
        if diff > 0:
            fp+=diff

        falseNegative+=fn
        truePositive+=tp
        falsePositive+=fp

        files_stats[path.split('/')[-1].split('.')[0]] = {
            "falseNegative": fn,
            "truePositive": tp,
            "falsePositive": fp
        }

        sys.stdout.write(f"\r {index + 1} / {len(paths)}")
        sys.stdout.flush()
    print(f"\n\nfalsePositives: {falsePositive} , truePositives: {truePositive} , falseNegatives: {falseNegative}")
    recall = truePositive / (truePositive + falseNegative)
    precision = truePositive / (truePositive + falsePositive)
    return precision, recall
    

In [9]:
thresholds = [0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
precisions = []
recalls = []

for threshold in thresholds:
    precision, recall = get_precision_recall(threshold)
    precisions.append(precision)
    recalls.append(recall)
    print(f"\nthreshold: {threshold} recall: {round(recall * 100, 2)}% precision: {round(precision*100, 2)}% \n")

 43819 / 43819
falsePositives: 53343 , truePositives: 8605 , falseNegatives: 50930

threshold: 0.7 recall: 14.45% precision: 13.89% 

 43819 / 43819
falsePositives: 46160 , truePositives: 15796 , falseNegatives: 43739

threshold: 0.6 recall: 26.53% precision: 25.5% 

 40309 / 43819
falsePositives: 39661 , truePositives: 22315 , falseNegatives: 37220

threshold: 0.5 recall: 37.48% precision: 36.01% 

 43819 / 43819
falsePositives: 31117 , truePositives: 30897 , falseNegatives: 28638

threshold: 0.3 recall: 51.9% precision: 49.82% 

 43728 / 43819