In [None]:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, convert_to_messages
from os.path import dirname, join
from dotenv import load_dotenv
from typing import Literal
from typing_extensions import TypedDict, Annotated
from langgraph.graph import StateGraph, START, MessagesState, END
from langgraph.types import Command
from operator import add
import pandas as pd
import getpass
import os
import re
from pydantic import BaseModel, Field
from typing import Annotated, Literal, Union
from os.path import dirname, join
from dotenv import load_dotenv
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import re
import requests
import time
import json

In [None]:
# Set your OpenAI API key
load_dotenv("/Users/mastorga/Documents/BTE-LLM/.env")

if not os.environ.get("OPENAI_API_KEY"): #field to ask for OpenAI API key
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Please enter OpenAI API Key: ")

In [None]:
from Agent import BTEx

In [None]:
# Create an LLM-based agent
llm = ChatOpenAI(temperature=0, model="gpt-4.1")  # Change model if needed

In [None]:
def idCosineSimilarity(id1: str, id2: str) -> float:
    def getIDSynonymEntities(ids: list[str], base_url: str = "https://name-lookup.ci.transltr.io/synonyms"):
        try:
            encoded_ids = [curie.replace(":", "%3A") for curie in ids]
            query = "&".join(f"preferred_curies={eid}" for eid in encoded_ids)
            url = f"{base_url}?{query}"

            response = requests.get(url, headers={"accept": "application/json"})
            response.raise_for_status()
            data = response.json()
            return {curie: data.get(curie, {}).get("names", []) for curie in ids}
        except Exception as e:
            print(f"Error querying synonyms for {ids}: {e}")
            return {curie: [] for curie in ids}
    
    def getCosineSimilarity(str1: str, str2: str):
        str1_list = word_tokenize(str1)
        str2_list = word_tokenize(str2)

        sw = stopwords.words("english")
        str1_set = {w for w in str1_list if w not in sw}
        str2_set = {w for w in str2_list if w not in sw}

        rvector = str1_set | str2_set
        l1, l2 = [], []

        for w in rvector:
            l1.append(1 if w in str1_set else 0)
            l2.append(1 if w in str2_set else 0)

        if sum(l1) == 0 or sum(l2) == 0:
            return 0.0

        c = sum(l1[i] * l2[i] for i in range(len(rvector)))
        return c / float((sum(l1) * sum(l2)) ** 0.5)

    # Expand IDs via node normalization
    id1list = list(nodeNormalize(id1))  # returns set
    id2list = list(nodeNormalize(id2))

    if not id1list or not id2list:
        print("ISSUE WITH ENTITY: ")
        print(id1)
        return 0.0

    # Fetch synonyms
    synonyms1 = getIDSynonymEntities(id1list)
    synonyms2 = getIDSynonymEntities(id2list)

    # Flatten synonyms into one string each
    ent1_text = " ".join(name for names in synonyms1.values() for name in names)
    ent2_text = " ".join(name for names in synonyms2.values() for name in names)

    return getCosineSimilarity(ent1_text, ent2_text)

In [None]:
def nodeNormalize(entity: str):
    try:
        headers = {"Content-Type": "application/json"}
    
        payload = {
            "curies": [entity],
            "conflate": True,
            "description": True,
            "drug_chemical_conflate": False
        }
    
        response = requests.post(
            "https://nodenormalization-sri.renci.org/1.5/get_normalized_nodes",
            headers=headers,
            json=payload
        )
        response.raise_for_status()
    
        data = response.json()
        entry = data.get(entity)  # access the entityâ€™s entry
        ids = set()
    
        if not entry:
            return []
    
        # Main id
        if "id" in entry and "identifier" in entry["id"]:
            ids.add(entry["id"]["identifier"])
    
        # Equivalent IDs
        for equiv in entry.get("equivalent_identifiers", []):
            identifier = equiv.get("identifier")
            if identifier:
                ids.add(identifier.strip("[]'\" "))

        if len(ids) < 1:
            print(entity)
        
        return ids
    except Exception as e:
        print(f"Error normalizing {entity}: {e}")
        return set()

def nodeNormalizationTest(ent: str, gtruth: str) -> bool:
    """
    Returns True if ent and gtruth share any normalized IDs.
    """
    ent_ids = nodeNormalize(ent)       
    gtruth_ids = nodeNormalize(gtruth)

    return len(set(ent_ids) & set(gtruth_ids)) > 0  # intersection

In [None]:
def extract_marked_entities(text: str) -> list[str]:
    # capture everything after ** until newline or another asterisk
    return [m.strip() for m in re.findall(r"\*\*(.*?)\*\*", text)]

def sriNameResolver(ent: str, base_url = "https://name-lookup.ci.transltr.io/lookup", is_bp: bool = False, k = 50) -> list:    
    candidates = []
    autoComplete = True
    
    processed_ent = ent.replace(" ", "%20")

    final_url = base_url + "?string=" + processed_ent + "&autocomplete=" + str(autoComplete).lower() + "&limit=" + str(k)

    if is_bp:
        final_url = final_url + "&only_prefixes=GO&biolink_type=BiologicalProcess"
        
    response = requests.get(final_url, headers={"accept": "application/json"})

    candidate_list = json.loads(response.content.decode("utf-8"))

    for item in candidate_list:
        parsed = {
            "label": item.get("label", ""),
            "curie": item.get("curie", ""),
            "score": item.get("score", "")
        }

        candidates.append(parsed)

    return candidates

def selectIDbp(entity: str, candidates: list):
    choices = [""]

    for entry in candidates:
        curie = entry.get("curie")
        choices.append(curie)

    class selectedID(TypedDict):
        """The most appropriate CUI/ID from the given candidates"""
        selectedID: Literal[*choices]
    
    select_prompt = f"""You are a smart biomedical assistant that can understand the context and the intent behind a query. 
                Be careful when choosing IDs for entities that can refer to different concepts (for example, HIV can refer either to the virus or the disease; you MUST choose the most appropriate concept/definition based on the query). 
                Use the context and the intent behind the query to choose the most appropriate ID. 
                Select the one most appropriate ID/CUI for {entity} from the list below:
                {candidates}
                If none of the choices are appropriate, return "".
                Otherwise, return only the ID/CUI.
                """

    # LLM selects most appropriate ID from list
    selectedID = llm.with_structured_output(selectedID).invoke(select_prompt)

    bioID = str(selectedID["selectedID"])

    # Printing chosen ID + definition, if any
    print(entity + " - " + bioID + '\n')

    return bioID

def checkAnswer(result, groundTruth): 
    answers = extract_marked_entities(result)
    for answer in answers:
        ans_ID = selectIDbp(answer, sriNameResolver(answer, is_bp=False))
        if nodeNormalizationTest(ans_ID, groundTruth): 
            return True
     
    return False

In [None]:
def run_test(
    df,
    question_col: str,
    ans_col: str,
    ans_id_col: str,
    tools: dict[str, object],
    check_answer_fn,
    similarity_fn
):
    """
    Generalizable test function.

    Args:
        df: pandas DataFrame with entities and ground-truth IDs
        question_col: column containing the question to be answered
        ans_col: column containing ground truth (e.g. "drug_name")
        ans_id_col: column containing ground truth IDs (e.g. "drug")
        tools: dict mapping tool_name -> tool_object (must have .invoke())
        check_answer_fn: function(result_id, groundTruth) -> "Correct"/"Incorrect"
        similarity_fn: function(id1, id2) -> float

    Returns:
        dict with tally counts for each tool
    """
    n = len(df[question_col])
    tallies = {tool_name: 0 for tool_name in tools}

    # Add results columns to df
    for tool_name in tools:
        df[f"{ans_id_col}_{tool_name}"] = ""               # raw results dict
        df[f"{ans_id_col}_{tool_name}_ans"] = ""           # best chosen ID
        df[f"{ans_id_col}_{tool_name}_correct"] = False    # correctness flag
        df[f"{ans_id_col}_{tool_name}_similarity"] = -1.0  # cosine similarity

    for index, question in enumerate(df[question_col]):
        groundTruth = df.iloc[index][ans_id_col]
        groundTruthName = df.iloc[index][ans_col]
        print(f"\n{index+1}. Question: {question}\nGround truth: {groundTruth} - {groundTruthName}\n")

        # Store per-question results
        question_results = {}

        for tool_name, tool in tools.items():
            best_ans = None
            best_sim = -1.0
            check = False

            try:
                # Run tool
                results = tool(question)
                df.at[index, f"{ans_id_col}_{tool_name}"] = results

                answers = extract_marked_entities(results)

                # Pick best candidate
                for answer in answers:
                    ans_ID = selectIDbp(answer, sriNameResolver(answer, is_bp=False))
                    
                    sim = similarity_fn(ans_ID, groundTruth)
                    if sim > best_sim:
                        best_ans, best_sim = ans_ID, sim
                        check = check_answer_fn(best_ans, groundTruth)
                    if sim == 1:
                        check = True
                        break

                # Update DataFrame
                df.at[index, f"{ans_id_col}_{tool_name}_ans"] = best_ans
                df.at[index, f"{ans_id_col}_{tool_name}_correct"] = check
                df.at[index, f"{ans_id_col}_{tool_name}_similarity"] = best_sim
                if check == True:
                    tallies[tool_name] += 1
                    df.at[index, f"{ans_id_col}_{tool_name}_correct"] = True

                print(f"\n{tool_name} answer: {best_ans}; cosine similarity = {best_sim:.3f}\n\n-----")

            except Exception as e:
                best_ans = None
                best_sim = -1.0
                check = f"Error: {e}"

            # Save to question summary
            question_results[tool_name] = {
                "ans": best_ans,
                "sim": best_sim,
                "check": check,
                "tally": tallies[tool_name]
            }

        # === Print summary for this question ===
        for tool_name, result in question_results.items():
            print(f"""{tool_name} answer: {result['ans']} - {result['check']}
    cosine similarity with gt: {result['sim']:.3f}
    tally: {result['tally']} / {n}\n""")

        print("{\n---------------------------\n")

    return tallies


In [None]:
def LLM_only(question):
    return llm.invoke(question).content

In [None]:
geneDB = pd.read_csv('/Users/mastorga/Documents/BTE-LLM/Prototype/data/DMDB_mechanistic_genes_filtered.csv')

In [None]:
geneSet = geneDB.sample(50)
geneSet["pairID"] = geneSet.index
geneSet["gene_str"] = geneSet["protein_gene_symbol"].apply(lambda x: x.strip("""['"]"""))
geneSet["protein_str"] = geneSet["protein_name"].apply(lambda x: x.strip("""['"]"""))
geneSet["protein_ID"] = geneSet["protein"].apply(lambda x: x.strip("""['"]""").replace("UniProt:", "UniProtKB:"))
geneSet = geneSet.reset_index(drop=True)  

geneSet['question'] = "Which gene plays the most significant role in how the drug" + geneSet['drug_name'] + " treats or impacts the disease " + geneSet['disease_name'] + "? Enumerate 5 answers and do not include anything else in your response. Each of your answer entities MUST be tagged with ** at the start AND end of the phrase (**....**), otherwise it will not be assessed"

print(geneSet)

In [None]:

geneSet.to_excel('/Users/mastorga/Documents/BTE-LLM/Prototype/logs/drug <- diseasebp/questionset_50questions_mechanisticgenes_09.03.25.xlsx')

In [None]:
tools = {
    "LLM_only": LLM_only,
    "Agentic_BTE": BTEx
}

In [None]:
tallies = run_test(
    df=geneSet,
    question_col="question",
    ans_col="gene_str",
    ans_id_col="protein_ID",
    tools=tools,
    check_answer_fn=checkAnswer,
    similarity_fn=idCosineSimilarity
)

print("Final tallies for the questions:", tallies)

In [None]:
# Save responses

geneSet.to_excel('/Users/mastorga/Documents/BTE-LLM/Prototype/logs/drug <- diseasebp/50questions_mechanisticgenes_09.03.25.xlsx')