In [5]:
import json
from collections import defaultdict

def load_json(filepath):
    with open(filepath, 'r') as f:
        return json.load(f)

def compare_extracted_criteria(gpt_data, ground_truth_data):
    results = []
    mismatch_counts = defaultdict(lambda: defaultdict(int))  # Track mismatches by section and key

    for gpt_trial, gt_trial in zip(gpt_data, ground_truth_data):
        gpt_criteria = gpt_trial.get("Extracted Criteria", {})
        gt_criteria = gt_trial.get("Extracted Criteria", {})

        for section in gpt_criteria:
            # Skip specific sections
            if section in ["text_relating_to_medical_history", "unused_text"]:
                continue

            gpt_section = gpt_criteria.get(section, {})
            gt_section = gt_criteria.get(section, {})

            for key, gt_value in gt_section.items():
                gpt_value = gpt_section.get(key, None)
                match = gt_value == gpt_value


                results.append({
                    "section": section,
                    "key": key,
                    "ground_truth": gt_value,
                    "gpt_extracted": gpt_value,
                    "match": match
                })

                # Increment mismatch count for the key in the section if there's a mismatch
                if not match:
                    mismatch_counts[section][key] += 1

    return results, mismatch_counts


def calculate_accuracy(results):
    total = len(results)
    matches = sum(1 for result in results if result["match"])
    return matches / total if total > 0 else 0


def main():
    gpt_file = "extracted_criteria_30.json"
    ground_truth_file = "ground_truth_30.json"

    # Load data
    gpt_data = load_json(gpt_file)
    ground_truth_data = load_json(ground_truth_file)

    # Compare criteria
    comparison_results, mismatch_counts = compare_extracted_criteria(gpt_data, ground_truth_data)

    accuracy = calculate_accuracy(comparison_results)
    print(f"Overall Accuracy: {accuracy:.2f}")

    mismatches = [r for r in comparison_results if not r["match"]]
    if mismatches:
        print("\nMismatched Entries:")
        for mismatch in mismatches:
            print(mismatch)

    print("\nMismatch Count by Section and Key:")
    for section, keys in mismatch_counts.items():
        print(f"Section: {section}")
        for key, count in keys.items():
            print(f"  {key}: {count} mismatches")

if __name__ == "__main__":
    main()

Overall Accuracy: 0.69

Mismatched Entries:
{'section': 'demographics_and_general_characteristics', 'key': 'signed_consent', 'ground_truth': 'Yes', 'gpt_extracted': None, 'match': False}
{'section': 'demographics_and_general_characteristics', 'key': 'protocol_compliance', 'ground_truth': 'Yes', 'gpt_extracted': None, 'match': False}
{'section': 'disease_characteristics', 'key': 'confirmed_locally_recurrent_breast_cancer', 'ground_truth': 'Yes', 'gpt_extracted': None, 'match': False}
{'section': 'health_and_organ_function', 'key': 'tumor_tissue_lesion_availability', 'ground_truth': 'Yes', 'gpt_extracted': None, 'match': False}
{'section': 'demographics_and_general_characteristics', 'key': 'signed_consent', 'ground_truth': 'Yes', 'gpt_extracted': None, 'match': False}
{'section': 'demographics_and_general_characteristics', 'key': 'protocol_compliance', 'ground_truth': 'Yes', 'gpt_extracted': None, 'match': False}
{'section': 'disease_characteristics', 'key': 'confirmed_locally_recurrent_