In [None]:
import os
import cv2
import json
import torch
import pydicom
import requests
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score

In [None]:
# modify the following two path to your local path
img_id_map_file = 'path/to/image_id_map.json'   # a dictionary file with the format {image_name: image_id, ...}
cat_id_map_file = 'path/to/category_id_map.json'   # a dictionary file with the format {label_name: label_id, ...}

In [None]:
def load_img_id_map():
    with open(img_id_map_file) as fp:
        return json.load(fp)

def load_cat_id_map():
     with open(cat_id_map_file) as fp:
        return json.load(fp)

In [None]:
img_id_map = load_img_id_map()
cat_id_map = load_cat_id_map()

In [None]:
id_img_map = {v: k for k, v in img_id_map.items()}
id_cat_map = {v: k for k, v in cat_id_map.items()}

In [None]:
def merge_box_max(results, score_filter=0.5, ca_percent=0.9):
    boxes = results['boxes']
    labels = results['labels']
    scores = results['scores']
    # filter
    bls = [(results['boxes'][i], results['labels'][i], results['scores'][i])
           for i, score in enumerate(results['scores']) if score > score_filter]
    if len(bls) == 0:
        return [], [], []
    # merge box
    # If the area overlaps by more than 80%, the label is the one with the higher score, and the bbox is the union of the two.
    new_bls = [bls[0]]
    for bbox, label, score in bls[1:]:
        is_max_cover = False
        for i, (nbox, nlabel, nscore) in enumerate(new_bls):
            bx1, by1, bx2, by2 = bbox
            nx1, ny1, nx2, ny2 = nbox
            cx1, cy1, cx2, cy2 = max(bx1, nx1), max(by1, ny1), min(bx2, nx2), min(by2, ny2)
            if cx1 <= cx2 and cy1 <= cy2: # confirm overlap
                area1 = (bx2-bx1) * (by2-by1)
                area2 = (nx2-nx1) * (ny2-ny1)
                areac = (cx2-cx1) * (cy2-cy1)
                is_max_cover = areac/area1>ca_percent or areac/area2>ca_percent # if overlap meets criteria
            if is_max_cover:
                dc_box = (min(bx1, nx1), min(by1, ny1), max(bx2, nx2), max(by2, ny2))
                dc_label = label if score>nscore else nlabel
                dc_score = max(score, nscore)
                new_bls[i] = (dc_box, dc_label, dc_score)
                break
        if not is_max_cover: # No overlap or too little overlap
            new_bls.append((bbox, label, score))
                
    n_boxes, n_labels, n_scores = list(zip(*new_bls))
    return n_labels, n_boxes, n_scores

In [None]:
def load_ground_truth_test(img_id, test_dataset='int-test-inference', dtype='T2_Ax'):
    with open('annotations/%s_%s.json' % (test_dataset, dtype)) as fp:
        ans = json.load(fp)
    return [an for an in ans['annotations'] if an['image_id']==img_id]

In [None]:
def align_res_and_ground_truth(img_id, test_epoch, results, ground_truth, score_filter=0.5, ca_percent=0.5):
    # ca_percent: cross area percentage
    labels, boxes, scores = merge_box_max(results, score_filter)
    for merge_ca_percent in [0.8, 0.7, 0.6, 0.5]:
        if len(labels)>3:
            labels, boxes, scores = merge_box_max({'boxes': boxes, 'labels': labels, 'scores': scores}, score_filter, merge_ca_percent)
    
    label_pair = []
    paired_predicted = set()
    for i, an in enumerate(ground_truth):
        tbox, tlabel = an['bbox'], an['category_id']
        tx1, ty1, tx2, ty2 = tbox[0], tbox[1], tbox[0]+tbox[2], tbox[1]+tbox[3]
        paired_index = []
        for j, box in enumerate(boxes):
            nx1, ny1, nx2, ny2 = box
            cx1, cy1, cx2, cy2 = max(tx1, nx1), max(ty1, ny1), min(tx2, nx2), min(ty2, ny2)
            if cx1 <= cx2 and cy1 <= cy2: # area overlap
                areat = (tx2-tx1) * (ty2-ty1)
                arean = (nx2-nx1) * (ny2-ny1)
                areac = (cx2-cx1) * (cy2-cy1)
                if areac/areat>ca_percent or areac/arean>ca_percent: # large overlap, store the index
                    paired_index.append([j, (areac/areat>ca_percent+areac/arean>ca_percent)/2])
        if len(paired_index) > 0:
            paired_index.sort(key=lambda ip: ip[1], reverse=True)
            paired_index = paired_index[0][0]
            label_pair.append([tlabel, labels[paired_index]])
            paired_predicted.add(paired_index)
        # give a fake label for image has ground true but model does not give a prediction
        else:
            label_pair.append([tlabel, 7])

    unpaired_results = set(range(len(labels)))-paired_predicted
    label_pair.extend([(7, labels[i]) for i in unpaired_results])
    return label_pair


def cal_metrics(test_dataset, test_epoch, score_filter=0.5, ca_percent=0.5, output_dir='t2ax_output'):
    # load the detect results
    test = torch.load('/path/to/%s/%s-%s.pth' % (output_dir, test_dataset, test_epoch))
    t2 = 'T2_Ax' if test_dataset[:3] == 'int' else 'T2*Ax'
    
    label_pairs = []
    for img_id, res in test.items():
        ground_truth = load_ground_truth_test(img_id, test_dataset, t2)
        label_pair = align_res_and_ground_truth(img_id, test_epoch, res, ground_truth, score_filter, ca_percent)
        label_pairs.extend(label_pair)
 
    df = pd.DataFrame(label_pairs, columns=['truth_label', 'predict_label'])
    test_log = '/path/to/%s/test_log_%s.txt' % (output_dir, test_dataset)
    with open(test_log, 'a') as fp:
        acc = accuracy_score(list(zip(*label_pairs))[0], list(zip(*label_pairs))[1])
        overall_recall = df[(df.truth_label<7)&(df.predict_label<7)].shape[0] / df[df.truth_label<7].shape[0]
        fp.write('%s\t%s\tACC:%.4f\tRecall:%.4f\n' % (test_epoch, score_filter, acc, overall_recall))
    for i in range(7):
        label_count = df[df.truth_label==i].shape[0]
        true_count = df[(df.truth_label==i)&(df.predict_label==i)].shape[0]
        with open(test_log, 'a') as fp:
            fp.write('%s\t%4s/%4s\t%.4f\n' % (i, true_count, label_count, true_count/label_count*100))

for test_epoch in range(9, 100, 10): 
    for score_filter in [0.1, 0.25, 0.5, 0.7, 0.9]:
        cal_metrics('int-test-reference', test_epoch, score_filter, output_dir='t2ax_output')

In [None]:
def cal_results(test_dataset, test_epoch, score_filter=0.5, ca_percent=0.5, output_dir='t2ax_output'):
    test = torch.load('/path/to/%s/%s-%s.pth' % (output_dir, '%s-test-reference' % test_dataset[:3], test_epoch))
    t2 = 'T2_Ax' if test_dataset[:3] == 'int' else 'T2*Ax'
    
    label_pairs = []
    for img_id, res in test.items():
        ground_truth = load_ground_truth_test(img_id, test_dataset, t2)
        label_pair = align_res_and_ground_truth(img_id, test_epoch, res, ground_truth, score_filter, ca_percent)
        label_pairs.extend(label_pair)
 
    df = pd.DataFrame(label_pairs, columns=['truth_label', 'predict_label'])
    
    for i in range(7):
        print(i, df[df.truth_label==i].shape[0])
    
    overall_recall = df[(df.truth_label<7)&(df.predict_label<7)].shape[0] / df[df.truth_label<7].shape[0]
    print("Overall Recall:%.6f" % overall_recall)
    
    df_central = df[df.truth_label<=3]
    central_recall = df_central[df_central.predict_label<=3].shape[0] / df_central.shape[0]
    print("Central Recall:%.6f" % central_recall)
    
    df_side = df[(df.truth_label>3)&(df.truth_label<7)]
    side_recall = df_side[(df_side.predict_label>3)&(df_side.predict_label<7)].shape[0] / df_side.shape[0]
    print("Side Recall:%.6f" % side_recall)
    
    # multi-class
    df_overall = df[(df.truth_label<7)&(df.predict_label<7)]
    overall_7_acc = accuracy_score(list(df_overall.truth_label), list(df_overall.predict_label))
    print('7 Classes Acc:%.6f' % overall_7_acc)
    
    df_central = df[(df.truth_label<=3)&(df.predict_label<=3)]
    central_4_acc = accuracy_score(list(df_central.truth_label), list(df_central.predict_label))
    print('Central 4 Classes Acc:%.6f' % central_4_acc)
    
    df_side = df[((df.truth_label>3)&(df.truth_label<7))&((df.predict_label>3)&(df.predict_label<7))]
    side_3_acc = accuracy_score(list(df_side.truth_label), list(df_side.predict_label))
    print('Side 3 Classes Acc:%.6f' % side_3_acc)
    
    # binary class
    cat_72_map = {0:0, 1:0, 2:1, 3:1, 4:2, 5:2, 6:3, 7:4}
    pairs_72 = [(cat_72_map[p[0]], cat_72_map[p[1]]) for p in label_pairs]
    df = pd.DataFrame(pairs_72, columns=['truth_label', 'predict_label'])
    
    df_2 = df[(df.truth_label<=3)&(df.predict_label<=3)]
    overall_4_acc = accuracy_score(list(df_2.truth_label), list(df_2.predict_label))
    print('Overall 4 Classes Acc:%.6f' % overall_4_acc)
    
    df_central = df[(df.truth_label<=1)&(df.predict_label<=1)]
    central_2_acc = accuracy_score(list(df_central.truth_label), list(df_central.predict_label))
    print('Central 2 Classes Acc:%.6f' % central_2_acc)
    
    df_side = df[((df.truth_label>1)&(df.truth_label<4))&((df.predict_label>1)&(df.predict_label<4))]
    side_2_acc = accuracy_score(list(df_side.truth_label), list(df_side.predict_label))
    print('Side 2 Classes Acc:%.6f' % side_2_acc)
    
cal_results('int-test-reference', 29, 0.7)

In [None]:
def confusion_matrix(test_dataset, test_epoch, score_filter=0.5, ca_percent=0.5, output_dir='t2ax_output'):
    test = torch.load('/path/to/%s/%s-%s.pth' % (output_dir, '%s-test-reference' % test_dataset[:3], test_epoch))
    t2 = 'T2_Ax' if test_dataset[:3] == 'int' else 'T2*Ax'
    
    label_pairs = []
    for img_id, res in test.items():
        if id_img_map[img_id].find('NUH257')>=0:
            continue
        ground_truth = load_ground_truth_test(img_id, test_dataset, t2)
        label_pair = align_res_and_ground_truth(img_id, test_epoch, res, ground_truth, score_filter, ca_percent)
        label_pairs.extend(label_pair)

    df = pd.DataFrame(label_pairs, columns=['truth_label', 'predict_label'])
        
    # confusion matrix for central
    # multi-class
    c4_cm = {'t%s' % i: {'p%s' % j: df[(df.truth_label==i)&(df.predict_label==j)].shape[0] for j in [0, 1, 2, 3]} for i in [0, 1, 2, 3]}
    print(pd.DataFrame(c4_cm))

    # binary-class
    c2_cm = {'t%s' % i: {'p%s' % j: df[(df.truth_label.isin(i))&(df.predict_label.isin(j))].shape[0] for j in [[0, 1], [2, 3]]} for i in [[0, 1], [2, 3]]}
    print(pd.DataFrame(c2_cm))
        
    # for side
    # multi-class
    s3_cm = {'t%s' % i: {'p%s' % j: df[(df.truth_label==i)&(df.predict_label==j)].shape[0] for j in [4, 5, 6]} for i in [4, 5, 6]}
    print(pd.DataFrame(s3_cm))

    # binary-class
    s2_cm = {'t%s' % i: {'p%s' % j: df[(df.truth_label.isin(i))&(df.predict_label.isin(j))].shape[0] for j in [[4, 5], [6]]} for i in [[4, 5], [6]]}
    print(pd.DataFrame(s2_cm))

    # summary info
    print('all', df[df.truth_label<7].shape[0], df[df.predict_label<7].shape[0], df[(df.truth_label<7)&(df.predict_label<7)].shape[0])
    for i in range(7):
        print(i, df[df.truth_label==i].shape[0], df[df.predict_label==i].shape[0], df[(df.truth_label==i)&(df.predict_label==i)].shape[0])
    for g in [[0, 1], [2, 3], [4, 5], [6]]:
        print(g, df[df.truth_label.isin(g)].shape[0], df[df.predict_label.isin(g)].shape[0], df[(df.truth_label.isin(g))&(df.predict_label.isin(g))].shape[0])
    
confusion_matrix('int-test-reference', 29, 0.7)