In [None]:
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from typing import List, Any
from typing import Literal
from pprint import pprint
from time import sleep
from typing_extensions import TypedDict, Annotated
import json
import requests

In [None]:
def choosePredicate(predicateList, query):
    
    class predicateChoice(TypedDict):
        """The most appropriate predicate chosen for the TRAPI query"""

        predicate: Literal[*predicateList]

    llm = ChatOpenAI(temperature=0, model="gpt-4o")

    chosen_predicate = llm.with_structured_output(predicateChoice).invoke(f"Choose the most appropriate predicate based on the query: {query}")

    return chosen_predicate

In [None]:

predicateList = ['biolink:subject_of_treatment_application_or_study_for_treatment_by', 'biolink:caused_by', 'biolink:condition_prevented_by', 'biolink:related_to', 'biolink:affected_by', 'biolink:disrupted_by', 'biolink:condition_predisposed_by', 'biolink:produces', 'biolink:condition_exacerbated_by', 'biolink:treated_by', 'biolink:has_contraindication', 'biolink:treatment_applications_from', 'biolink:adverse_event_of', 'biolink:treated_in_studies_by', 'biolink:occurs_together_in_literature_with', 'biolink:contribution_from', 'biolink:associated_with', 'biolink:tested_by_preclinical_trials_of', 'biolink:tested_by_clinical_trials_of', 'biolink:correlated_with', 'biolink:increased_likelihood_associated_with', 'biolink:associated_with_sensitivity_to', 'biolink:associated_with_resistance_to', 'biolink:negatively_correlated_with', 'biolink:positively_correlated_with', 'biolink:treats_or_applied_or_studied_to_treat', 'biolink:causes', 'biolink:preventative_for_condition', 'biolink:related_to', 'biolink:affects', 'biolink:disrupts', 'biolink:predisposes_to_condition', 'biolink:produced_by', 'biolink:exacerbates_condition', 'biolink:treats', 'biolink:contraindicated_in', 'biolink:applied_to_treat', 'biolink:has_adverse_event', 'biolink:studied_to_treat', 'biolink:occurs_together_in_literature_with', 'biolink:contributes_to', 'biolink:associated_with', 'biolink:in_preclinical_trials_for', 'biolink:in_clinical_trials_for', 'biolink:correlated_with', 'biolink:associated_with_increased_likelihood_of', 'biolink:sensitivity_associated_with', 'biolink:resistance_associated_with', 'biolink:diagnoses', 'biolink:ameliorates_condition', 'biolink:negatively_correlated_with', 'biolink:positively_correlated_with', 'biolink:ameliorates']

In [None]:
choosePredicate(predicateList=predicateList, query="Which drugs can treat Allergic rhinitis?")

In [None]:
ent_prompt = f"""You are a helpful assistant that can extract biological terms/entities from a given query. 
                    These might include diseases (such as malaria, early onset dementia, parkinsonism), genes, proteins, biological entities (such as viruses), etc.

                    Do NOT include common nouns/phrases such as "gene", "compound", "pathogen", "associated with", etc. 
                    Your response must ONLY include discrete proper noun entities such as "early onset dementia" or "sickle cell anemia". 
                    Make sure to include symptoms such as "fever" or "cough".

                    If extracting terms concerning proteins/genes such as "ESR1 upregulation", only return the gene/protein name ("ESR1").
                    You must always return the full phrase/long form of each biomedical entity 
                    (for example, "type I and type II diabetes" should result in "type I diabetes" and "type II diabetes"; "DNA polymerases of human vs. mice" should result in "human DNA polymerase" and "mouse DNA polymerase")
                    Return results as a list.
                    Here is your query: {query}"""

    entities = llm.invoke(ent_prompt).content

    verify_prompt = f"""Given the following query, have all relevant biomedical terms been extracted?
                    {query}

                    Here are the extracted biomedical terms: {entities}
                    
                    Return TRUE if all biomedical entities were extracted. Return FALSE if there are any missing biomedical entities from the list.
                    Your response MUST be boolean"""
    
    verified = llm.invoke(verify_prompt).content
    
    if verified.lower() != "true":
        retry_prompt = f"""Given the following query, which biomedical terms are missing from the list?
                    {query}

                    Here are the extracted biomedical terms: {entities}

                    Here were the extraction instructions: {ent_prompt}
                    
                    Return a list of ONLY the missing biomedical terms"""
        
        addtl_ents = llm.invoke(retry_prompt).content
        entities = entities + " " + addtl_ents
    
    print(entities + "\n")

In [None]:
metaKG = requests.get("https://bte.transltr.io/v1/meta_knowledge_graph")

def findPredicate(subject: str, object: str):
    
    global metaKG

    predicates = []

    metaEdges = json.loads(metaKG.content).get("edges")

    for edge in metaEdges:
        if (edge.get("subject").lower() == subject.lower()) and (edge.get("object").lower() == object.lower()):
            p = edge.get("predicate")
            predicates.append(p)
        if (edge.get("subject").lower() == object.lower()) and (edge.get("object").lower() == subject.lower()):
            p = edge.get("predicate")
            predicates.append(p)

    return predicates

def extractDict(rawstring: str):
    # Finding positions of the brackets
    start_index = rawstring.find("{")
    end_index = rawstring.find("}\n`")

    # Extracting dict object
    if start_index != -1 and end_index != -1:
        extracted_dict = rawstring[start_index:end_index + 1].strip()

    if extracted_dict:
        return json.loads(extracted_dict)
    else:
        # Returns raw dict if already valid
        try:
            if json.loads(rawstring):
                return rawstring
        except:
            return {"error": "could not parse dict"}  
        

@tool("TRAPIQuery")
def TRAPIQuery(query: str, entity_data: dict, failed_trapis: list = []):
    """Use this tool to build a TRAPI JSON query"""
    
    MAX_RETRIES = 3
    llm = ChatOpenAI(temperature=0, model="gpt-4o")

    def choosePredicate(predicateList: List, query: str):
        class predicateChoice(TypedDict):
            """The most appropriate predicate chosen for the TRAPI query"""

            predicate: Literal[*predicateList]

        llm = ChatOpenAI(temperature=0, model="gpt-4o")

        predicatePrompt = """Choose the most specific predicate by examining the query closely and choosing the closest answer.  
            
            Here is the query: {query}
            """

        chosen_predicate = llm.with_structured_output(predicateChoice).invoke(predicatePrompt)

        return str(chosen_predicate["predicate"])

    def invoke_with_retries(prompt, parser_func, max_retries=3, delay=5):
        for attempt in range(max_retries):
            try:
                response = llm.invoke(prompt)
                return parser_func(response.content)
            except Exception as e:
                print(f"Retry {attempt+1}/{max_retries} failed: {e}")
                sleep(delay)
        return None

    def identify_nodes(query: str):
        nodeprompt = f"""
        Your task is to help build a TRAPI query by identifying the correct subject and object nodes from the list of nodes below. 

        Subject: The entity that initiates or is the focus of the relationship in your question.
        Object: The entity that is affected or related to the subject in your question.
        
        Each of these must have the prefix "biolink:". You must use the context and intent behind the user query to make the correct choice.

        Nodes: Disease, PhysiologicalProcess, BiologicalEntity, Gene, PathologicalProcess, Polypeptide, SmallMolecule, PhenotypicFeature

        Here is the user query: {query}

        Be as specific as possible as downstream results will be affected by your node choices.
        Your response MUST ONLY CONTAIN a dictionary object with "subject" and "object" as keys along with their corresponding values.
        """
        
        return invoke_with_retries(nodeprompt, extractDict)
    
    def build_TRAPI(query: str, subject_object: dict, predicate: str, entity_data: dict, failed_trapis: list):
        print("\n\n Building trapi....")

        # TRAPI example that could be included in the prompt
        trapi_ex = {
                "message": {
                "query_graph": {
                "nodes": {
                    "n0": {
                    "categories": [
                        "biolink:Disease"
                    ],
                    "ids": [
                        "MONDO:0016575"
                    ]
                    },
                    "n1": {
                    "categories": [
                        "biolink:PhenotypicFeature"
                    ]
                    }
                },
                "edges": {
                    "e01": {
                    "subject": "n0",
                    "object": "n1",
                    "predicates": [
                        "biolink:has_phenotype"
                    ]
                    }
                }
                }
        }
        }

        trapiPrompt = f"""
            You are a smart assistant that has access to a knowledge graph. 
            You are tasked with correctly parsing the user prompt into a TRAPI query. 
            Here's an example TRAPI query:
            {json.dumps(trapi_ex, indent=2)}

            Here is the actual user query: "{query}"

            Here are the biological entities and their IDs extracted from the query:
            {json.dumps(entity_data, indent=2)}

            Here are the nodes you MUST use:
            {json.dumps(subject_object, indent=2)}

            Here is the chosen predicate
            {predicate}

            Do NOT use the following TRAPI queries as they have failed to get results: 
            {failed_trapis}

            Some predicates have directionality ("treated_by" is NOT the same as "treats"). When defining the edges, you MUST use the correct predicate with the correct directionality.
            (e.g., biolink:Disease is "treated_by" (NOT "treats") biolink:SmallMolecule; 
            biolink:Disease as subject and biolink:Gene as object should result in "condition_associated_with_gene", NOT "gene_associated_with_condition")

            In this KG, the subject of the predicate is always the source (domain), and the object is the target (range). For example, if 'X treats Y', then X is the subject and Y is the object.
            
            Please make sure that only one node in the TRAPI query has the "ids" field (either "n0" or "n1" but NOT both); 
            decide which one would result in more useful results (for "Which of these genes are involved in a physiological process?", including the IDs for Gene would result in more useful results. In this case DO NOT include the ID for Physiological Process)

            Using these, output a JSON object containing the completed TRAPI query. 
        """

        return invoke_with_retries(trapiPrompt, extractDict)
    
    # Main dataflow
    subject_object = identify_nodes(query)
    if not subject_object:
        print("Could not determine subject/object nodes")
        return "Could not determine subject/object nodes"
    
    print(f"\n\nIdentified subject/object: {json.dumps(subject_object)}")
    
    predicateList = findPredicate(subject_object.get("subject"), subject_object.get("object"))

    for item in failed_trapis:
        failed_predicates = item.get("edges", "").get("e01", "").get("predicates", "")

        

    predicate = choosePredicate(predicateList, query)

    trapi = build_TRAPI(query, subject_object, predicate, entity_data, failed_trapis)
    
    if trapi:
        return trapi

    # Retry if initial attempt fails
    for _ in range(MAX_RETRIES):
        trapi = build_TRAPI(query, subject_object, predicate, entity_data)
        if trapi:
            return trapi

    return "Invalid TRAPI"

In [None]:
trial_ent_data = {'Allergic rhinitis': 'UMLS:C2607914', 'Response': 'UMLS:C0871261', 'histamine': 'UMLS:C0019588'}

TRAPIQuery.invoke(input={"query": "Which drugs can treat Allergic rhinitis?", "entity_data": trial_ent_data})