In [None]:
import pathlib
import os
from pyvis import network as net
from dotenv import load_dotenv
import neo4j
from pprint import pprint

# Load environment variables
env_path = pathlib.Path().resolve() / 'db_config.env'
load_dotenv(dotenv_path=env_path)

In [None]:
URI = os.getenv("NEO4J_URI")
AUTH = (os.getenv("NEO4J_USERNAME"), os.getenv("NEO4J_PASSWORD"))


def find_shortest_path(from_person, to_person):
    QUERY = """
        MATCH (from_person:Person {name: $from_person})-[la:LIVES_NEAR]->(sa:Station),
        (to_person:Person {name: $to_person})-[lb:LIVES_NEAR]->(sb:Station),
        p = shortestPath((sa)-[:CONNECTS_TO*]-(sb))
        RETURN from_person, to_person, p;
    """
    with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
        driver.verify_connectivity()
        # arrat of records, summary, result fields
        records, _, _ = driver.execute_query(
            QUERY, from_person=from_person, to_person=to_person
        )
        if not records:
            print(f"No path found between {from_person} and {to_person}")
            return None

        return records[0]


def visualize_path(path):
    nodes = [
        path["from_person"]["name"],
        *[n["name"] for n in path["p"].nodes],
        path["to_person"]["name"],
    ]
    path_len = len(path["p"].nodes)
    path_colors = ["blue"] * path_len
    colors = ["orange", *path_colors, "orange"]
    g = net.Network(cdn_resources="remote", directed=True)
    g.add_nodes(nodes, color=colors)

    # you do not need list() because add_edges accepts any iterable
    # I added list() to use the same object in both add_edges() and pprint()
    edges = list(zip(nodes[0:-1], nodes[1:]))
    g.add_edges(edges)

    pprint(nodes)
    pprint(edges)
    g.show("sample.html", notebook=False)


path = find_shortest_path("Alison", "Bob")
pprint(path)
if path:
    visualize_path(path)