In [None]:
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from typing import List, Any, Literal
from typing_extensions import TypedDict
from pprint import pprint
from time import sleep
import json
import requests

In [None]:
from os.path import dirname, join
from dotenv import load_dotenv
import os

# Set your OpenAI API key
load_dotenv("../.env")

if not os.environ.get("OPENAI_API_KEY"): #field to ask for OpenAI API key
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Please enter OpenAI API Key: ")

In [None]:
metaKG = requests.get("https://bte.transltr.io/v1/meta_knowledge_graph")

def findPredicate(subject: str, object: str):
    
    global metaKG

    predicates = []

    metaEdges = json.loads(metaKG.content).get("edges")

    for edge in metaEdges:
        if (edge.get("subject").lower() == subject.lower()) and (edge.get("object").lower() == object.lower()):
            p = edge.get("predicate")
            predicates.append(p)

    return predicates

In [None]:
def extractDict(rawstring: str):
    # Finding positions of the brackets
    start_index = rawstring.find("{")
    end_index = rawstring.find("}\n`")

    # Extracting dict object
    if start_index != -1 and end_index != -1:
        extracted_dict = rawstring[start_index:end_index + 1].strip()

    if extracted_dict:
        return json.loads(extracted_dict)
    else:
        # Returns raw dict if already valid
        try:
            if json.loads(rawstring):
                return rawstring
        except:
            return {"error": "could not parse dict"}  
        

In [None]:
# Original code
@tool("TRAPIQuery")
def TRAPIQuery(query: str, entity_data: dict, failed_trapis: list = []):
    """Use this tool to build a TRAPI JSON query"""
    
    MAX_RETRIES = 3
    llm = ChatOpenAI(temperature=0, model="gpt-4o")

    def choosePredicate(predicateList: List, query: str):
        class predicateChoice(TypedDict):
            """The most appropriate predicate chosen for the TRAPI query"""

            predicate: Literal[*predicateList]

        llm = ChatOpenAI(temperature=0, model="gpt-4o")

        predicatePrompt = """Choose the most specific predicate by examining the query closely and choosing the closest answer.  
            
            Here is the query: {query}
            """

        chosen_predicate = llm.with_structured_output(predicateChoice).invoke(predicatePrompt)

        return str(chosen_predicate["predicate"])

    def invoke_with_retries(prompt, parser_func, max_retries=3, delay=5):
        for attempt in range(max_retries):
            try:
                response = llm.invoke(prompt)
                return parser_func(response.content)
            except Exception as e:
                print(f"Retry {attempt+1}/{max_retries} failed: {e}")
                sleep(delay)
        return None

    def identify_nodes(query: str):
        nodeprompt = f"""
        Your task is to help build a TRAPI query by identifying the correct subject and object nodes from the list of nodes below. 

        Subject: The entity that initiates or is the focus of the relationship in your question.
        Object: The entity that is affected or related to the subject in your question.
        
        Each of these must have the prefix "biolink:". You must use the context and intent behind the user query to make the correct choice.

        Nodes: Disease, PhysiologicalProcess, BiologicalEntity, Gene, PathologicalProcess, Polypeptide, SmallMolecule, PhenotypicFeature

        Here is the user query: {query}

        Be as specific as possible as downstream results will be affected by your node choices.
        Your response MUST ONLY CONTAIN a dictionary object with "subject" and "object" as keys along with their corresponding values.
        """
        
        return invoke_with_retries(nodeprompt, extractDict)
    
    def build_TRAPI(query: str, subject_object: dict, predicate: str, entity_data: dict, failed_trapis: list):
        print("\n\n Building trapi....")

        # TRAPI example that could be included in the prompt
        trapi_ex = {
                "message": {
                "query_graph": {
                "nodes": {
                    "n0": {
                    "categories": [
                        "biolink:Disease"
                    ],
                    "ids": [
                        "MONDO:0016575"
                    ]
                    },
                    "n1": {
                    "categories": [
                        "biolink:PhenotypicFeature"
                    ]
                    }
                },
                "edges": {
                    "e01": {
                    "subject": "n0",
                    "object": "n1",
                    "predicates": [
                        "biolink:has_phenotype"
                    ]
                    }
                }
                }
        }
        }

        trapiPrompt = f"""
            You are a smart assistant that has access to a knowledge graph. 
            You are tasked with correctly parsing the user prompt into a TRAPI query. 
            Here's an example TRAPI query:
            {json.dumps(trapi_ex, indent=2)}

            Here is the actual user query: "{query}"

            Here are the biological entities and their IDs extracted from the query:
            {json.dumps(entity_data, indent=2)}

            Here are the nodes you MUST use:
            {json.dumps(subject_object, indent=2)}

            Here is the chosen predicate
            {predicate}

            Do NOT use the following TRAPI queries as they have failed to get results: 
            {failed_trapis}

            Some predicates have directionality ("treated_by" is NOT the same as "treats"). When defining the edges, you MUST use the correct predicate with the correct directionality.
            (e.g., biolink:Disease is "treated_by" (NOT "treats") biolink:SmallMolecule; 
            biolink:Disease as subject and biolink:Gene as object should result in "condition_associated_with_gene", NOT "gene_associated_with_condition")

            In this KG, the subject of the predicate is always the source (domain), and the object is the target (range). For example, if 'X treats Y', then X is the subject and Y is the object.
            
            Please make sure that only one node in the TRAPI query has the "ids" field (either "n0" or "n1" but NOT both); 
            decide which one would result in more useful results (for "Which of these genes are involved in a physiological process?", including the IDs for Gene would result in more useful results. In this case DO NOT include the ID for Physiological Process)

            Using these, output a JSON object containing the completed TRAPI query. 
        """

        return invoke_with_retries(trapiPrompt, extractDict)
    
    # Main dataflow
    subject_object = identify_nodes(query)
    if not subject_object:
        print("Could not determine subject/object nodes")
        return "Could not determine subject/object nodes"
    
    print(f"\n\nIdentified subject/object: {json.dumps(subject_object)}")
    
    predicateList = findPredicate(subject_object.get("subject"), subject_object.get("object"))

    # Extract all failed predicates from the failed TRAPI queries
    failed_predicates = set()
    for trapi_query in failed_trapis:
        try:
            preds = trapi_query.get("message", {}).get("query_graph", {}).get("edges", {}).get("e01", {}).get("predicates", [])
            failed_predicates.update(preds)
        except Exception as e:
            print(f"Error extracting predicates from failed_trapis: {e}")

    # Remove failed predicates from predicateList
    predicateList = [p for p in predicateList if p not in failed_predicates]

    predicate = choosePredicate(predicateList, query)

    trapi = build_TRAPI(query, subject_object, predicate, entity_data, failed_trapis)
    
    if trapi:
        return trapi

    # Retry if initial attempt fails
    for _ in range(MAX_RETRIES):
        trapi = build_TRAPI(query, subject_object, predicate, entity_data)
        if trapi:
            return trapi

    return "Invalid TRAPI"
            



    

In [None]:
# Modified manual code
@tool("TRAPIQuery")
def manualTRAPIQuery(query: str, entity_data: dict, failed_trapis: list = []):
    """Use this tool to build a TRAPI JSON query"""
    
    MAX_RETRIES = 3
    llm = ChatOpenAI(temperature=0, model="gpt-4o")

    def choosePredicate(predicateList: List, query: str):
        class predicateChoice(TypedDict):
            """The most appropriate predicate chosen for the TRAPI query"""

            predicate: Literal[*predicateList]

        llm = ChatOpenAI(temperature=0, model="gpt-4o")

        predicatePrompt = """Choose the most specific predicate by examining the query closely and choosing the closest answer.  
            
            Here is the query: {query}
            """

        chosen_predicate = llm.with_structured_output(predicateChoice).invoke(predicatePrompt)

        return str(chosen_predicate["predicate"])

    def invoke_with_retries(prompt, parser_func, max_retries=3, delay=5):
        for attempt in range(max_retries):
            try:
                response = llm.invoke(prompt)
                return parser_func(response.content)
            except Exception as e:
                print(f"Retry {attempt+1}/{max_retries} failed: {e}")
                sleep(delay)
        return None

    def identify_nodes(query: str):
        nodeprompt = f"""
        Your task is to help build a TRAPI query by identifying the correct subject and object nodes from the list of nodes below. 

        Subject: The entity that initiates or is the focus of the relationship in your question.
        Object: The entity that is affected or related to the subject in your question.
        
        Each of these must have the prefix "biolink:". You must use the context and intent behind the user query to make the correct choice.

        Nodes: Disease, PhysiologicalProcess, BiologicalEntity, Gene, PathologicalProcess, Polypeptide, SmallMolecule, PhenotypicFeature

        Here is the user query: {query}

        Be as specific as possible as downstream results will be affected by your node choices.
        Your response MUST ONLY CONTAIN a dictionary object with "subject" and "object" as keys along with their corresponding values.
        """
        
        return invoke_with_retries(nodeprompt, extractDict)
    
    def build_TRAPI(query: str, subject_object: dict, predicate: str, entity_data: dict, failed_trapis: list):
        print("\n\n Building trapi....")

        # Initialize node dictionary
        nodes = {}
    
        # Subject node
        nodes["n0"] = {
            "categories": [subject_object.get("subject")],
        }

        subject_id_prompt = f"""
                            Based on the following query and the identified subject-object, 
                            select the most appropriate subject ID from the entity data below.
                            ONLY RETURN THE ID.
                            
                            Here is the query: {query}
                            
                            Here are the subject and object: {subject_object}
                            
                            Here is the entity data: {entity_data}"""
        
        subject_id = llm.invoke(subject_id_prompt).content

        print(subject_id)

        if subject_id:
            nodes["n0"]["ids"] = [subject_id]
        
        # Object node
        nodes["n1"] = {
            "categories": [subject_object.get("object")],
        }

        # Edge
        edges = {
            "e01": {
                "subject": "n0",
                "object": "n1",
                "predicates": [predicate] if predicate else []
            }
        }

        # Build final TRAPI
        trapi = {
            "message": {
                "query_graph": {
                    "nodes": nodes,
                    "edges": edges
                }
            }
        }

        print(trapi)
    
        # Check if this TRAPI already failed before
        if trapi in failed_trapis:
            print("⚠️ Skipping TRAPI (was in failed list).")
            return None
    
        return trapi
    
    # Main dataflow
    subject_object = identify_nodes(query)
    if not subject_object:
        print("Could not determine subject/object nodes")
        return "Could not determine subject/object nodes"
    
    print(f"\n\nIdentified subject/object: {json.dumps(subject_object)}")
    
    predicateList = findPredicate(subject_object.get("subject"), subject_object.get("object"))

    # Extract all failed predicates from the failed TRAPI queries
    failed_predicates = set()
    for trapi_query in failed_trapis:
        try:
            preds = trapi_query.get("message", {}).get("query_graph", {}).get("edges", {}).get("e01", {}).get("predicates", [])
            failed_predicates.update(preds)
        except Exception as e:
            print(f"Error extracting predicates from failed_trapis: {e}")

    # Remove failed predicates from predicateList
    predicateList = [p for p in predicateList if p not in failed_predicates]

    predicate = choosePredicate(predicateList, query)

    trapi = build_TRAPI(query, subject_object, predicate, entity_data, failed_trapis)
    
    if trapi:
        return trapi

    # Retry if initial attempt fails
    for _ in range(MAX_RETRIES):
        trapi = build_TRAPI(query, subject_object, predicate, entity_data)
        if trapi:
            return trapi

    return "Invalid TRAPI"
            


In [None]:
# Sample query
query = "What diseases can Elvitegravir treat?"

In [None]:
# Sample entity_data
entity_data = {"viral genome": "UMLS:C0042720", "replication": "UMLS:C1160687", "biomedical": "UMLS:C0042720", "list": "No appropriate IDs were found", "Elvitegravir\n2": "UMLS:C2606637", "Diseases": "UMLS:C0012634", "HIV/AIDS": "UMLS:C0001175", "Elvitegravir": "UMLS:C2606637"}

In [None]:
TRAPIQuery.invoke(input={"query": query, "entity_data": entity_data})

In [None]:
manualTRAPIQuery.invoke(input={"query": query, "entity_data": entity_data})