In [1]:
import openai
from langchain.llms import OpenAI
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory

import os
import re
import json
import requests
import openai
from datetime import datetime

In [2]:
OPEN_AI_KEY = ''
openai.api_key = OPEN_AI_KEY


In [3]:
def query_chatgpt(prompt):
    response = openai.ChatCompletion.create(
          model="gpt-3.5-turbo",
          messages=[
                {"role": "system", "content": "You are an expert assistant that can turn user queries into GraphQL queries."},
                {"role": "user", "content": prompt}
            ],
          temperature=0,
          max_tokens=1024,
          top_p=1,
          frequency_penalty=0,
          presence_penalty=0,
            stop=["###"],
        )
    response_text = response.choices[0].message['content']
    return response_text

# swap the chatGPT model to enable extra long context
def query_chatgpt_16k(prompt):
    response = openai.ChatCompletion.create(
          model="gpt-3.5-turbo-16k",
          messages=[
                {"role": "system", "content": "You are an expert assistant that can turn user queries into GraphQL queries."},
                {"role": "user", "content": prompt}
            ],
          temperature=0,
          max_tokens=1024,
          top_p=1,
          frequency_penalty=0,
          presence_penalty=0,
            stop=["###"],
        )
    response_text = response.choices[0].message['content']
    return response_text

def query_langchain(llm, prompt):
    response = llm.predict(input=prompt)
    return response

def query_graphql(query_string):
    error_message, hits_list = None, None
    # Set base URL of GraphQL API endpoint
    base_url = "https://api.platform.opentargets.org/api/v4/graphql"

    # Perform POST request and check status code of response
    # This handles the cases where the Open Targets API is down or our query is invalid
    try:
        response = requests.post(base_url, json={"query": query_string})
        response.raise_for_status()
    except requests.exceptions.HTTPError as err:
        a = 1
    # Transform API response from JSON into Python dictionary and print in console
    api_response = json.loads(response.text)
    if 'data' not in api_response:
        error_message = api_response
        print('ERROR:\n {}\n\n'.format(error_message))
    elif api_response['data']==None:
        error_message = api_response['errors'][0]['message']
        print('ERROR:\n {}\n\n'.format(error_message))
    else:
        try:
            hits_list = api_response["data"]["search"]["hits"][0]
        except:
            print(api_response)
    
    return error_message, hits_list

def convert_json_response(hits_list, user_input):
    response = openai.ChatCompletion.create(
          model="gpt-3.5-turbo-16k",
          messages=[
                {"role": "system", "content": "You are an helpful assistant."},
                {"role": "user", "content": "Given the following json response, answer the question as a list. Json Response:\n{}\n\nQuestion:\n{}".format(hits_list, user_input)}
            ],
          temperature=0,
          max_tokens=1024,
          top_p=1,
          frequency_penalty=0,
          presence_penalty=0,
            stop=["###"],
        )
    response_text = response.choices[0].message['content']
    return response_text

In [4]:
user_inputs = ["Find the top 2 diseases linked to HBB", "What are the top 3 diseases associated with ABCA4?", "What are the targets of vorinostat?", "Find drugs that are used for treating ulcerative colitis.", "Which diseases are associated with the genes targetted by fasudil?", "Show all the diseases that have at least 5 pathways associated with Alzheimer"]


In [5]:
# read prompt template file
with open("graphql_schema.txt", "r") as f:
    prompt_template = f.read()

# read in GraphQL schema
with open("graphql_real_schema.txt", "r") as f:
    schema = f.read()

# Prime the target query for completion
prime_prompt = "query top_n_associated_diseases {\n  search(queryString:"

## Lets start with the simplest approach. 

In [6]:
# Custom input by the user
# user_input = "Find the top 2 diseases associated with BRCA1"
# user_input = input("How can I help you today?\n")

for user_input in user_inputs:
    print('-'*100)
    print('USER INPUT: {}\n'.format(user_input))
    # create the prompt
    prompt = "Given the following example, create the GraphQL query to answer the following user input.\n\nExample: \n{}\n\nUser Input: {}\n\nGenerated Query:\n".format(prompt_template, user_input)    
    response = query_chatgpt(prompt)
    query_string =  response
    print('QUERY:\n {}\n\n'.format(query_string))

    error_message, hits_list = query_graphql(query_string)

    response_text = convert_json_response(hits_list, user_input)
    print('ANSWER:\n{}'.format(response_text))

----------------------------------------------------------------------------------------------------
USER INPUT: Find the top 2 diseases linked to HBB

QUERY:
 query top_n_associated_diseases {
    search(queryString: "HBB", entityNames: "target") {
    hits { id,
           name, 
           entity,
            object {
               ... on Target {
              associatedDiseases(page: {index: 0, size: 2}) {
                 rows {
                  score
                   disease {
                     name
                       }
                    }
                }
            }
        }
     }
   }
 }


ANSWER:
1. Beta-thalassemia HBB/LCRB
2. Sickle cell anemia
----------------------------------------------------------------------------------------------------
USER INPUT: What are the top 3 diseases associated with ABCA4?

QUERY:
 query top_n_associated_diseases {
    search(queryString: "ABCA4", entityNames: "target") {
    hits { id,
           name, 
           entity,

### Analysis: 
This approach works well for the simple queries given in the repo, but quickly falters when exploring complex queries such as `Find drugs that are used for treating ulcerative colitis.` The main error appears to be that the ChatGPT model creates fields or gets fields incorrectly, since the model isn't grounding the query generation within the schema of the database.

## Grounding Output in DB Schema

In [7]:
for user_input in user_inputs:
    print('-'*100)
    print('USER INPUT: {}\n'.format(user_input))
    # create the prompt
    prompt = "Given the following schema and example, create the GraphQL query to answer the following user input.\n\nSchema:\n{}\n\nExample: \n{}\n\nUser Input: {}\n\nGenerated Query:\n".format(schema, prompt_template, user_input)
    query_string = query_chatgpt_16k(prompt)
    print('QUERY:\n {}\n\n'.format(query_string))

    error_message, hits_list = query_graphql(query_string)

    response_text = convert_json_response(hits_list, user_input)
    print('ANSWER:\n{}'.format(response_text))

----------------------------------------------------------------------------------------------------
USER INPUT: Find the top 2 diseases linked to HBB

QUERY:
 query top_n_linked_diseases {
  search(queryString: "HBB", entityNames: "target") {
    hits {
      id
      name
      entity
      object {
        ... on Target {
          linkedDiseases {
            rows(size: 2) {
              name
            }
          }
        }
      }
    }
  }
}


ERROR:
 Cannot query field 'linkedDiseases' on type 'Target'. Did you mean 'associatedDiseases'? (line 9, column 11):
          linkedDiseases {
          ^


ANSWER:
Unfortunately, without the actual JSON response, I am unable to provide a list of the top 2 diseases linked to HBB. Please provide the JSON response so that I can assist you further.
----------------------------------------------------------------------------------------------------
USER INPUT: What are the top 3 diseases associated with ABCA4?

QUERY:
 query top_n_associ

### Analysis: 
The model performs better on the question `Find drugs that are used for treating ulcerative colitis.` However, we still see the model's output still doesn't fit into the schema perfectly, resulting in errors. This might be due to the fact that ChatGPT's performance decreases as we use inputs over 8,000 tokens. 

## What if we used ChatGPT to create a consolidated schema and used it to generate the model's output. 

In [8]:
for user_input in user_inputs:
    print('-'*100)
    print('USER INPUT: {}\n'.format(user_input))
    # create the prompt
    # create the prompt
    prompt = "Given the following schema, create the consolidated the GraphQL schema to only include fields that can be used to answer the following user input. The consolidated schema should only contains fields from the original schema that are useful in the generation of an answer to the user input.\n\nSchema:\n{}\n\nUser Input:\n{}\n\nConsolidated Schema:\n".format(schema, user_input)
    
    response = openai.ChatCompletion.create(
          model="gpt-3.5-turbo-16k",
          messages=[
                {"role": "system", "content": "You are an helpful assistant with extensive knowledge of GraphQL and the Open Targets database"},
                {"role": "user", "content": prompt}
            ],
          temperature=0,
          max_tokens=1024,
          top_p=1,
          frequency_penalty=0,
          presence_penalty=0,
            stop=["###"],
        )

    consolidated_schema = response.choices[0].message['content']
    # print('Consolidated Schema:\n{}\n'.format(consolidated_schema))
    
    prompt = "Given the following schema and example, create the GraphQL query to answer the following user input.\n\nSchema:\n{}\n\nExample: \n{}\n\nUser Input: {}\n\nGenerated Query:\n".format(consolidated_schema, prompt_template, user_input)
    # print('Prompt:\n{}\n'.format(prompt))
    query_string = query_chatgpt(prompt)
    print('QUERY:\n {}\n\n'.format(query_string))

    error_message, hits_list = query_graphql(query_string)

    response_text = convert_json_response(hits_list, user_input)
    print('ANSWER:\n{}'.format(response_text))

----------------------------------------------------------------------------------------------------
USER INPUT: Find the top 2 diseases linked to HBB

QUERY:
 query top_n_linked_diseases {
  disease(efoId: "HBB") {
    linkedDiseases {
      rows(page: {index: 0, size: 2}) {
        name
      }
    }
  }
}


ERROR:
 Cannot query field 'linkedDiseases' on type 'Disease'. (line 3, column 5):
    linkedDiseases {
    ^


ANSWER:
Unfortunately, without the actual JSON response, I am unable to provide you with the top 2 diseases linked to HBB. Please provide the JSON response so that I can assist you further.
Final QUERY: 
Unfortunately, without the actual JSON response, I am unable to provide you with the top 2 diseases linked to HBB. Please provide the JSON response so that I can assist you further.
----------------------------------------------------------------------------------------------------
USER INPUT: What are the top 3 diseases associated with ABCA4?

QUERY:
 query top_n_assoc

KeyboardInterrupt: 

### Analysis
That approach did much worse... Lets roll back the changes.

## Error Correction with LangChain

In [7]:
for user_input in user_inputs:
    chat = OpenAI(temperature=0.2, top_p=0.1 ,openai_api_key=openai.api_key)
    llm = ConversationChain(
        llm=chat, 
        memory=ConversationBufferMemory(), 
        verbose=False
    )
    print('-'*100)
    print('USER INPUT: {}\n'.format(user_input))
    # create the prompt
    prompt = "Given the following example, create the GraphQL query to answer the following user input. If a number isn't given, return all possible answers. {}\n\nUser Input: {}\n\nGenerated Query:\n".format(prompt_template, user_input)
    query_string = query_langchain(llm, prompt)

    print('INITIAL QUERY:\n {}\n\n'.format(query_string))
    error_message, hits_list = query_graphql(query_string)
    runs = 0
    
    error_message, hits_list = query_graphql(query_string)
    runs = 0
    while error_message:
        prompt = "The query failed with the following error: {}\n Can you regenerate the query to fix the error? \nRegenerated Query:\n".format(error_message)
        query_string = query_langchain(llm, prompt)
        error_message, hits_list = query_graphql(query_string)
        runs+=1
        if runs==10:
            error_message=None
            break

    response_text = convert_json_response(hits_list, user_input)    
    print(hits_list)
    print('ANSWER:\n{}'.format(response_text))

----------------------------------------------------------------------------------------------------
USER INPUT: Find the top 2 diseases linked to HBB

INITIAL QUERY:
  query top_n_associated_diseases {
    search(queryString: "HBB", entityNames: "target") {
    hits { id,
           name, 
           entity,
            object {
               ... on Target {
              associatedDiseases(page: {index: 0, size: 2}) {
                 rows {
                  score
                   disease {
                     name
                       }
                    }
                }
            }
        }
     }
   }
 }
}


ERROR:
 {'syntaxError': "Syntax error while parsing GraphQL query. Invalid input '}', expected ExecutableDefinition or TypeSystemDefinition (line 21, column 1):\n}\n^", 'locations': [{'line': 21, 'column': 1}]}


ERROR:
 {'syntaxError': "Syntax error while parsing GraphQL query. Invalid input '}', expected ExecutableDefinition or TypeSystemDefinition (line 21, c

### Analysis
The basic queries work and we can see that the model tries to correct the errors, as the API errors change with every iteration. This is promising. It is better than the original approaches, but the output of the model is definately wrong for some, such as the last question `Show all the diseases that have at least 5 pathways associated with Alzheimer` which outputs `Alzheimer disease` and gene targets. The output for `Find drugs that are used for treating ulcerative colitis.` also differs from previous attempts. 

## Error Correction with LangChain with Schema Grounding

In [9]:
for user_input in user_inputs:
    chat = OpenAI(temperature=0.1, top_p=0.1 ,openai_api_key=openai.api_key, model_name='gpt-3.5-turbo-16k')
    llm = ConversationChain(
        llm=chat, 
        memory=ConversationBufferMemory(), 
        verbose=False
    )
    print('-'*50)
    print('USER INPUT: {}\n'.format(user_input))
    # create the prompt
    prompt = "Given the following schema and example, create the GraphQL query to answer the following user input. If a number isn't given, return all possible answers.\n\nSchema:\n{}\n\nExample: \n{}\n\nUser Input: {}\n\nGenerated Query:\n".format(schema, prompt_template, user_input)

    query_string = query_langchain(llm, prompt)

    print('INITIAL QUERY:\n {}\n\n'.format(query_string))
    error_message, hits_list = query_graphql(query_string)
    runs = 0
    while error_message:
        prompt = "The query failed with the following error: {}\n Can you regenerate the query to fix the error? \nRegenerated Query:\n".format(error_message)
        query_string = query_langchain(llm, prompt)
        error_message, hits_list = query_graphql(query_string)
        runs+=1
        if runs==10:
            error_message=None
            break

    response_text = convert_json_response(hits_list, user_input)    
    print(hits_list)
    print('ANSWER:\n{}'.format(response_text))

--------------------------------------------------
USER INPUT: Find the top 2 diseases linked to HBB

INITIAL QUERY:
 query top_n_linked_diseases {
    search(queryString: "HBB", entityNames: "target") {
    hits { id,
           name, 
           entity,
            object {
               ... on Target {
              linkedDiseases {
                 count
                 rows(page: {index: 0, size: 2}) {
                  name
                }
              }
            }
        }
     }
   }
 }


ERROR:
 Cannot query field 'linkedDiseases' on type 'Target'. Did you mean 'associatedDiseases'? (line 8, column 15):
              linkedDiseases {
              ^


ERROR:
 Unknown argument 'page' on field 'rows' of type 'AssociatedDiseases'. (line 10, column 23):
                 rows(page: {index: 0, size: 2}) {
                      ^


ERROR:
 Cannot query field 'name' on type 'AssociatedDisease'. (line 11, column 19):
                  name
                  ^


{'id': 'ENSG000

### Analysis
While the ground helps the model out, it doesn't help the model correct the errors as effectively as hoped. There can be some more work around generating suggestions using the schema. 

## Schema Informed Suggestions

In [10]:
for user_input in user_inputs:
    chat = OpenAI(temperature=0.2, top_p=0.1 ,openai_api_key=openai.api_key)
    llm = ConversationChain(
        llm=chat, 
        memory=ConversationBufferMemory(), 
        verbose=False
    )
    print('-'*100)
    print('USER INPUT: {}\n'.format(user_input))
    # create the prompt
    prompt = "Given the following example, create the GraphQL query to answer the following user input. If a number isn't given, return all possible answers. {}\n\nUser Input: {}\n\nGenerated Query:\n".format(prompt_template, user_input)
    query_string = query_langchain(llm, prompt)

    print('INITIAL QUERY:\n {}\n\n'.format(query_string))
    error_message, hits_list = query_graphql(query_string)
    runs = 0
    
    while error_message:
        if 'syntaxError' not in error_message:
            # if its not a syntax error, we can ask GPT for suggestions
            prompt = "Given a query, error message and schema, generate a suggestion on how the query can be improved in natural language citing specific fields within the schema. Generate only a suggestion, not a query.\nSchema:\n{} \nQuery:\n{}\nError Message:\n{}Suggestion:\n".format(schema, query_string, error_message)
            suggestion = query_chatgpt_16k(prompt)
            # print(suggestion)
            prompt = "The query failed with the following error: {}\n Here is a suggestion on how to correct the query. {} Can you regenerate the query to fix the error? \nRegenerated Query:\n".format(error_message, suggestion)
        
        else:
            # suggestions on syntax errors confuse the model
            prompt = "The query failed with the following error: {}\n Can you regenerate the query to fix the error? \nRegenerated Query:\n".format(error_message)
        query_string = query_langchain(llm, prompt)
        error_message, hits_list = query_graphql(query_string)
        runs+=1
        if runs==20:
            error_message=None
            break

    response_text = convert_json_response(hits_list, user_input)    
    print(hits_list)
    print('ANSWER:\n{}'.format(response_text))

----------------------------------------------------------------------------------------------------
USER INPUT: Find the top 2 diseases linked to HBB

INITIAL QUERY:
  query top_n_associated_diseases {
    search(queryString: "HBB", entityNames: "target") {
    hits { id,
           name, 
           entity,
            object {
               ... on Target {
              associatedDiseases(page: {index: 0, size: 2}) {
                 rows {
                  score
                   disease {
                     name
                       }
                    }
                }
            }
        }
     }
   }
 }
}


ERROR:
 {'syntaxError': "Syntax error while parsing GraphQL query. Invalid input '}', expected ExecutableDefinition or TypeSystemDefinition (line 21, column 1):\n}\n^", 'locations': [{'line': 21, 'column': 1}]}


{'id': 'ENSG00000244734', 'name': 'HBB', 'entity': 'target', 'object': {'associatedDiseases': {'rows': [{'score': 0.793993706260743, 'disease': {'name'