In [1]:
import sys
import os
from neo4j import GraphDatabase


sys.path.append(os.path.abspath("."))


def load_config():
    cfg = {}
    with open("config.txt", "r") as f:
        for line in f:
            if "=" in line:
                k, v = line.strip().split("=")
                cfg[k] = v
    return cfg

config = load_config()

URI = config["URI"]
USERNAME = config["USERNAME"]
PASSWORD = config["PASSWORD"]



In [2]:
from graph_retriever import GraphRetriever


In [3]:
gr = GraphRetriever(URI, USERNAME, PASSWORD)

print("GraphRetriever initialized.")


GraphRetriever initialized.


## Baseline Retrieval Example

In [4]:
result = gr.retrieve("flights_from", {"origin": "MEX"})
result


[{'flight': '4709', 'origin': 'MEX'},
 {'flight': '1839', 'origin': 'MEX'},
 {'flight': '5565', 'origin': 'MEX'},
 {'flight': '475', 'origin': 'MEX'},
 {'flight': '429', 'origin': 'MEX'},
 {'flight': '2330', 'origin': 'MEX'},
 {'flight': '2253', 'origin': 'MEX'},
 {'flight': '62', 'origin': 'MEX'},
 {'flight': '445', 'origin': 'MEX'},
 {'flight': '1065', 'origin': 'MEX'},
 {'flight': '1090', 'origin': 'MEX'},
 {'flight': '833', 'origin': 'MEX'}]

In [5]:
result = gr.retrieve("flights_to", {"destination": "IAX"})
result


[{'flight': '500', 'destination': 'IAX'},
 {'flight': '128', 'destination': 'IAX'},
 {'flight': '298', 'destination': 'IAX'},
 {'flight': '1469', 'destination': 'IAX'},
 {'flight': '614', 'destination': 'IAX'},
 {'flight': '6131', 'destination': 'IAX'},
 {'flight': '1046', 'destination': 'IAX'},
 {'flight': '880', 'destination': 'IAX'},
 {'flight': '1496', 'destination': 'IAX'},
 {'flight': '6112', 'destination': 'IAX'},
 {'flight': '984', 'destination': 'IAX'},
 {'flight': '1282', 'destination': 'IAX'},
 {'flight': '1902', 'destination': 'IAX'},
 {'flight': '1891', 'destination': 'IAX'},
 {'flight': '1414', 'destination': 'IAX'},
 {'flight': '2244', 'destination': 'IAX'},
 {'flight': '780', 'destination': 'IAX'},
 {'flight': '2373', 'destination': 'IAX'},
 {'flight': '6315', 'destination': 'IAX'},
 {'flight': '1930', 'destination': 'IAX'},
 {'flight': '5', 'destination': 'IAX'},
 {'flight': '1765', 'destination': 'IAX'},
 {'flight': '462', 'destination': 'IAX'},
 {'flight': '253', 'de

In [8]:
result = gr.retrieve("journey_flight", {"feedback_id": "F_1"})
result



[{'journey_id': 'F_1', 'flight': '2411'}]

In [9]:
result = gr.retrieve("flights_between",
                     {"origin": "MEX", "destination": "IAX"})
result


[{'flight': '1839'}, {'flight': '429'}, {'flight': '2330'}, {'flight': '1090'}]

## Train Node2Vec Embeddings

In [10]:
print("Training Node2Vec embeddings...")
output = gr.build_embeddings(method="node2vec")
output


Training Node2Vec embeddings...


Computing transition probabilities:   0%|          | 0/7944 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 25/25 [00:00<00:00, 35.79it/s]
Generating walks (CPU: 2): 100%|██████████| 25/25 [00:00<00:00, 36.07it/s]


"Embedding model 'node2vec_embed' trained and indexed."

## Train GraphSAGE Embeddings

In [11]:
print("Training GraphSAGE embeddings...")
output = gr.build_embeddings(method="graphsage")
output


Training GraphSAGE embeddings...


"Embedding model 'sage_embed' trained and indexed."

## Similarity Search Example

In [12]:
with gr.driver.session() as session:
    eid = session.run("""
        MATCH (p:Passenger)
        RETURN elementId(p) AS eid
        LIMIT 1
    """).single()["eid"]

eid


'4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:0'

In [13]:
print("Node2Vec similarity search (Passenger):")

result = gr.retrieve(
    "similar_nodes",
    {
        "embedding_name": "node2vec_embed",
        "label": "Passenger",
        "node_eid": eid,
        "k": 5
    }
)

result


Node2Vec similarity search (Passenger):


[{'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:0',
  'score': 0.9999357461929321},
 {'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:3567',
  'score': 0.9911239147186279},
 {'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:3139',
  'score': 0.990331768989563},
 {'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:3017',
  'score': 0.9903204441070557},
 {'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:7099',
  'score': 0.9902714490890503}]

In [14]:
print("GraphSAGE similarity search (Passenger):")

result = gr.retrieve(
    "similar_nodes",
    {
        "embedding_name": "sage_embed",
        "label": "Passenger",
        "node_eid": eid,
        "k": 5
    }
)

result



GraphSAGE similarity search (Passenger):


[{'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:0',
  'score': 0.9997367858886719},
 {'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:8241',
  'score': 0.9977693557739258},
 {'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:7059',
  'score': 0.9976536631584167},
 {'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:4608',
  'score': 0.9973176121711731},
 {'id': '4:5db1414e-a7ed-4877-a4a1-3456bfe83c2f:2454',
  'score': 0.997127890586853}]