In [1]:
import pandas as pd
import pickle
import numpy as np
from collections import defaultdict
import torch
import re
from sentence_transformers import SentenceTransformer,util
from functools import lru_cache  # For caching
import os
import ast
import subprocess

pd.options.mode.copy_on_write = True

# need to get SNOMED files, see https://www.snomed.org/get-snomed
# have to start with snomed_concepts, because descriptions have old (inactive) concepts in there that were not updated and are marked as active descriptions
snomed_concepts = pd.read_csv("../snomed/SnomedCT_InternationalRF2_PRODUCTION_20230331T120000Z/Snapshot/Terminology/sct2_Concept_Snapshot_INT_20230331.txt", delimiter ="\t")
active_snomed_concepts = list(snomed_concepts[snomed_concepts['active']==1]['id'].values)

snomed_descriptions = pd.read_csv("../snomed/SnomedCT_InternationalRF2_PRODUCTION_20230331T120000Z/Snapshot/Terminology/sct2_Description_Snapshot-en_INT_20230331.txt", delimiter ="\t")
active_snomed_descriptions = snomed_descriptions[(snomed_descriptions['active']==1) & (snomed_descriptions['conceptId'].isin(active_snomed_concepts))]

# In SNOMED CT, the preferred term label has the semantic tag in parentheses
prefered_id = 900000000000003001

# Filter active SNOMED descriptions to get those with the preferred term typeId
df_fsn = active_snomed_descriptions[active_snomed_descriptions['typeId'] == prefered_id]

def find_in_parentheses(text):
    # Regular expression to find all content within parentheses
    matches = re.findall(r'\(([^)]+)\)', text)
    if matches:
        return matches[-1]  # Returns the last text within the parentheses
    return None

# Extract the semantic tag from the term column
df_fsn['semantic_tag'] = df_fsn['term'].apply(find_in_parentheses)

# Create dictionaries mapping conceptId to term and semantic tag
sctid_to_term = df_fsn.set_index('conceptId')['term'].to_dict()
sctid_to_tag = df_fsn.set_index('conceptId')['semantic_tag'].to_dict()

# Find abbreviations in the term column
mask = active_snomed_descriptions['term'].str.match(r'[A-Z]{2,4}\s-\s.*', na=False)
abbreviations = active_snomed_descriptions[mask].copy()

# Extract the abbreviation part of the term
abbreviations['term'] = abbreviations['term'].map(lambda x: x.split('-')[0].strip())

# Keep only the first occurrence of each abbreviation
abbreviations_unique = abbreviations.groupby('term', as_index=False).nth(0)

# Filter out abbreviations that are already in the SNOMED terms
unique_terms = active_snomed_descriptions['term'].values
abbreviation_fresh = abbreviations_unique[~abbreviations_unique['term'].isin(unique_terms)]

# Combine the original descriptions with the new unique abbreviations
snomed_active_withabbreviation = pd.concat([active_snomed_descriptions, abbreviation_fresh], ignore_index=True, sort=False)

# Write the terms to a CSV file for normalization
snomed_active_withabbreviation['term'].to_csv("./snomed_active_withabbreviation_onlyterms.csv", index=False, sep=',')

# Note: The next step involves using an external command line tool for normalization
# See: https://lhncbc.nlm.nih.gov/LSG/Projects/lvg/current/docs/userDoc/tools/norm.html
# Command: norm -i:snomed_active_withabbreviation_onlyterms.csv -o:snomed_active_withabbreviation_onlyterms_normalize.csv


In [2]:
# Read the normalized SNOMED terms from the CSV file
snomed_terms_normalized = pd.read_csv('./chrome_matching/snomed_active_withabbreviation_onlyterms_normalize.csv',
                                      header=None, names=['label', 'label_normalized'], sep='|')

# Drop the first row which contains the term/header entry and remove duplicates
snomed_terms_normalized = snomed_terms_normalized.drop([0])
snomed_terms_normalized = snomed_terms_normalized.drop_duplicates(subset=['label', 'label_normalized'])

# Set the 'term' column as the index for the DataFrame
snomed_active_withabbreviation = snomed_active_withabbreviation.set_index('term')

# Group by 'term' and create a set of 'conceptId' for each term
snomed_active_withabbreviation_uniqueset = snomed_active_withabbreviation.groupby('term')['conceptId'].apply(set)

# Convert the grouped Series to a DataFrame
snomed_active_withabbreviation_uniqueset_df = pd.DataFrame(snomed_active_withabbreviation_uniqueset)

# Add a column to count the number of unique 'conceptId's for each term
snomed_active_withabbreviation_uniqueset_df['num_ids'] = snomed_active_withabbreviation_uniqueset.apply(lambda x: len(x))

# The vast majority of terms are mapped to a single SCTID. In the rare occasion that a term is mapped to 2 or more different SCTIDs,
# the SCTID is picked according to its semantic tag as described in the PNAS paper by Kurvers et al.
# If none of the SCTIDs has a semantic tag from the 7 listed below, the SCTID is randomly chosen among the candidates.
ordered_tags = ['disorder', 'finding', 'morphologic abnormality', 'body structure', 'person', 'organism', 'specimen']
ordered_tags_dict = {k: i for i, k in enumerate(ordered_tags)}

# Function to pick a single SCTID from a set of SCTIDs based on the preferred semantic tag
def pick_single_sctid_from_set(sct_set):
    sct_list = list(sct_set)
    tags = [sctid_to_tag[sctid] if sctid in sctid_to_tag else '-' for sctid in sct_list]
    tags_num = [ordered_tags_dict.get(tag, np.inf) for tag in tags]
    if any(tag_num != np.inf for tag_num in tags_num):
        return sct_list[np.argmin(tags_num)]
    else:
        return np.random.choice(sct_list)

# Apply the function to pick a single SCTID for each term and add it as a new column
snomed_active_withabbreviation_uniqueset_df['single_sct_id'] = snomed_active_withabbreviation_uniqueset_df["conceptId"].map(pick_single_sctid_from_set)

# Add the 'conceptId' back to the DataFrame by mapping the labels to their corresponding 'conceptId'
def label_to_concept(label):
    try:
        return snomed_active_withabbreviation_uniqueset.loc[label]
    except:
        print('did not find ', label)
        return set()

# Map each label to its 'conceptId'
snomed_terms_normalized['conceptId'] = snomed_terms_normalized['label'].map(label_to_concept)

# Group by 'label' and 'label_normalized' and create sets of normalized labels and 'conceptId's
label_to_normalized_df = snomed_terms_normalized.groupby('label')['label_normalized'].apply(set)
normalized_to_sctid_df = snomed_terms_normalized.groupby('label_normalized')['conceptId'].apply(lambda x: set.union(*x))

# Convert the grouped Series to dictionaries
label_to_normalized = label_to_normalized_df.to_dict


In [3]:
# Load case and solve data
cases_df = pd.read_csv('data/case_data.csv')
cases_df['diagnosis_names'] = cases_df['diagnosis_names'].apply(ast.literal_eval)
solves_df = pd.read_csv('data/solve_data.csv')
solves_df['final_dxs'] = solves_df['final_dxs'].apply(ast.literal_eval)

# List of pretext blocks to be removed from LLM responses
RESPONSE_PRETEXT_BLOCKLIST = [
    "Sure, ",
    "Here is the ",
    "Here are ",
    "### Response:",
    "The probable",
    "The differential",
    "The most probable",
]

# Function to create a defaultdict of dicts
def create_dict_defaultdict():
    return defaultdict(dict)

# Function to clean the raw responses from LLMs
def _clean_response(response):
    # Strip leading/trailing whitespace
    response = response.strip()
    
    # Remove pretext blocks from the beginning of the response
    for blocklist_item in RESPONSE_PRETEXT_BLOCKLIST:
        if response.startswith(blocklist_item):
            response = "\n".join(response.split("\n")[1:])
            break

    # Process each diagnosis in the response
    responses = [
        re.sub("^(?:-|\*|•|\d+\.)", "", dx.strip()).strip()
        for dx in response.split("\n")
    ]
    
    # Remove various unwanted patterns from the responses
    responses = [re.sub("(\(SNOMED CT:.+)\)", "", dx.strip()).strip() for dx in responses]
    responses = [re.sub(r'\bSNOMED CT:\s*[\d|#]*\s*[|\-()]*\s*', "", dx.strip()).strip() for dx in responses]
    responses = [re.sub("(\(\d+)\)", "", dx.strip()).strip() for dx in responses]
    responses = [re.sub("(\([A-Z]\d+.+)\)", "", dx.strip()).strip() for dx in responses]
    responses = [re.sub("(- \d\d\d+)", "", dx.strip()).strip() for dx in responses]
    
    # Filter out empty responses
    responses = [res for res in responses if res != '']
    
    # Further clean responses by splitting at certain characters if the length exceeds 6 words
    responses = [res.split(':')[0] if len(res.split()) > 6 else res for res in responses]
    responses = [res.split(' - ')[0] if len(res.split()) > 6 else res for res in responses]
    
    # Remove vertical bars and strip leading/trailing whitespace
    responses = [dx.replace('|', '').strip() for dx in responses]
    
    # Return the first 5 cleaned responses
    return responses[:5]

# Initialize an empty dictionary to store LLM data
llm_data = {}

# Define keys for LLM models and corresponding saved keys
mkeys = ['OpenAI_gpt-4-1106-preview', 'GoogleAI_gemini-1.0-pro', 'AnthropicAI_claude-3-opus-20240229','MetaAI_llama-2-70b-f','MistralAI_mistral-large-latest']
saved_mkeys = ['OPENAI', 'GOOGLEAI', 'ANTHROPICAI', 'METAAI', 'MISTRALAI']

# Load LLM responses from pickle files
for other_key, key in zip(mkeys, saved_mkeys):
    llm_data[other_key] = pickle.load(open(f"./data/_data_machine_solves_varied_prompts_2024_03_29_{key}.pkl", "rb"))

# Extract all case IDs and experimental settings from the loaded data
all_case_ids = list(llm_data[mkeys[0]].keys())
exp_settings = list(llm_data[mkeys[0]][1883][mkeys[0]].keys())

# Initialize an empty list to store machine data
machine_data = []

# Process each LLM model, case ID, and experimental setting
for mkey in mkeys:
    for cid in all_case_ids:
        for exp in exp_settings:
            dialist = _clean_response(llm_data[mkey][cid][mkey][exp][0])
            if len(dialist) > 0:
                for i, dia in enumerate(dialist):
                    machine_data.append({
                        'llm_model': mkey,
                        'pc_id': cid,
                        'prompt': exp,
                        'diagnosis': dia,
                        'rank': i + 1
                    })
            else:
                machine_data.append({
                    'llm_model': mkey,
                    'pc_id': cid,
                    'prompt': exp,
                    'diagnosis': '',
                    'rank': -1
                })

# Convert the list of machine data to a DataFrame
machine_data_df = pd.DataFrame(machine_data)


In [4]:
# Initialize an empty list to collect all terms for normalization
term_list = []

# Extend the term list with diagnosis names from cases_df
for diagnoses in cases_df['diagnosis_names'].values:
    term_list.extend(diagnoses)

# Extend the term list with final diagnoses from solves_df
for diagnoses in solves_df['final_dxs'].values:
    term_list.extend(diagnoses)

# Extend the term list with diagnoses from machine_data_df
for diagnoses in machine_data_df['diagnosis'].values:
    term_list.extend([diagnoses])

# Remove duplicate terms by converting the list to a set and then back to a list
term_list = list(set(term_list))

# Write the unique terms to a CSV file for normalization
with open('data/terms_to_normalize.csv', 'w') as f:
    for line in term_list:
        f.write(f"{line}\n")

# Normalize terms via command line using the 'norm' tool (much quicker than individual normalization)
# Command to normalize terms:
# cmd = "norm i:./data/terms_to_normalize.csv o:./data/terms_normalized.csv"

In [5]:
# Load the normalized terms from the CSV file
solves_terms_normalized = pd.read_csv('data/terms_normalized.csv', header=None, names=['label', 'label_normalized'], sep='|')

# Drop duplicate entries to ensure unique mappings
solves_terms_normalized = solves_terms_normalized.drop_duplicates(subset=['label', 'label_normalized'])

# Group by the original label to collect all normalized versions
solves_terms_normalized = solves_terms_normalized.groupby('label')['label_normalized'].apply(set)

# Convert the grouped data into a dictionary for easy lookup
term_to_normalized = solves_terms_normalized.to_dict(defaultdict(set))

# Function to normalize diagnosis terms using an external normalization tool
def normalize_diagnosis(raw, term_to_normalized={}):
    existing = term_to_normalized.get(raw, None)
    if existing is not None:
        return existing
    else:
        # For a large number of raw labels, it is much faster to apply normalization to a file
        cmd = f'echo "{raw}" | norm'
        ps = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, executable="/bin/bash", env=dict(PATH="~/hacid/lvg2024/bin/"))
        output = ps.communicate()[0].decode('UTF-8')
        normalized = [x.split("|")[-1] for x in output.split('\n')[:-1]]
        term_to_normalized[raw] = normalized
        return normalized

# Use memoization to cache results of the crome_matching function
@lru_cache(maxsize=1000000)  # Set a cache size limit to avoid memory overuse
def crome_matching(text, normalized_to_sctid=normalized_to_sctid):
    normalized_items = normalize_diagnosis(text, term_to_normalized=term_to_normalized)
    results = [normalized_to_sctid.get(item, None) for item in normalized_items]

    # Flatten results and remove None
    flat_results = [
        element 
        for item in results 
        for element in (item if isinstance(item, (set, list)) else [item]) 
        if element is not None
    ]

    # Ensure uniqueness using a set
    unique_results = list(set(flat_results))

    if unique_results:
        if len(unique_results) == 1:
            return unique_results[0]
        else:
            # Choose the SCTID based on the ordered tags, prioritizing certain tags
            tags = [sctid_to_tag[sctid] for sctid in unique_results]
            tags_num = [ordered_tags_dict.get(tag, np.inf) for tag in tags]
            
            if any(tag_num != np.inf for tag_num in tags_num):
                return unique_results[np.argmin(tags_num)]
            else:
                return np.random.choice(unique_results)
    return None

# Function to return all indices where the maximum value occurs
def all_argmax(b):
    return np.flatnonzero(b == b.max())

### Matching
# Set the TOKENIZERS_PARALLELISM to false before importing or using the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# load sentence transformer model
sbiobertmodel = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO', device='cuda')

# encode all unique snomed terms (including all synonyms) and abbreviation and create embeddings
# this takes pretty long
sct_embeddings = sbiobertmodel.encode(snomed_active_withabbreviation_uniqueset_df.index.values)

# Use memoization to cache results of the str_to_sctid function
@lru_cache(maxsize=1000000)  # Cache results to improve performance
def str_to_sctid(text):
    if text == '' or text == '-':
        return None
    result = crome_matching(text)  # Crome method -> see Kurvers et.al PNAS
    if result is None:
        # Use SBERT model to find the closest SCTID if Crome matching fails
        diagnosis_emb = sbiobertmodel.encode(text)
        cos_sims = util.cos_sim(diagnosis_emb, sct_embeddings).numpy()
        results = snomed_active_withabbreviation_uniqueset_df.iloc[all_argmax(cos_sims)]['single_sct_id'].values
        if len(results) == 1:
            result = results[0]
        else:
            result = np.random.choice(results)
    return result

# Convert a list of strings to a set of SCTIDs
def strs_to_sctid_set(text_iter):
    return set([str_to_sctid(text) for text in text_iter])

#match correct case diagnoses to snomed ids
cases_df['sctids']= cases_df['diagnosis_names'].apply(strs_to_sctid_set)
cases_df.to_csv('data/case_data_matched.csv', index=False)
cases_df = cases_df.set_index('id')


def is_correct(row):
    return row['sctid'] in cases_df.loc[row['pc_id']].sctids

#match all solve diagnoses to snomed ids
solves_df_expanded= solves_df.explode('final_dxs').rename(columns={'final_dxs': 'diagnosis'})
solves_df_expanded['rank'] = solves_df_expanded.groupby(level=0).cumcount() + 1 
solves_df_expanded['sctid'] = solves_df_expanded['diagnosis'].apply(str_to_sctid)
solves_df_expanded['sctid'] = solves_df_expanded['sctid'].astype('Int64')
solves_df_expanded['is_correct'] = solves_df_expanded.apply(is_correct, axis=1)
solves_df_expanded.to_csv('data/solves_data_matched.csv', index=False)

#match all LLM solves to snomed ids
machine_data_df['sctid'] = machine_data_df['diagnosis'].apply(str_to_sctid)
machine_data_df['sctid'] = machine_data_df['sctid'].astype('Int64')
machine_data_df = machine_data_df[machine_data_df['pc_id'].isin(cases_df.index)]
machine_data_df['is_correct'] = machine_data_df.apply(is_correct, axis=1)
machine_data_df.to_csv('data_to_share/llm_data_matched_validcases.csv', index=False) 