In [None]:
from langchain_openai import ChatOpenAI
from scispacy.linking import EntityLinker
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from typing import Annotated, Literal, Union
from typing_extensions import TypedDict
from os.path import dirname, join
from dotenv import load_dotenv
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import re
import spacy
import requests
import time
import json
import pandas as pd
import getpass
import os

In [None]:
# Set your OpenAI API key
load_dotenv("/Users/mastorga/Documents/BTE-LLM/.env")

if not os.environ.get("OPENAI_API_KEY"): #field to ask for OpenAI API key
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Please enter OpenAI API Key: ")

In [None]:
# Setting up SciSpacy
nlp = spacy.load("en_core_sci_lg")
drug_disease_nlp = spacy.load("en_ner_bc5cdr_md")
nlp.add_pipe("scispacy_linker", config={"resolve_abbreviations": True, "linker_name": "umls"})

In [None]:
def idCosineSimilarity(id1: str, id2: str) -> float:
    def getIDSynonymEntities(ids: list[str], base_url: str = "https://name-lookup.ci.transltr.io/synonyms"):
        try:
            encoded_ids = [curie.replace(":", "%3A") for curie in ids]
            query = "&".join(f"preferred_curies={eid}" for eid in encoded_ids)
            url = f"{base_url}?{query}"

            response = requests.get(url, headers={"accept": "application/json"})
            response.raise_for_status()
            data = response.json()
            return {curie: data.get(curie, {}).get("names", []) for curie in ids}
        except Exception as e:
            print(f"Error querying synonyms for {ids}: {e}")
            return {curie: [] for curie in ids}
    
    def getCosineSimilarity(str1: str, str2: str):
        str1_list = word_tokenize(str1)
        str2_list = word_tokenize(str2)

        sw = stopwords.words("english")
        str1_set = {w for w in str1_list if w not in sw}
        str2_set = {w for w in str2_list if w not in sw}

        rvector = str1_set | str2_set
        l1, l2 = [], []

        for w in rvector:
            l1.append(1 if w in str1_set else 0)
            l2.append(1 if w in str2_set else 0)

        if sum(l1) == 0 or sum(l2) == 0:
            return 0.0

        c = sum(l1[i] * l2[i] for i in range(len(rvector)))
        return c / float((sum(l1) * sum(l2)) ** 0.5)

    # Expand IDs via node normalization
    id1list = list(nodeNormalize(id1))  # returns set
    id2list = list(nodeNormalize(id2))

    if not id1list or not id2list:
        print("ISSUE WITH ENTITY: ")
        print(id1)
        return 0.0

    # Fetch synonyms
    synonyms1 = getIDSynonymEntities(id1list)
    synonyms2 = getIDSynonymEntities(id2list)

    # Flatten synonyms into one string each
    ent1_text = " ".join(name for names in synonyms1.values() for name in names)
    ent2_text = " ".join(name for names in synonyms2.values() for name in names)

    return getCosineSimilarity(ent1_text, ent2_text)


In [None]:
def nodeNormalize(entity: str):
    try:
        headers = {"Content-Type": "application/json"}
    
        payload = {
            "curies": [entity],
            "conflate": True,
            "description": True,
            "drug_chemical_conflate": False
        }
    
        response = requests.post(
            "https://nodenormalization-sri.renci.org/1.5/get_normalized_nodes",
            headers=headers,
            json=payload
        )
        response.raise_for_status()
    
        data = response.json()
        entry = data.get(entity)  # access the entityâ€™s entry
        ids = set()
    
        if not entry:
            return []
    
        # Main id
        if "id" in entry and "identifier" in entry["id"]:
            ids.add(entry["id"]["identifier"])
    
        # Equivalent IDs
        for equiv in entry.get("equivalent_identifiers", []):
            identifier = equiv.get("identifier")
            if identifier:
                ids.add(identifier.strip("[]'\" "))

        if len(ids) < 1:
            print(entity)
        
        return ids
    except Exception as e:
        print(f"Error normalizing {entity}: {e}")
        return set()

def checkAnswer(result, groundTruth): 
    if nodeNormalizationTest(result, groundTruth): 
        return True 
    else: return False

In [None]:
def nodeNormalizationTest(ent: str, gtruth: str) -> bool:
    """
    Returns True if ent and gtruth share any normalized IDs.
    """
    ent_ids = nodeNormalize(ent)       
    gtruth_ids = nodeNormalize(gtruth)

    return len(set(ent_ids) & set(gtruth_ids)) > 0  # intersection

In [None]:
class BERNToolWrapper:
    def __init__(self, url: str = "http://bern2.korea.ac.kr/plain"):
        self.url = url

    def invoke(self, input: dict) -> dict:
        text = input.get("query", "")
        try:
            response = requests.post(self.url, json={"text": text})
            response.raise_for_status()
            data = response.json()
            annotations = data.get("annotations", [])
        except Exception as e:
            print(f"BERN request failed for '{text}': {e}")
            return {}

        results = {}
        for ann in annotations:
            ids = ann.get("id", [])
            mention = ann.get("mention", "")
            prob = ann.get("prob", None)

            for curie in ids:
                results[mention] = curie
                print(mention + " - " + results[mention] + "; prob: " + str(prob))

        return results

In [None]:
BERNTool = BERNToolWrapper()

res = BERNTool.invoke({"query": "cholesterol biosynthetic process"})
print(res)

In [None]:
class NERInput(BaseModel):
    query: str = Field(description = "Query to extract biological entities from")

In [None]:
@tool("BioNERTool", args_schema=NERInput)
def modifiedBioNERTool(query: str):
    """Extract biological entities from a query and returns them along with their ID"""


### Step 1: Biomedical entity extraction
    llm = ChatOpenAI(temperature=0, model="gpt-4o")
    global nlp, drug_disease_nlp

    def extractEnts(query: str):
        entities = []
        
        docs = [
        nlp(query),
        drug_disease_nlp(query)
        ]

        # Extract entity texts from each doc
        for doc in docs:
            for ent in doc.ents:
                entities.append(ent.text.strip())

        # Extract biological process terms using LLM
        bp_prompt = f"""You are a helpful assistant that can extract biological processes from a given query. 
                    These might include concepts such as "cholesterol biosynthesis", "Aminergic neurotransmitter loading into synaptic vesicle", etc. Each entity should be a noun.

                    You must always return the full phrase/long form of each biomedical entity 
                    
                    Return results as a list. Return "" if no biological processes are in the query. DO NOT INCLUDE YOUR THOUGHTS
                    Here is your query: {query}"""

        bp_list = llm.invoke(bp_prompt).content.strip()
        entities.append(bp_list)
    
        # Deduplicate while preserving order
        entities = list(dict.fromkeys(entities))

        return entities
        
### Step 2: Linking each entity into a semantic type
    def classifyEnt(query: str, ent: str):
        supportedEntTypes = ["biologicalProcess", "general"]
        
        class entType(TypedDict):
            """The most appropriate entity type given a specific entity"""
            entType: Literal[*supportedEntTypes]

        classify_prompt = f"""
        Classify the following biomedical entity into one of: 
        {supportedEntTypes}

        Biomedical entity: {ent}

        For context, here is the query that the entity was extracted from: {query}
        """
        chosen_type = llm.with_structured_output(entType).invoke(classify_prompt)

        chosen_type = str(chosen_type["entType"])

        print(ent + " - " + chosen_type)
        
        return chosen_type


### Step 3: Entity-linking using different linkers based on entity type mapping
    def linkEnt(entList: list, query: str):
        idList = {}

        def sriNameResolver(ent: str, base_url = "https://name-lookup.ci.transltr.io/lookup", is_bp: bool = False, k = 50) -> list:    
            candidates = []
            autoComplete = True
            
            processed_ent = ent.replace(" ", "%20")

            final_url = base_url + "?string=" + processed_ent + "&autocomplete=" + str(autoComplete).lower() + "&limit=" + str(k)

            if is_bp:
                final_url = final_url + "&only_prefixes=GO&biolink_type=BiologicalProcess"
                
            response = requests.get(final_url, headers={"accept": "application/json"})
        
            candidate_list = json.loads(response.content.decode("utf-8"))
        
            for item in candidate_list:
                parsed = {
                    "label": item.get("label", ""),
                    "curie": item.get("curie", ""),
                    "score": item.get("score", "")
                }
        
                candidates.append(parsed)
        
            return candidates

        def remove_TUI(text):
            parts = text.split("TUI", 1)
            return parts[0]
        
        def selectID(doc: object):
            bioID = ""
            candidates = {}

            linker = nlp.get_pipe("scispacy_linker")   
            
            for ent in doc.ents:
                if ent._.kb_ents:  # Check if entity has linked knowledge base IDs
                    for id in ent._.kb_ents:
                        candidates[id[0]] = remove_TUI(str(linker.kb.cui_to_entity[id[0]]))
        
                select_prompt = f"""You are a smart biomedical assistant that can understand the context and the intent behind a query. 
                            Be careful when choosing IDs for entities that can refer to different concepts (for example, HIV can refer either to the virus or the disease; you MUST choose the most appropriate concept/definition based on the query). 
                            Use the context and the intent behind the query to choose the most appropriate ID. 
                            Here is the complete query: {query}
                            Select the one most appropriate ID/CUI for {ent.text} from the list below:
                            {candidates}
                            If none of the choices are appropriate, return "".
                            Otherwise, return only the ID/CUI.
                            """
    
                # LLM selects most appropriate ID from list
                selectedID = llm.invoke(select_prompt).content.strip()
    
                # Extract just the UMLS CUI using regex
                match = re.search(r"C\d{7}", selectedID)
                if match:
                    bioID = "UMLS:" + match.group(0)
                    definition = candidates[match.group(0)]
                else:
                    bioID = ""
                    definition = ""
    
                # Printing chosen ID + definition, if any
                print(ent.text + " - " + bioID + '\n' + definition)

            return bioID

        def selectIDbp(entity: str, candidates: list):
            choices = [""]

            for entry in candidates:
                curie = entry.get("curie")
                choices.append(curie)

            class selectedID(TypedDict):
                """The most appropriate CUI/ID from the given candidates"""
                selectedID: Literal[*choices]
            
            select_prompt = f"""You are a smart biomedical assistant that can understand the context and the intent behind a query. 
                        Be careful when choosing IDs for entities that can refer to different concepts (for example, HIV can refer either to the virus or the disease; you MUST choose the most appropriate concept/definition based on the query). 
                        Use the context and the intent behind the query to choose the most appropriate ID. 
                        Here is the complete query: {query}
                        Select the one most appropriate ID/CUI for {entity} from the list below:
                        {candidates}
                        If none of the choices are appropriate, return "".
                        Otherwise, return only the ID/CUI.
                        """
    
            # LLM selects most appropriate ID from list
            selectedID = llm.with_structured_output(selectedID).invoke(select_prompt)

            bioID = str(selectedID["selectedID"])

            # Printing chosen ID + definition, if any
            print(entity + " - " + bioID + '\n')

            return bioID

        # General Pipeline            
        for entity in entList:
            chosenID = ""
            entclass = classifyEnt(query, entity)
            if entclass == "biologicalProcess":
                chosenID = selectIDbp(entity, sriNameResolver(entity, is_bp=True))
            else:
                doc = nlp(entity)
                chosenID = selectID(doc)
                if chosenID == "":
                    chosenID = selectIDbp(entity, sriNameResolver(entity, is_bp=False))

            idList[entity] = chosenID
            

        return idList
    
    entityList = extractEnts(query)
    bioIDs = linkEnt(entityList, query)

    print(bioIDs)
    
    return bioIDs if bioIDs else {"message": "No entities found"}

In [None]:
@tool("BioNERTool", args_schema=NERInput)
def BioNERTool(query: str):
    """Extract biological entities from a query and returns them along with their ID"""

    def remove_TUI(text):
        parts = text.split("TUI", 1)
        return parts[0]

    # Setting up nlp model
    global nlp
    linker = nlp.get_pipe("scispacy_linker")

    bioIDs = {}
    idList = {}

    llm = ChatOpenAI(temperature=0, model="gpt-4o")
    
    doc = nlp(query)

    for ent in doc.ents:
        if ent._.kb_ents:  # Check if entity has linked knowledge base IDs
            for id in ent._.kb_ents:
                idList[id[0]] = remove_TUI(str(linker.kb.cui_to_entity[id[0]]))

            select_prompt = f"""You are a smart biomedical assistant that can understand the context and the intent behind a query. 
                        Be careful when choosing IDs for entities that can refer to different concepts (for example, HIV can refer either to the virus or the disease; you MUST choose the most appropriate concept/definition based on the query). 
                        Use the context and the intent behind the query to choose the most appropriate ID. 
                        Here is the complete query: {query}
                        Select the one most appropriate ID/CUI for {ent.text} from the list below:
                        {idList}
                        If none of the choices are appropriate, return "".
                        Otherwise, return only the ID/CUI.
                        """

            # LLM selects most appropriate ID from list
            selectedID = llm.invoke(select_prompt).content.strip()

            # Extract just the UMLS CUI using regex
            match = re.search(r"C\d{7}", selectedID)
            if match:
                bioIDs[ent.text] = "UMLS:" + match.group(0)
                definition = idList[match.group(0)]
            else:
                bioIDs[ent.text] = "No appropriate IDs were found"
                definition = ""

            # Printing chosen ID + definition, if any
            print(ent.text + " - " + bioIDs[ent.text] + '\n' + definition)

            

    return bioIDs if bioIDs else {"message": "No entities found"}

In [None]:
def run_ner_test(
    df,
    name_col: str,
    id_col: str,
    tools: dict[str, object],
    check_answer_fn,
    similarity_fn
):
    """
    Generalizable NER test function.

    Args:
        df: pandas DataFrame with entities and ground-truth IDs
        name_col: column containing entity names (e.g. "bp_name")
        id_col: column containing ground truth IDs (e.g. "bp")
        tools: dict mapping tool_name -> tool_object (must have .invoke())
        check_answer_fn: function(result_id, groundTruth) -> "Correct"/"Incorrect"
        similarity_fn: function(id1, id2) -> float

    Returns:
        dict with tally counts for each tool
    """
    n = len(df[name_col])
    tallies = {tool_name: 0 for tool_name in tools}

    # Add results columns to df
    for tool_name in tools:
        df[f"{id_col}_{tool_name}"] = ""               # raw results dict
        df[f"{id_col}_{tool_name}_ans"] = ""           # best chosen ID
        df[f"{id_col}_{tool_name}_correct"] = False    # correctness flag
        df[f"{id_col}_{tool_name}_similarity"] = -1.0  # cosine similarity

    for index, entity in enumerate(df[name_col]):
        groundTruth = df.iloc[index][id_col]
        print(f"\n{index+1}. Entity: {entity}\nGround truth: {groundTruth}\n")

        # Store per-entity results
        entity_results = {}

        for tool_name, tool in tools.items():
            best_ans = None
            best_sim = -1.0
            check = False

            try:
                # Run tool
                results = tool.invoke(input={"query": entity})
                df.at[index, f"{id_col}_{tool_name}"] = results

                # Pick best candidate
                for candidate_id in results.values():
                    sim = similarity_fn(candidate_id, groundTruth)
                    if sim > best_sim:
                        best_ans, best_sim = candidate_id, sim
                        check = check_answer_fn(best_ans, groundTruth)
                    if sim == 1:
                        check = True
                        break

                # Update DataFrame
                df.at[index, f"{id_col}_{tool_name}_ans"] = best_ans
                df.at[index, f"{id_col}_{tool_name}_correct"] = check
                df.at[index, f"{id_col}_{tool_name}_similarity"] = best_sim
                if check == True:
                    tallies[tool_name] += 1
                    df.at[index, f"{id_col}_{tool_name}_correct"] = True

                print(f"\n{tool_name} answer: {best_ans}; cosine similarity = {best_sim:.3f}\n\n-----")

            except Exception as e:
                best_ans = None
                best_sim = -1.0
                check = f"Error: {e}"

            # Save to entity summary
            entity_results[tool_name] = {
                "ans": best_ans,
                "sim": best_sim,
                "check": check,
                "tally": tallies[tool_name]
            }

        # === Print summary for this entity ===
        for tool_name, result in entity_results.items():
            print(f"""{tool_name} answer: {result['ans']} - {result['check']}
    cosine similarity with gt: {result['sim']:.3f}
    tally: {result['tally']} / {n}\n""")

        print("{\n---------------------------\n")

    return tallies


In [None]:
DMDB = pd.read_csv('/Users/mastorga/Documents/BTE-LLM/Prototype/data/DMDB_go_bp_filtered_jjoy_05_08_2025.csv')

In [None]:
drugDiseaseSet = DMDB.sample(100)
drugDiseaseSet["pairID"] = drugDiseaseSet.index
drugDiseaseSet = drugDiseaseSet.reset_index(drop=True)  

In [None]:
drugDiseaseSet

In [None]:
metaboliteDB = pd.read_csv('/Users/mastorga/Documents/BTE-LLM/Prototype/data/DMDB_chebi_metabolite_filtered.csv')

In [None]:
metaboliteSet = metaboliteDB.sample(100)
metaboliteSet["pairID"] = metaboliteSet.index
metaboliteSet = metaboliteSet.reset_index(drop=True)  

In [None]:
metaboliteSet

In [None]:
geneDB = pd.read_csv('/Users/mastorga/Documents/BTE-LLM/Prototype/data/DMDB_mechanistic_genes_filtered.csv')

In [None]:
geneSet = geneDB.sample(100)
geneSet["pairID"] = geneSet.index
geneSet["gene_str"] = geneSet["protein_gene_symbol"].apply(lambda x: x.strip("""['"]"""))
geneSet["protein_str"] = geneSet["protein_name"].apply(lambda x: x.strip("""['"]"""))
geneSet["protein_ID"] = geneSet["protein"].apply(lambda x: x.strip("""['"]""").replace("UniProt:", "UniProtKB:"))
geneSet = geneSet.reset_index(drop=True)  

In [None]:
geneSet

In [None]:
BERNTool = BERNToolWrapper()

In [None]:
tools = {
    "BioNER": BioNERTool,
    "modifiedBioNER": modifiedBioNERTool,
    "BERN2": BERNTool
}

In [None]:
# BIOLOGICAL PROCESS NER
bp_tallies = run_ner_test(
    df=drugDiseaseSet,
    name_col="bp_name",
    id_col="bp",
    tools=tools,
    check_answer_fn=checkAnswer,
    similarity_fn=idCosineSimilarity
)

print("Final tallies for the BIOLOGICAL PROCESS entity type:", bp_tallies)

In [None]:
# DISEASE NER
disease_tallies = run_ner_test(
    df=drugDiseaseSet,
    name_col="disease_name",
    id_col="disease",
    tools=tools,
    check_answer_fn=checkAnswer,
    similarity_fn=idCosineSimilarity
)

print("Final tallies for the DISEASE entity type:", disease_tallies)

In [None]:
# DRUG NER
drug_tallies = run_ner_test(
    df=drugDiseaseSet,
    name_col="drug_name",
    id_col="Drug_MeshID",
    tools=tools,
    check_answer_fn=checkAnswer,
    similarity_fn=idCosineSimilarity
)

print("Final tallies for the DRUG entity type:", drug_tallies)

In [None]:
# Save responses

drugDiseaseSet.to_excel('/Users/mastorga/Documents/BTE-LLM/Prototype/logs/NER/09.03.25_100pairs_drug_disease_bp.xlsx')

In [None]:
# METABOLITE NER
metabolite_tallies = run_ner_test(
    df=metaboliteSet,
    name_col="metabolite_name_str",
    id_col="metabolite",
    tools=tools,
    check_answer_fn=checkAnswer,
    similarity_fn=idCosineSimilarity
)

print("Final tallies for the METABOLITE entity type:", metabolite_tallies)

In [None]:
# Save responses

metaboliteSet.to_excel('/Users/mastorga/Documents/BTE-LLM/Prototype/logs/NER/09.03.25_100pairs_metabolite.xlsx')

In [None]:
# GENE NER
gene_tallies = run_ner_test(
    df=geneSet,
    name_col="gene_str",
    id_col="protein_ID",
    tools=tools,
    check_answer_fn=checkAnswer,
    similarity_fn=idCosineSimilarity
)

print("Final tallies for the GENE entity type:", gene_tallies)

In [None]:
# PROTEIN NER
protein_tallies = run_ner_test(
    df=geneSet,
    name_col="protein_str",
    id_col="protein_ID",
    tools=tools,
    check_answer_fn=checkAnswer,
    similarity_fn=idCosineSimilarity
)

print("Final tallies for the PROTEIN entity type:", protein_tallies)

In [None]:
# Save responses

geneSet.to_excel('/Users/mastorga/Documents/BTE-LLM/Prototype/logs/NER/09.03.25_100pairs_mechanisticgenes.xlsx')