In [None]:
# Set up the Neo4j driver to connect to the local database.
from neo4j import GraphDatabase
import pandas as pd

class NeoDriver(object):

    def __init__(self, uri, user, password):
        self._driver = GraphDatabase.driver(uri, auth=(user, password))
        print("Created Neo4j driver. URI=" + uri)

    def close(self):
        self._driver.close()

    def query(self, query, **kwargs):
        with self._driver.session() as session:
            return session.write_transaction(self.run_cypher, query, **kwargs)

    @staticmethod
    def run_cypher(tx, query, **kwargs):
        return tx.run(query, **kwargs)
    
driver = NeoDriver("bolt://localhost:7687", "neo4j", "sunshine")

In [None]:
# Run a Cypher query and display the results.
result = driver.query("""
MATCH (p:Product) 
RETURN p LIMIT 5
""")

for row in result.data():
    print(row['p']['partNumber'] + ": " + row['p']['shortDescription'])

## Example algorithm - similarity between departments

In [None]:
# This query drops the graph if it already exists, else it does nothing.
driver.query("""
CALL gds.graph.exists($name) YIELD exists
WHERE exists
CALL gds.graph.drop($name) yield graphName
RETURN *
""", name = 'departments-products')


# Create a Cypher projection graph of similar departments (based on shared products) 
result = driver.query("""
CALL gds.graph.create.cypher(
    'departments-products'
    'MATCH (d:Department) RETURN id(d) as id',
    'MATCH (d:Department)<-[:HAS_DEPARTMENT]->(p:Product)-[:HAS_DEPARTMENT]->(d2:Department) RETURN id(d) AS source, id(d2) AS target')
""")

print(result.data())

In [None]:
result = driver.query("""
CALL gds.nodeSimilarity.stream.estimate('departments-products',  { similarityCutoff: 0.5 })
""")

for row in result.data():
    print(row['requiredMemory'])
    print(str(row['nodeCount']) + " nodes")
    print(str(row['relationshipCount']) + " rels")

In [None]:
result = driver.query("""
CALL gds.nodeSimilarity.stream('departments-products', { similarityCutoff: 0.75 })
""")

df = pd.DataFrame(result.data())
print(df)
