# Class objects

In [326]:
from neo4j import GraphDatabase
import json
import openai
import requests
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored

host ="neo4j://ec2-107-22-100-129.compute-1.amazonaws.com:7687/"
user="mingran"
password="mingran123"

graph = GraphDatabase.driver(host, auth=(user, password))
print('connect to neo4j successfully')


GPT_MODEL = "gpt-3.5-turbo-0613"
openai_api_key = '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'

Failed to write data to connection IPv4Address(('ec2-107-22-100-129.compute-1.amazonaws.com', 7687)) (ResolvedIPv4Address(('107.22.100.129', 7687)))


connect to neo4j successfully


In [327]:
database_schema_dict = get_database_info(graph)
database_schema_string = "\n".join(
    [
        f"Node: {keys['node_type']}\nrpropertykey: {', '.join(keys['properties'])} \nrelationship: {keys['relationships']}\n"
        for keys in database_schema_dict
    ]
)


functions = [
    {
        "name": "ask_database",
        "description": "Use this function to answer user questions about music. Output should be a fully formed SQL query.",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            Cypher query extracting info to answer the user's question.
                            Cypher query should be written using this database schema:
                            {database_schema_string}
                            The query should be returned in plain text, not in JSON.
                            """,
                }
            },
            "required": ["query"],

            
        },
    },


    {
        "name": "ask_relevant_questions",
        "description": "Use this function to ask the user relevant questions to clarify their request.",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description":  f"""
                            Ask relevant question to the user's question. Make 
                            sure to include the relevant information from the user's question.
                            The query should be returned in plain text, not in JSON.
                            """   
            }
            },
            "required": ["query"]
        },
    
    },
   

]

In [328]:
import requests
from termcolor import colored
# from retry import retry, wait_random_exponential, stop_after_attempt
from neo4j import GraphDatabase

class Database:

    def __init__(self, model, host, user, password):
        self.graph = GraphDatabase.driver(host, auth=(user, password))
        self.model = model

    def get_node_info(self):
        """Return a node type."""
        with self.graph.session() as session:
            labels = []
            node_info = session.run('Call db.labels()')

            for record in node_info:
                labels.append(record['label'])
            return labels

    def get_relationship_info(self):
        """Return a list of relationship names."""
        with self.graph.session() as session:
            relationship_info = session.run("""MATCH (n)-[r]->(m)
                                            RETURN DISTINCT head(labels(n)) AS first_node, 
                                            type(r) AS relationship_type, 
                                            head(labels(m)) AS second_node""")
            relationships = []

            for record in relationship_info:
                relationships.append([record['first_node'], record['relationship_type'], record['second_node']])
            return relationships

    def get_node_properties(self):
        with self.graph.session() as session:
            properties = []
            property = session.run("""CALL db.propertyKeys()""")
            for record in property:
                properties.append(record['propertyKey'])
            return properties

    def get_database_info(self):
        """Return a list of dicts containing the table name and columns for each table in the database."""
        table_dicts = []
        node_type = self.get_node_info()
        properties = self.get_node_properties()
        relationships = self.get_relationship_info()
        table_dicts.append({"node_type": node_type,
                            "properties": properties,
                            "relationships": relationships})

        return table_dicts
    

    def ask_database(self, query):
        """Function to query Neo4j database with a provided cypher query."""
        try:
            with self.graph.session() as session:
                results = session.run(query)
                results = [r.values()[0] for r in results]
        except Exception as e:
            results = f"query failed with error: {e}"
        return str(results)
    

    def ask_relevant_questions(self, query):
        try:
            results = openai.ChatCompletion.create(
            model=self.model,
            messages=[
                {
                    "role": "user",
                    "content": f"""Write a relevant question to the user's input question.
                            User question: {query}
                            The question should be returned in plain text, not in JSON.""",
                }
            ],
            temperature=0,
        )
        except Exception as e:
            results = f"query failed with error: {e}"
       
        return results['choices'][0]['message']['content']

class OpenAIRequest:

    def __init__(self, openai_api_key, model):
        self.openai_api_key = openai_api_key
        self.model = model

    @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
    def chat_completion_request(self, messages, functions=None, function_call=None):
        headers = {
            "Content-Type": "application/json",
            "Authorization": "Bearer " + self.openai_api_key,
        }
        json_data = {"model": self.model, "messages": messages}
        if functions is not None:
            json_data.update({"functions": functions})
        if function_call is not None:
            json_data.update({"function_call": function_call})
        try:
            response = requests.post(
                "https://api.openai.com/v1/chat/completions",
                headers=headers,
                json=json_data,
            )
            return response
        except Exception as e:
            print("Unable to generate ChatCompletion response")
            print(f"Exception: {e}")
            return e
        
    def execute_function_call(self, message, db):
        
        if message["function_call"]["name"] == "ask_database":
            query = json.loads(message["function_call"]["arguments"])["query"]
            results = db.ask_database(query)  # assume `ask_database` is a method of `Database` class
        elif message["function_call"]["name"] == "ask_relevant_questions":
            query = json.loads(message["function_call"]["arguments"])["query"]
            results = db.ask_relevant_questions(query)  # assume `ask_relevant_questions` is a method of `Database` class
        else:
            results = f"Error: function {message['function_call']['name']} does not exist"
        return results

class MessageHandler:

    def pretty_print_conversation(self, messages):
        role_to_color = {
            "system": "red",
            "user": "green",
            "assistant": "blue",
            "function": "magenta",
        }
        formatted_messages = []
        for message in messages:
            if message["role"] == "system":
                formatted_messages.append(f"system: {message['content']}\n")
            elif message["role"] == "user":
                formatted_messages.append(f"user: {message['content']}\n")
            elif message["role"] == "assistant" and message.get("function_call"):
                formatted_messages.append(f"assistant: {message['function_call']}\n")
            elif message["role"] == "assistant" and not message.get("function_call"):
                formatted_messages.append(f"assistant: {message['content']}\n")
            elif message["role"] == "function":
                formatted_messages.append(f"function ({message['name']}): {message['content']}\n")
        for formatted_message in formatted_messages:
            print(
                colored(
                    formatted_message,
                    role_to_color[messages[formatted_messages.index(formatted_message)]["role"]],
                )
        )

In [336]:
model=GPT_MODEL

db = Database(model, host, user, password)
openai_req = OpenAIRequest(openai_api_key, GPT_MODEL)
msg_handler = MessageHandler()

messages = []
messages.append({"role": "system", "content": "Answer user questions by generating cypher queries against the gene disease drug databases."})
messages.append({"role": "user", "content": "What genes downregulate APC?"})
chat_response = openai_req.chat_completion_request(messages, functions, function_call={"name": "ask_database"})
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)

if assistant_message.get("function_call"):
    results = openai_req.execute_function_call(assistant_message, db)
    messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": results})
    
msg_handler.pretty_print_conversation(messages)

[31msystem: Answer user questions by generating cypher queries against the gene disease drug databases.
[0m
[32muser: What genes downregulate APC?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "MATCH (g:GeneProtein)-[:DOWNREGULATES]->(t:GeneProtein {name: \'APC\'}) RETURN g.name"\n}'}
[0m
[35mfunction (ask_database): ['DVL1', 'PRKACA']
[0m
