In [27]:
!pip install openai neo4j python-dotenv

import os, dotenv

dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


True

In [28]:
node_properties_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
WITH label AS nodeLabels, collect(property) AS properties
RETURN {labels: nodeLabels, properties: properties} AS output

"""

rel_properties_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
WITH label AS nodeLabels, collect(property) AS properties
RETURN {type: nodeLabels, properties: properties} AS output
"""

rel_query = """
CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE type = "RELATIONSHIP" AND elementType = "node"
RETURN {source: label, relationship: property, target: other} AS output
"""

In [29]:
from neo4j import GraphDatabase
from neo4j.exceptions import CypherSyntaxError
import openai


def schema_text(node_props, rel_props, rels):
    return f"""
  This is the schema representation of the Neo4j database.
  Node properties are the following:
  {node_props}
  Relationship properties are the following:
  {rel_props}
  Relationship point from source to target nodes
  {rels}
  Make sure to respect relationship types and directions
  """


class Neo4jGPTQuery:
    def __init__(self, url, user, password, openai_api_key, database="neo4j"):
        self.driver = GraphDatabase.driver(url, auth=(user, password), database=database)
        openai.api_key = openai_api_key
        # construct schema
        self.schema = self.generate_schema()


    def generate_schema(self):
        node_props = self.query_database(node_properties_query)
        rel_props = self.query_database(rel_properties_query)
        rels = self.query_database(rel_query)
        return schema_text(node_props, rel_props, rels)

    def refresh_schema(self):
        self.schema = self.generate_schema()

    def get_system_message(self):
        return f"""
        Task: Generate Cypher queries to query a Neo4j graph database based on the provided schema definition. These queries will be used inside NeoDash reports.
        Documentation for NeoDash is here : https://neo4j.com/labs/neodash/2.2/
        Instructions:
        Use only the provided relationship types and properties.
        Do not use any other relationship types or properties that are not provided.
        The Cypher RETURN clause must contained certain variables, based on the report type asked for.
        Report types :
        Single Value - A single value of a single variable
        Pie Chart - Two variables named category and value
        Sankey - Three variables, two being a node object (and not a property value) and one representing a relationship object (and not a property value).
        Schema:
        {self.schema}
        """

    def query_database(self, neo4j_query, params={}):
        with self.driver.session() as session:
            result = session.run(neo4j_query, params)
            output = [r.values() for r in result]
            output.insert(0, result.keys())
            return output

    def construct_cypher(self, question, history=None):
        messages = [
            {"role": "system", "content": self.get_system_message()},
            {"role": "user", "content": question},
        ]
        # Used for Cypher healing flows
        if history:
            messages.extend(history)

        completions = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            temperature=0.0,
            max_tokens=1000,
            messages=messages
        )
        return completions.choices[0].message.content

    def run(self, question, history=None, retry=False):
        # Construct Cypher statement
        cypher = self.construct_cypher(question, history)
        print(cypher)
        try:
            
            return self.query_database(cypher)
        # Self-healing flow
        except CypherSyntaxError as e:
            # If out of retries
            if not retry:
              return "Invalid Cypher syntax"
        # Self-healing Cypher flow by
        # providing specific error to GPT-4
            print("Retrying")
            return self.run(
                question,
                [
                    {"role": "assistant", "content": cypher},
                    {
                        "role": "user",
                        "content": f"""This query returns an error: {str(e)} 
                        Give me a improved query that works without any explanations or apologies""",
                    },
                ],
                retry=False
            )


### I am using my infamous Citation database from Bloom training here

In [30]:
demo_db = Neo4jGPTQuery(
    url=os.environ["DATABASE_URL"],
    user=os.environ["DATABASE_USER"],
    password=os.environ["DATABASE_PASSWORD"],
    database=os.environ["DATABASE_NAME"],
    openai_api_key=os.environ["OPENAI_API_KEY"],
)


In [33]:
demo_db.run("""
What is the title of the most cited Article ? Report type : Single value. Plain cypher code, no explanations.
""")

MATCH (a:Article)
RETURN a.title
ORDER BY a.n_citation DESC
LIMIT 1


[['a.title'], ['Seizing Power: Shaders and Storytellers']]

In [25]:
demo_db.run("""
Number of articles per year ? Report type : Pie chart. Plain cypher code, no explanations.
""")

MATCH (a:Article)
RETURN a.year AS category, count(*) AS value


[['category', 'value'],
 [2006, 7536],
 [2009, 817],
 [2003, 3678],
 [2002, 3612],
 [1998, 1620],
 [1999, 1446],
 [2011, 1162],
 [2005, 5056],
 [1994, 209],
 [2000, 1985],
 [2004, 5064],
 [2001, 2719],
 [1993, 213],
 [1997, 902],
 [2013, 959],
 [2012, 1935],
 [2014, 2040],
 [1991, 138],
 [1996, 243],
 [2008, 607],
 [2010, 812],
 [1990, 162],
 [1976, 118],
 [1978, 115],
 [1992, 129],
 [1983, 83],
 [1960, 53],
 [1973, 93],
 [1988, 113],
 [2015, 2075],
 [1979, 77],
 [1972, 116],
 [1995, 228],
 [1985, 94],
 [1977, 83],
 [1974, 91],
 [1980, 49],
 [1959, 42],
 [1981, 85],
 [1975, 80],
 [2007, 640],
 [1982, 81],
 [1989, 198],
 [1962, 201],
 [1969, 142],
 [1987, 81],
 [1961, 166],
 [1984, 90],
 [2016, 2308],
 [1964, 169],
 [1971, 69],
 [1963, 180],
 [1970, 147],
 [1986, 76],
 [1965, 154],
 [1958, 19],
 [1966, 145],
 [1968, 133],
 [1967, 150],
 [2017, 168]]

In [26]:
demo_db.run("""
Provide a good example of a sankey chart using my data model ? Report type : Sankey. Plain cypher code, no explanations.
""")

MATCH (a1:Author)-[c:CO_AUTHOR]->(a2:Author)
RETURN a1, c, a2


[['a1', 'c', 'a2'],
 [<Node element_id='4:22e65545-2d65-4360-a691-783eae3c69e7:1004' labels=frozenset({'Author'}) properties={'address': POINT(12.531437482332922 47.376173735645594), 'name': 'Tegegne Marew'}>,
  <Relationship element_id='5:22e65545-2d65-4360-a691-783eae3c69e7:3600' nodes=(<Node element_id='4:22e65545-2d65-4360-a691-783eae3c69e7:1004' labels=frozenset({'Author'}) properties={'address': POINT(12.531437482332922 47.376173735645594), 'name': 'Tegegne Marew'}>, <Node element_id='4:22e65545-2d65-4360-a691-783eae3c69e7:1005' labels=frozenset({'Author'}) properties={'address': POINT(24.779568028936517 44.4569792559138), 'name': 'Doo-Hwan Bae'}>) type='CO_AUTHOR' properties={'year': 2006, 'collaborations': 2}>,
  <Node element_id='4:22e65545-2d65-4360-a691-783eae3c69e7:1005' labels=frozenset({'Author'}) properties={'address': POINT(24.779568028936517 44.4569792559138), 'name': 'Doo-Hwan Bae'}>],
 [<Node element_id='4:22e65545-2d65-4360-a691-783eae3c69e7:1006' labels=frozenset({