In [10]:
import os
import csv
from collections import defaultdict
from oaklib import get_adapter

# Directory containing the files
input_directory = "../outputdir_all_2024_07_04/gpt_o1_preview_disease_results"

# Ontology adapter to query MONDO
adapter = get_adapter("sqlite:obo:mondo")

# Dictionary to store disease names and their associated identifiers
disease_mapping = defaultdict(set)

def get_mondo_label(mondo_id):
    """
    Fetch the label for a MONDO ID using the Ontology Access Kit.
    """
    try:
        label = adapter.label(mondo_id)
        return label if label else "Label not found"
    except Exception as e:
        return f"Error retrieving label: {str(e)}"

def get_ancestors(mondo_id):
    """
    Retrieve all ancestors for a MONDO ID.
    """
    try:
        return set(adapter.ancestors(mondo_id))
    except Exception as e:
        print(f"Error retrieving ancestors for {mondo_id}: {str(e)}")
        return set()

def find_most_general_term(mondo_ids, ancestor_curie_prefix="MONDO"):
    """
    Find the most general term among a set of MONDO IDs, considering only ancestors with a specific CURIE prefix.
    """
    # Retrieve ancestors for each MONDO ID
    ancestor_sets = {
        mondo_id: {ancestor for ancestor in get_ancestors(mondo_id) if ancestor.startswith(ancestor_curie_prefix)}
        for mondo_id in mondo_ids
    }

    # Find the intersection of all ancestor sets
    common_ancestors = set.intersection(*ancestor_sets.values())

    # Filter candidates to those that are keys in ancestor_sets
    candidates = [candidate for candidate in common_ancestors if candidate in ancestor_sets]

    # Identify the most general term (not an ancestor of any other candidate)
    for candidate in candidates:
        if not any(candidate in ancestor_sets.get(other, set()) for other in candidates if other != candidate):
            return candidate

    # Fallback if no clear "most general term" is found
    return None


# Iterate through each file in the directory
for filename in os.listdir(input_directory):
    file_path = os.path.join(input_directory, filename)
    
    # Only process files that are not directories and have a .tsv extension
    if os.path.isfile(file_path) and filename.endswith('.tsv'):
        with open(file_path, 'r') as file:
            reader = csv.DictReader(file, delimiter='\t')
            
            # Read each row in the file
            for row in reader:
                disease_name = row['disease_name']
                disease_identifier = row['disease_identifier']
                
                # Skip 'N/A' identifiers
                if disease_identifier and disease_identifier != 'N/A':
                    disease_mapping[disease_name].add(disease_identifier)

# Find the most general MONDO term for each disease
generalized_mapping = {}
for disease_name, mondo_ids in disease_mapping.items():
    if len(mondo_ids) == 1:
        # Directly use the single MONDO ID
        most_general_id = next(iter(mondo_ids))
        most_general_label = get_mondo_label(most_general_id)
        generalized_mapping[disease_name] = {
            "most_general_id": most_general_id,
            "most_general_label": most_general_label
        }
    else:
        # Compute the most general MONDO ID when there are multiple IDs
        most_general_id = find_most_general_term(mondo_ids)
        
        if most_general_id:
            # Found the most general term
            most_general_label = get_mondo_label(most_general_id)
            generalized_mapping[disease_name] = {
                "most_general_id": most_general_id,
                "most_general_label": most_general_label
            }
        else:
            # If no general term is found, include all original IDs and labels in existing columns
            all_mondo_ids = "; ".join(sorted(mondo_ids))  # Join IDs with semicolon
            all_mondo_labels = "; ".join([get_mondo_label(mondo_id) for mondo_id in mondo_ids])
            generalized_mapping[disease_name] = {
                "most_general_id": all_mondo_ids + " (no common general term found)",
                "most_general_label": all_mondo_labels
            }

# Sort the results by disease_name
sorted_generalized_mapping = dict(sorted(generalized_mapping.items()))

# Save or print the results
output_tsv_path = "o1_preview_disease_grounding_summary.tsv"
with open(output_tsv_path, 'w', newline='') as tsv_file:
    writer = csv.writer(tsv_file, delimiter='\t')
    # Write the header row
    writer.writerow(["Disease Name", "Most General MONDO ID", "Most General Label"])
    # Write the data rows
    for disease_name, data in sorted_generalized_mapping.items():
        writer.writerow([
            disease_name,
            data["most_general_id"],
            data["most_general_label"]
        ])

print(f"Generalized disease summary saved in TSV format to {output_tsv_path}")

Generalized disease summary saved in TSV format to o1_preview_disease_grounding_summary.tsv
