In [1]:
import pandas as pd

# we start with o1 responses here:
o1_responses = pd.read_csv('../supplemental_data/gpt_o1_response/gpt_o1_response.csv', index_col=0)
o1_responses

Unnamed: 0,problem,service_answer,metadata
0,I am running an experiment on a clinical case ...,"I'm sorry, but I cannot generate a differentia...",PMID_34722527_individual_103_7_Hui_Wang_Compre...
1,I am running an experiment on a clinical case ...,1. VACTERL association \n2. Feingold syndrome ...,PMID_32730804_Individual_3_en-prompt.txt
2,I am running an experiment on a clinical case ...,1. Autosomal recessive hyper-IgE syndrome (DOC...,PMID_19776401_Patient_6_1_en-prompt.txt
3,I am running an experiment on a clinical case ...,1. Sclerosteosis \n2. Van Buchem disease \n3. ...,PMID_20358596_Patient_A_en-prompt.txt
4,I am running an experiment on a clinical case ...,1. Smith-Lemli-Opitz syndrome \n2. ATR-X syndr...,PMID_36586412_8_en-prompt.txt
...,...,...,...
5262,I am running an experiment on a clinical case ...,1. Wolfram syndrome \n2. Alström syndrome \n3....,PMID_9817917_Family_4_individual_13070_en-prom...
5263,I am running an experiment on a clinical case ...,1. GM1 gangliosidosis \n2. Galactosialidosis \...,PMID_1907800_TS_en-prompt.txt
5264,I am running an experiment on a clinical case ...,1. Mitochondrial neurogastrointestinal encepha...,PMID_28673863_28673863_P1_en-prompt.txt
5265,I am running an experiment on a clinical case ...,1. Down syndrome\n2. Kabuki syndrome \n3. 22q1...,PMID_31021519_individual_SATB2_112_en-prompt.txt


In [2]:
import re
from oaklib.interfaces.text_annotator_interface import TextAnnotationConfiguration
from oaklib.interfaces.text_annotator_interface import TextAnnotatorInterface
import logging
from typing import Tuple, List

# Compile a regex pattern to detect lines starting with "Differential Diagnosis:"
dd_re = re.compile(r"^[^A-z]*Differential Diagnosis")
related_re = re.compile(r'\s*\S+\-related\s*')  # Pattern to detect "something-related"

# General terms to exclude when doing inexact matching. 
# These are MONDO:0700096 human disease and its immediate subclasses
default_exclude_list = [
    "MONDO:0700096", # human disease
    "MONDO:0002022", "MONDO:0002025", "MONDO:0002051", "MONDO:0002081", "MONDO:0002118",
    "MONDO:0002254", "MONDO:0002409", "MONDO:0002657", "MONDO:0003847", "MONDO:0003900",
    "MONDO:0004335", "MONDO:0004995", "MONDO:0005039", "MONDO:0005046", "MONDO:0005066",
    "MONDO:0005071", "MONDO:0005087", "MONDO:0005137", "MONDO:0005151", "MONDO:0005550",
    "MONDO:0005570", "MONDO:0006858", "MONDO:0019040", "MONDO:0019303", "MONDO:0020683",
    "MONDO:0021147", "MONDO:0021166", "MONDO:0021669", "MONDO:0024458", "MONDO:0024623",
    "MONDO:0029000", "MONDO:0043459", "MONDO:0043543", "MONDO:0043839", "MONDO:0044970",
    "MONDO:0044991", "MONDO:0045024", "MONDO:0100086", "MONDO:0100366", "MONDO:0700003",
    "MONDO:0700007", "MONDO:0700220"
]

# Function to clean and remove "Differential Diagnosis" header if present
def clean_service_answer(answer: str) -> str:
    """Remove the 'Differential Diagnosis' header if present, and clean the first line."""
    lines = answer.split('\n')
    # Filter out any line that starts with "Differential Diagnosis:"
    cleaned_lines = [line for line in lines if not dd_re.match(line)]
    return '\n'.join(cleaned_lines)

# Clean the diagnosis line by removing leading numbers, periods, asterisks, and spaces
def clean_diagnosis_line(line: str) -> str:
    """Remove leading numbers, asterisks, and unnecessary punctuation/spaces from the diagnosis."""
    line = re.sub(r'^\**\d+\.\s*', '', line)  # Remove leading numbers and periods
    line = line.strip('*')  # Remove asterisks around the text
    return line.strip()  # Strip any remaining spaces

# Split a diagnosis into its main name and synonym if present
def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]:
    """Split the diagnosis into main name and synonym (if present in parentheses)."""
    match = re.match(r'^(.*)\s*\((.*)\)\s*$', diagnosis)
    if match:
        main_diagnosis, synonym = match.groups()
        return main_diagnosis.strip(), synonym.strip()
    return diagnosis, None  # Return the original diagnosis if no synonym is found

# Remove the "-related" part of the diagnosis if present
def strip_related_phrase(diagnosis: str) -> str:
    """Remove the '[anything]-related' part from the beginning of the diagnosis."""
    return related_re.sub('', diagnosis).strip()

# Perform grounding on the text to MONDO ontology and return the result
def perform_grounding(
    annotator: TextAnnotatorInterface,
    diagnosis: str,
    exact_match: bool = True,
    verbose: bool = False,
    include_list: List[str] = ["MONDO:"],
    exclude_list: List[str] = default_exclude_list
) -> List[Tuple[str, str]]:
    """
    Perform grounding for a diagnosis. The 'exact_match' flag controls whether exact or inexact
    (partial) matching is used. Filter results to include only CURIEs that match the 'include_list',
    and exclude results that match the 'exclude_list'.
    Remove redundant groundings from the result.
    """
    config = TextAnnotationConfiguration(matches_whole_text=exact_match)
    annotations = list(annotator.annotate_text(diagnosis, configuration=config))

    # Filter and remove duplicates, while excluding unwanted general terms
    filtered_annotations = list(
        {
            (ann.object_id, ann.object_label)
            for ann in annotations
            if any(ann.object_id.startswith(prefix) for prefix in include_list)
            and ann.object_id not in exclude_list
        }
    )
    
    if filtered_annotations:
        return filtered_annotations
    else:
        match_type = "exact" if exact_match else "inexact"
        if verbose:
            logging.warning(f"No {match_type} grounded IDs found for: {diagnosis}")
        return [('N/A', 'No grounding found')]

# Ground the diagnosis text to MONDO ontology
def ground_diagnosis_text_to_mondo(
    annotator: TextAnnotatorInterface,
    differential_diagnosis: str,
    verbose: bool = False,
    include_list: List[str] = ["MONDO:"],
    exclude_list: List[str] = default_exclude_list
) -> List[Tuple[str, List[Tuple[str, str]]]]:
    results = []
    
    # Split the input into lines and process each one
    for line in differential_diagnosis.splitlines():
        clean_line = clean_diagnosis_line(line)
        
        # Skip header lines like "**Differential diagnosis:**"
        if not clean_line or "Differential diagnosis" in clean_line.lower():
            continue
        
        # Try grounding the full line first (exact match)
        grounded = perform_grounding(annotator, clean_line, exact_match=True, verbose=verbose, include_list=include_list, exclude_list=exclude_list)
        
        # If grounding fails and there is a synonym (text with parentheses), split and ground both parts
        if grounded == [('N/A', 'No grounding found')]:
            main_diagnosis, synonym = split_diagnosis_and_synonym(clean_line)
            if synonym:
                main_grounded = perform_grounding(annotator, main_diagnosis, exact_match=True, verbose=verbose, include_list=include_list, exclude_list=exclude_list)
                synonym_grounded = perform_grounding(annotator, synonym, exact_match=True, verbose=verbose, include_list=include_list, exclude_list=exclude_list)
                # Combine the results if any of them found results
                if main_grounded != [('N/A', 'No grounding found')] or synonym_grounded != [('N/A', 'No grounding found')]:
                    grounded = main_grounded + synonym_grounded
                    grounded = list(set(grounded))  # Remove duplicates
        
        # If the diagnosis contains "-related", try stripping it and exact matching again
        if grounded == [('N/A', 'No grounding found')] and related_re.search(clean_line):
            stripped_line = strip_related_phrase(clean_line)
            grounded = perform_grounding(annotator, stripped_line, exact_match=True, verbose=verbose, include_list=include_list, exclude_list=exclude_list)

        # If neither the main diagnosis, the synonym, nor the stripped "-related" version could be grounded, try inexact matching
        if grounded == [('N/A', 'No grounding found')]:
            if verbose:
                logging.warning(f"Exact grounding failed for: {clean_line}, attempting inexact match.")
            grounded = perform_grounding(annotator, clean_line, exact_match=False, verbose=verbose, include_list=include_list, exclude_list=exclude_list)
        
        # If still no grounding is found, log the final failure
        if grounded == [('N/A', 'No grounding found')]:
            if verbose:
                logging.warning(f"Final grounding failed for: {clean_line}")
        
        # Append the grounded results (even if no grounding was found)
        results.append((clean_line, grounded))

    return results

In [3]:
# Get the OAK annotator for MONDO
from oaklib import get_adapter
# Set up OAK SQLite implementation for MONDO
annotator = get_adapter("sqlite:obo:mondo")

In [None]:
##
## RUN SOME TESTS
##

# Example grounding with OAK annotation - testing on a sample input
differential_diagnosis_text = """
**Differential Diagnosis:**
1. Branchiooculofacial syndrome
2. Unicorn syndrome
3. Cystic fibrosis
4. 22q11.2 deletion syndrome (Velocardiofacial syndrome)
**5. ATP6V0A4-related distal renal tubular acidosis**
"""

# Cleaning and grounding the sample differential diagnosis text
cleaned_text = clean_service_answer(differential_diagnosis_text)
# Assert that the cleaning process returns non-empty text
assert cleaned_text != "", "Cleaning failed: the cleaned text is empty."

# Define the expected result for the sample input
expected_result = [
    ('Branchiooculofacial syndrome', [('MONDO:0007235', 'branchiooculofacial syndrome')]), 
    ('Unicorn syndrome', [('N/A', 'No grounding found')]), 
    ('Cystic fibrosis', [('MONDO:0009061', 'cystic fibrosis')]), 
    ('22q11.2 deletion syndrome (Velocardiofacial syndrome)', [
        ('MONDO:0018923', '22q11.2 deletion syndrome'), 
        ('MONDO:0008564', 'DiGeorge syndrome'), 
        ('MONDO:0008644', 'velocardiofacial syndrome')
    ]), 
    ('ATP6V0A4-related distal renal tubular acidosis', [('MONDO:0015827', 'distal renal tubular acidosis')])
]

# Ground the cleaned text to MONDO
result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False)
print("Grounding Result:")
print(result)

# Assert that the grounded result matches the expected output
assert len(result) == len(expected_result), "Grounding result length does not match expected result length"

for res_item, expected_item in zip(result, expected_result):
    # First, assert that the diagnosis name matches exactly
    assert res_item[0] == expected_item[0], f"Diagnosis mismatch: {res_item[0]} != {expected_item[0]}"
    
    # Then, assert that the grounding list matches, ignoring order
    assert set(res_item[1]) == set(expected_item[1]), f"Grounding mismatch for {res_item[0]}"

In [None]:
# Apply the cleaning and grounding functions directly to the 'service_answer' column with progress bar
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

# Now you can use progress_apply
o1_responses['grounded_diagnosis'] = o1_responses['service_answer'].progress_apply(
    lambda x: ground_diagnosis_text_to_mondo(annotator, clean_service_answer(x), verbose=False)
)

# Save the DataFrame with the new 'grounded_diagnosis' column to a CSV file
output_file = '../supplemental_data/gpt_o1_response/gpt_o1_response_grounded.csv'
o1_responses.to_csv(output_file, index=False)

# Display a sample of the updated DataFrame
o1_responses.head()

In [None]:
# Count the number of items with no grounding found (about 1.6%)
#  grep -v "I'm sorry" ../supplemental_data/gpt_o1_response/gpt_o1_response_grounded.csv | grep -o "No grounding found" | wc -l 
#  551
# (.venv) ~/PythonProject/malco/notebooks short_letter $ grep -o "('MONDO:[^']*', '[^']*')" ../supplemental_data/gpt_o1_response/gpt_o1_response_grounded.csv | wc -l
# Compare to the number of grounded items:
# 34539
# so about 98.4% of the items are grounded.

# We'll need to run these through OntoGPT I think

In [None]:
# Load the DataFrame with the grounded diagnosis text (not using o1_responses from above to avoid re-running the previous cell)
o1_responses = pd.read_csv('../supplemental_data/gpt_o1_response/gpt_o1_response_grounded.csv')
o1_responses

In [None]:
import csv

# Initialize an empty dictionary to store the result
correct_answers_dict = {}

# Define the file path
file_path = '../supplemental_data/correct_results.tsv'

# Read the TSV file and populate the dictionary
with open(file_path, 'r', newline='') as tsvfile:
    reader = csv.reader(tsvfile, delimiter='\t')
    for row in reader:
        # Assign each column to the corresponding variable
        correct_disease_name = row[0]
        correct_ID = row[1]
        prompt_file_name = row[2]
        
        # Populate the dictionary
        correct_answers_dict[prompt_file_name] = (correct_ID, correct_disease_name)

correct_answers_dict

In [None]:
from tqdm import tqdm
from malco.post_process.mondo_score_utils import score_grounded_result
import warnings
import csv
import ast
from oaklib import get_adapter
from pathlib import Path

dont_nuke_existing_output = False

# Create the directory if it doesn't exist; if it does, raise an error
output_dir = Path("../outputdir_all_2024_07_04/gpt_o1_disease_results")
if output_dir.exists():
    if dont_nuke_existing_output:
        raise FileExistsError(f"Directory {output_dir} already exists. Please remove it first.")
    else:
        warnings.warn(f"Directory {output_dir} already exists. Existing files may be overwritten.")
else:
    output_dir.mkdir(parents=True)

# Function to write results to a file
def write_result_to_file(file_path, results):
    with open(file_path, mode='w', newline='') as file:
        writer = csv.writer(file, delimiter='\t')
        # Write header
        writer.writerow(["rank", "disease_name", "disease_identifier", "correct_ID", "grounded_score", "is_correct"])
        # Write each result
        for result in results:
            writer.writerow(result)

# Initialize Mondo adapter
mondo = get_adapter("sqlite:obo:mondo")

# Iterate over each row in the 'o1_responses' DataFrame, with a progress bar
for index, row in tqdm( o1_responses.iterrows(), total=len(o1_responses) ):
    grounded_diagnoses_str = row['grounded_diagnosis']
    
    # Ensure grounded_diagnosis is deserialized from a string to a list
    try:
        grounded_diagnoses = ast.literal_eval(grounded_diagnoses_str)
    except (ValueError, SyntaxError) as e:
        print(f"Error parsing grounded diagnosis for index {index}: {e}")
        continue
    
    metadata = row['metadata']  # Assuming this field exists in o1_responses
    correct_disease = correct_answers_dict.get(metadata)  # Get correct ID from the dict
    
    if not correct_disease:
        logging.warning(f"No correct ID found for metadata: {metadata}")
        continue  # Skip rows with no correct ID
    
    results = []
    
    # Loop through each grounded diagnosis and score them
    for rank, (disease_name, grounded_list) in enumerate(grounded_diagnoses, start=1):
        for grounded_id, _ in grounded_list: # this is a list because there may be multiple groundings
            grounded_score = score_grounded_result(grounded_id, correct_disease[0], mondo)
            is_correct = grounded_score > 0 # Score > 0 means either exact or subclass match
            
            # Create a result row
            result_row = [rank, disease_name, grounded_id, correct_disease, grounded_score, is_correct]
            results.append(result_row)
    
    # Define the output file path
    output_file = output_dir / f"{metadata}.tsv"
    
    # Write results to file
    write_result_to_file(output_file, results)

print(f"Finished writing scored results to {output_dir}")

In [None]:
import os
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt

def compute_summary_statistics(input_dir, output_file, output_plot):
    # Initialize the counter for each rank
    rank_counter = Counter()

    # Iterate through all files in the directory ending with .tsv
    for filename in os.listdir(input_dir):
        if filename.endswith('.tsv'):
            filepath = os.path.join(input_dir, filename)
            # Read the TSV file
            df = pd.read_csv(filepath, sep='\t')

            # Find the first occurrence of the correct diagnosis
            correct_rank = df[df['is_correct'] == True].index.min() + 1 if not df[df['is_correct'] == True].empty else None

            # Increment the appropriate counter based on the rank or nf if not found
            if correct_rank is not None and 1 <= correct_rank <= 10:
                rank_counter[f'n{correct_rank}'] += 1
            else:
                rank_counter['nf'] += 1

    # Get the total number of records processed
    total_files = sum(rank_counter.values())

    # Prepare the row to be written to the output file (without the 'lang' column)
    output_row = [
        rank_counter.get('n1', 0),
        rank_counter.get('n2', 0),
        rank_counter.get('n3', 0),
        rank_counter.get('n4', 0),
        rank_counter.get('n5', 0),
        rank_counter.get('n6', 0),
        rank_counter.get('n7', 0),
        rank_counter.get('n8', 0),
        rank_counter.get('n9', 0),
        rank_counter.get('n10', 0),
        rank_counter.get('n10', 0) / total_files if total_files else 0,  # n10p: proportion of n10 hits
        rank_counter.get('nf', 0)
    ]

    # Write the results to the output file (without 'lang' column)
    with open(output_file, 'w') as f:
        f.write('n1\tn2\tn3\tn4\tn5\tn6\tn7\tn8\tn9\tn10\tn10p\tnf\n')
        f.write('\t'.join(map(str, output_row)) + '\n')

    print(f"Summary statistics written to {output_file}")

    # Generate the plot
    hits = ['Top-1', 'Top-3', 'Top-10']
    percentages = [
        rank_counter.get('n1', 0) / total_files * 100 if total_files else 0,
        sum(rank_counter.get(f'n{i}', 0) for i in range(1, 4)) / total_files * 100 if total_files else 0,
        sum(rank_counter.get(f'n{i}', 0) for i in range(1, 11)) / total_files * 100 if total_files else 0,
    ]

    # Plotting
    plt.figure(figsize=(10, 6))
    plt.bar(hits, percentages, color=['blue', 'green', 'orange'])
    plt.xlabel('Hits')
    plt.ylabel('Percent of cases')
    plt.title('Top-k accuracy of correct diagnoses')
    plt.ylim(0, 100)  # Adjust this as needed
    plt.savefig(output_plot)
    plt.close()

    print(f"Plot saved to {output_plot}")

# Example usage in a Jupyter notebook:
input_dir = "../outputdir_all_2024_07_04/gpt_o1_disease_results"
output_file = "../outputdir_all_2024_07_04/plots/topn_result_gpt_o1.tsv"
output_plot = "../outputdir_all_2024_07_04/plots/topn_result_gpt_o1_plot.png"

# Ensure the output directory exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)

# Call the function
compute_summary_statistics(input_dir, output_file, output_plot)