In [None]:
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import re
from torchvision.ops import box_iou
import torch
import os
import re
from collections import OrderedDict
import json
from sklearn.metrics import precision_score, recall_score, f1_score
from collections import defaultdict

In [None]:
anatomies = [
    "abdomen", "aortic arch", "cardiac silhouette", "carina", "cavoatrial junction",
    "descending aorta", "left apical zone", "left cardiac silhouette", "left cardiophrenic angle",
    "left clavicle", "left costophrenic angle", "left hemidiaphragm", "left hilar structures",
    "left lower lung zone", "left lung", "left mid lung zone", "left upper abdomen",
    "left upper lung zone", "mediastinum", "right apical zone", "right atrium",
    "right cardiac silhouette", "right cardiophrenic angle", "right clavicle",
    "right costophrenic angle", "right hemidiaphragm", "right hilar structures",
    "right lower lung zone", "right lung", "right mid lung zone", "right upper abdomen",
    "right upper lung zone", "spine", "svc", "trachea", "upper mediastinum"
]

In [None]:
def load_json(path):
    with open(path, 'r') as f:
        return json.load(f)

In [None]:
ep1_multipleInstructions_formatted = load_json('/content/eval_all_batches_ep1_multipleInstructions_formatted.json')
ep1_multipleInstructions_weighted_formatted = load_json('/content/eval_all_batches_ep1_multipleInstructions_weighted_formatted.json')
ep1_singleInstruction_formatted = load_json('/content/eval_all_batches_ep1_singleInstruction_formatted.json')
ep1_singleInstruction_weighted_formatted = load_json('/content/eval_all_batches_ep1_singleInstruction_weighted_formatted.json')
ep2_singleInstruction_formatted = load_json('/content/eval_all_batches_ep2_singleInstruction_formatted.json')
ep3_multipleInstructions_formatted = load_json('/content/eval_all_batches_ep3_multipleInstructions_formatted.json')
ep3_singleInstruction_formatted = load_json('/content/eval_all_batches_ep3_singleInstruction_formatted.json')

In [None]:
import re

def extract_bboxes(text, bbx_dict):
    """
    Extract predicted bounding boxes and return a dictionary in which the key is the anatomy
    and the value is a tuple of lists containing multiple bounding boxes for that anatomy.
    If the number of anatomies and bounding boxes do not match, or if the anatomies in the predictions
    do not match those in bbx_dict, it will be counted as a mismatch.
    Invalid bounding boxes (those with zero width or height) are filtered out.
    """
    mismatch_count = 0
    text = text.lower()

    anatomy_regex = re.compile(r'\b(' + '|'.join(map(re.escape, sorted(bbx_dict.keys(), key=len, reverse=True))) + r')\b')
    anatomy_matches = anatomy_regex.findall(text)

    bbox_texts = re.findall(r"\[(.*?)\]", text)
    bbox_matches = []

    for bbox in bbox_texts:
        try:
            bbox_floats = list(map(float, bbox.split(',')))
            if bbox_floats[2] > bbox_floats[0] and bbox_floats[3] > bbox_floats[1]:  # Ensure width and height are positive
                bbox_matches.append(bbox_floats)
        except ValueError as e:
            print(f"Error parsing bbox: {bbox}. Skipping it. Error: {e}")
            continue
    if len(anatomy_matches) != len(bbox_matches):
        print("Anatomy and bounding box count mismatch.")
        mismatch_count += 1
        return {}, mismatch_count

    if set(anatomy_matches) != set(bbx_dict.keys()):
        print("Mismatch between predicted anatomies and ground truth anatomies.")
        mismatch_count += 1
        return {}, mismatch_count

    anatomy_to_bbox = {}
    for anatomy, bbox in zip(anatomy_matches, bbox_matches):
        if anatomy in anatomy_to_bbox:
            anatomy_to_bbox[anatomy] = tuple(list(anatomy_to_bbox[anatomy]) + [bbox])
        else:
            anatomy_to_bbox[anatomy] = (bbox,)

    return anatomy_to_bbox, mismatch_count


In [None]:
def calculate_mismatch_percentage(data):
    total_images = 0
    total_mismatches = 0

    for json_obj in data:
        text = json_obj['output']
        bbx_dict = json_obj['bbx_dict']
        _, mismatch_count = extract_bboxes(text, bbx_dict)
        total_images += 1
        total_mismatches += mismatch_count
    mismatch_percentage = (total_mismatches / total_images) * 100 if total_images > 0 else 0

    return mismatch_percentage, total_mismatches, total_images

In [None]:
mismatch_percentage, total_mismatches, total_images = calculate_mismatch_percentage(ep1_multipleInstructions_formatted)
print(f"Mismatch Percentage: {mismatch_percentage}%")
print(f"Total mismatches: {total_mismatches} out of {total_images} images")


In [None]:
mismatch_percentage, total_mismatches, total_images = calculate_mismatch_percentage(ep1_multipleInstructions_weighted_formatted)
print(f"Mismatch Percentage: {mismatch_percentage}%")
print(f"Total mismatches: {total_mismatches} out of {total_images} images")

In [None]:
mismatch_percentage, total_mismatches, total_images = calculate_mismatch_percentage(ep1_singleInstruction_formatted)
print(f"Mismatch Percentage: {mismatch_percentage}%")
print(f"Total mismatches: {total_mismatches} out of {total_images} images")

In [None]:
mismatch_percentage, total_mismatches, total_images = calculate_mismatch_percentage(ep1_singleInstruction_weighted_formatted)
print(f"Mismatch Percentage: {mismatch_percentage}%")
print(f"Total mismatches: {total_mismatches} out of {total_images} images")

In [None]:
mismatch_percentage, total_mismatches, total_images = calculate_mismatch_percentage(ep1_singleInstruction_weighted_formatted)
print(f"Mismatch Percentage: {mismatch_percentage}%")
print(f"Total mismatches: {total_mismatches} out of {total_images} images")

In [None]:
mismatch_percentage, total_mismatches, total_images = calculate_mismatch_percentage(ep2_singleInstruction_formatted)
print(f"Mismatch Percentage: {mismatch_percentage}%")
print(f"Total mismatches: {total_mismatches} out of {total_images} images")

In [None]:
mismatch_percentage, total_mismatches, total_images = calculate_mismatch_percentage(ep3_multipleInstructions_formatted)
print(f"Mismatch Percentage: {mismatch_percentage}%")
print(f"Total mismatches: {total_mismatches} out of {total_images} images")

In [None]:
mismatch_percentage, total_mismatches, total_images = calculate_mismatch_percentage(ep3_singleInstruction_formatted)
print(f"Mismatch Percentage: {mismatch_percentage}%")
print(f"Total mismatches: {total_mismatches} out of {total_images} images")