In [None]:
%pip install neo4j
%pip install graphdatascience
%pip install numpy

In [1]:
import numpy as np
import random
import datetime
from neo4j import GraphDatabase
import graphdatascience



In [47]:
endpoint = "neo4j://localhost:7687"
username = "neo4j"
password = "#Bachelorarbeit"
database = "neo4j"

gds = graphdatascience.GraphDataScience(endpoint=endpoint, auth=(username, password))
gds.set_database(database)

db_driver = GraphDatabase.driver(endpoint, auth=(username,password))

In [None]:
def random_timestamp_for_station(station_id, suffix="start"):
    random.seed(f"{station_id}_{suffix}")  # unterscheidet start und endzeitpunkt für den randomseed damit ergebnis reproduzierbar ist

    year = 2017
    month = random.randint(1, 12)
    day = random.randint(1, 28)
    hour = random.randint(0, 23)
    minute = random.randint(0, 59)
    second = random.randint(0, 59)

    return datetime.datetime(year, month, day, hour, minute, second)

In [None]:
def write_station_embeddings_batchwise(driver, batch_size=500):
    with driver.session(database=database) as session:
        result = session.run("""
            MATCH (s:Station)
            RETURN id(s) AS station_id
        """)

        batch = []
        count = 0
        for record in result:
            sid = record["station_id"]

            try:
                start_dt = random_timestamp_for_station(sid, "start")
                end_dt = random_timestamp_for_station(sid, "end")

                startTimeStation_emb = timestamp_to_embedding(start_dt)
                endTimeStation_emb = timestamp_to_embedding(end_dt)

                batch.append({
                    "station_id": sid,
                    "start": startTimeStation_emb.tolist(),
                    "end": endTimeStation_emb.tolist()
                })

            except Exception as e:
                print(f"Skipping station {sid}: {e}")

            if len(batch) >= batch_size:
                _send_station_embedding_batch(driver, batch)
                count += len(batch)
                print(f"{count} stations processed.")
                batch = []

        if batch:
            _send_station_embedding_batch(driver, batch)

def _send_station_embedding_batch(driver, batch):
    query = """
    CALL apoc.periodic.iterate(
      'UNWIND $batch AS row RETURN row',
      '
      MATCH (s:Station) WHERE id(s) = row.station_id
      SET s.startTimeEmbedding = row.start,
          s.endTimeEmbedding = row.end
      ',
      {batchSize: 100, parallel: true, params: {batch: $batch}}
    )
    """
    with driver.session(database=database) as session:
        session.run(query, batch=batch)


write_station_embeddings_batchwise(db_driver)

In [26]:
def timestamp_to_embedding(timestamp):
    try:
        d = timestamp.to_native()  # für Neo4j-Typen
    except AttributeError:
        d = timestamp # für stationsembedding

    unix_timestamp = int(d.timestamp())
    dt = datetime.datetime.fromtimestamp(unix_timestamp)

    hour_sin = np.sin(2* np.pi * dt.hour /24)
    hour_cos = np.cos(2* np.pi * dt.hour /24)

    weekday_sin = np.sin(2* np.pi * dt.weekday() / 7)
    weekday_cos = np.cos(2* np.pi * dt.weekday() / 7)

    day_sin = np.sin(2* np.pi * dt.day / 7)
    day_cos = np.cos(2* np.pi * dt.day / 7)

    month_sin = np.sin(2* np.pi * dt.month / 12)
    month_cos = np.cos(2* np.pi * dt.month / 12)

    is_weekend = 1 if dt.weekday() >5 else 0

    day_of_year = d.timetuple().tm_yday
    unix_scaled = unix_timestamp / 1e9


    return np.array([hour_sin, hour_cos, weekday_sin, weekday_cos, day_sin, day_cos,month_sin, month_cos, is_weekend, day_of_year, unix_scaled])

In [27]:
#52 min
def write_embeddings_batchwise(driver, batch_size=500):
    with driver.session(database=database) as session:
        result = session.run("""
            MATCH (t:Trip)
            WHERE t.validFrom IS NOT NULL AND t.validTo IS NOT NULL
            RETURN id(t) AS node_id, t.validFrom AS startTime, t.validTo AS endTime
        """)

        batch = []
        batches_sent = 0
        for record in result:
            try:
                timeEmbeddingStart = timestamp_to_embedding(record["startTime"])
                timeEmbeddingEnd = timestamp_to_embedding(record["endTime"])
                batch.append({
                    "node_id": record["node_id"],
                    "externalTimeEmbeddingStart": timeEmbeddingStart.tolist(),
                    "externalTimeEmbeddingEnd": timeEmbeddingEnd.tolist()
                })
            except Exception as e:
                print(f"Skipping node {record['node_id']}: {e}")

            if len(batch) >= batch_size:
                _send_embedding_batch(driver, batch)
                batches_sent += batch_size
                print(f"{batches_sent} embeddings geschrieben")
                batch = []

        if batch:
            _send_embedding_batch(driver, batch)

def _send_embedding_batch(driver, batch):
    query = """
    CALL apoc.periodic.iterate(
      'UNWIND $batch AS row RETURN row',
      '
      MATCH (t:Trip) WHERE id(t) = row.node_id
      SET t.startTimeEmbedding = row.externalTimeEmbeddingStart,
          t.endTimeEmbedding = row.externalTimeEmbeddingEnd
      ',
      {batchSize: 100, parallel: true, params: {batch: $batch}}
    )
    """
    with driver.session(database=database) as session:
        session.run(query, batch=batch)

write_embeddings_batchwise(db_driver)


500 embeddings geschrieben
1000 embeddings geschrieben
1500 embeddings geschrieben
2000 embeddings geschrieben
2500 embeddings geschrieben
3000 embeddings geschrieben
3500 embeddings geschrieben
4000 embeddings geschrieben
4500 embeddings geschrieben
5000 embeddings geschrieben
5500 embeddings geschrieben
6000 embeddings geschrieben
6500 embeddings geschrieben
7000 embeddings geschrieben
7500 embeddings geschrieben
8000 embeddings geschrieben
8500 embeddings geschrieben
9000 embeddings geschrieben
9500 embeddings geschrieben
10000 embeddings geschrieben
10500 embeddings geschrieben
11000 embeddings geschrieben
11500 embeddings geschrieben
12000 embeddings geschrieben
12500 embeddings geschrieben
13000 embeddings geschrieben
13500 embeddings geschrieben
14000 embeddings geschrieben
14500 embeddings geschrieben
15000 embeddings geschrieben
15500 embeddings geschrieben
16000 embeddings geschrieben
16500 embeddings geschrieben
17000 embeddings geschrieben
17500 embeddings geschrieben
18000

#Start

In [49]:
projection_query_start  = """
MATCH (source)-[r:HAS_START|HAS_END]->(target)
WHERE source:Trip AND target:Station
WITH gds.graph.project(
  'externalGraph',
  source,
  target,
  {
    sourceNodeProperties: source {
      externalTimeEmbeddingStart: source.startTimeEmbedding
    },
    targetNodeProperties: target {
     externalTimeEmbeddingStart: target.startTimeEmbedding
    }},
  {undirectedRelationshipTypes: ['*']}
) AS g
RETURN g.graphName AS graph, g.nodeCount AS nodes, g.relationshipCount AS rels
"""

In [50]:
#4min40s
with db_driver.session(database=database) as session:
        session.run(projection_query_start)

In [51]:
G = gds.graph.get("externalGraph")

In [52]:
gds.fastRP.write.estimate(
    G,
    writeProperty="externalEmbeddingStart",
    randomSeed = 42,
    embeddingDimension= 128,
    nodeSelfInfluence = 1.0,
    propertyRatio = 0.5,
    featureProperties = ['externalTimeEmbeddingStart'],
    iterationWeights = [1.0]
)

requiredMemory                                                  24 GiB
treeView             Memory Estimation: 24 GiB\n|-- algorithm: 24 G...
mapView              {'memoryUsage': '24 GiB', 'name': 'Memory Esti...
bytesMin                                                   26708600512
bytesMax                                                   26708600512
nodeCount                                                     16365505
relationshipCount                                             65480276
heapPercentageMin                                                  0.8
heapPercentageMax                                                  0.8
Name: 0, dtype: object

In [53]:
#19min
gds.fastRP.write(
    G,
    writeProperty="externalEmbeddingStart",
    randomSeed = 42,
    embeddingDimension= 128,
    nodeSelfInfluence = 1.0,
    propertyRatio = 0.5,
    featureProperties = ['externalTimeEmbeddingStart'],
    iterationWeights = [1.0]
)

 FastRP:   0%|          | 0/100 [00:00<?, ?%/s]

nodeCount                                                         16365505
nodePropertiesWritten                                             16365505
preProcessingMillis                                                      0
computeMillis                                                        28657
writeMillis                                                         305952
configuration            {'writeProperty': 'externalEmbeddingStart', 'r...
Name: 0, dtype: object

In [54]:
G.drop()

graphName                                                    externalGraph
database                                                             neo4j
databaseLocation                                                     local
memoryUsage                                                               
sizeInBytes                                                             -1
nodeCount                                                         16365505
relationshipCount                                                 65480276
configuration            {'readConcurrency': 4, 'undirectedRelationship...
density                                                                0.0
creationTime                           2025-04-18T14:02:08.899277000+02:00
modificationTime                       2025-04-18T14:02:08.899277000+02:00
schema                   {'graphProperties': {}, 'nodes': {'__ALL__': {...
schemaWithOrientation    {'graphProperties': {}, 'nodes': {'__ALL__': {...
Name: 0, dtype: object

#End

In [56]:
projection_query_end  = """
MATCH (source)-[r:HAS_START|HAS_END]->(target)
WHERE source:Trip AND target:Station
WITH gds.graph.project(
  'externalGraph',
  source,
  target,
  {
    sourceNodeProperties: source {
      externalTimeEmbeddingEnd: source.endTimeEmbedding
    },
    targetNodeProperties: target {
     externalTimeEmbeddingEnd: target.endTimeEmbedding
    }},
  {undirectedRelationshipTypes: ['*']}
) AS g
RETURN g.graphName AS graph, g.nodeCount AS nodes, g.relationshipCount AS rels
"""

In [57]:
with db_driver.session(database=database) as session:
        session.run(projection_query_end)

In [58]:
G = gds.graph.get("externalGraph")

In [59]:
gds.fastRP.write.estimate(
    G,
    writeProperty="externalEmbeddingEnd",
    randomSeed = 42,
    embeddingDimension= 128,
    nodeSelfInfluence = 1.0,
    propertyRatio = 0.5,
    featureProperties = ['externalTimeEmbeddingEnd'],
    iterationWeights = [1.0]
)

requiredMemory                                                  24 GiB
treeView             Memory Estimation: 24 GiB\n|-- algorithm: 24 G...
mapView              {'memoryUsage': '24 GiB', 'name': 'Memory Esti...
bytesMin                                                   26708600512
bytesMax                                                   26708600512
nodeCount                                                     16365505
relationshipCount                                             65480276
heapPercentageMin                                                  0.8
heapPercentageMax                                                  0.8
Name: 0, dtype: object

In [60]:
#16min
gds.fastRP.write(
    G,
    writeProperty="externalEmbeddingEnd",
    randomSeed = 42,
    embeddingDimension= 128,
    nodeSelfInfluence = 1.0,
    propertyRatio = 0.5,
    featureProperties = ['externalTimeEmbeddingEnd'],
    iterationWeights = [1.0]
)

 FastRP:   0%|          | 0/100 [00:00<?, ?%/s]

nodeCount                                                         16365505
nodePropertiesWritten                                             16365505
preProcessingMillis                                                      0
computeMillis                                                        27775
writeMillis                                                         971161
configuration            {'writeProperty': 'externalEmbeddingEnd', 'ran...
Name: 0, dtype: object

In [61]:
G.drop()

graphName                                                    externalGraph
database                                                             neo4j
databaseLocation                                                     local
memoryUsage                                                               
sizeInBytes                                                             -1
nodeCount                                                         16365505
relationshipCount                                                 65480276
configuration            {'readConcurrency': 4, 'undirectedRelationship...
density                                                                0.0
creationTime                           2025-04-18T14:25:37.376512000+02:00
modificationTime                       2025-04-18T14:25:37.376512000+02:00
schema                   {'graphProperties': {}, 'nodes': {'__ALL__': {...
schemaWithOrientation    {'graphProperties': {}, 'nodes': {'__ALL__': {...
Name: 0, dtype: object

Hier kommt jetzt averagen der beiden embeddings und dann index rauf

In [63]:
#9min
average_query = """
CALL apoc.periodic.iterate(
  "MATCH (t:Trip) WHERE t.externalTimeEmbeddingStart IS NOT NULL AND t.externalTimeEmbeddingEnd IS NOT NULL RETURN t",
  "WITH t, apoc.coll.zip(t.externalEmbeddingStart, t.externalEmbeddingEnd) AS zipped
   SET t.externalIntervalEmbedding = [pair IN zipped | (pair[0] + pair[1]) / 2.0]",
  {batchSize:10000, parallel:true}
)"""

with db_driver.session() as session:
    session.run(average_query)

In [64]:
def create_vector_index(index_name, label, property_name, vector_dimension, similarity="cosine"):
    query = f"""
    CREATE VECTOR INDEX {index_name} IF NOT EXISTS
    FOR (n:{label})
    ON (n.{property_name})
    OPTIONS {{
    indexConfig: {{
        `vector.dimensions`: {vector_dimension},
        `vector.similarity_function`: '{similarity}'
        }}
    }}
    """
    with db_driver.session(database=database) as session:
        session.run(query)
create_vector_index( 'externalIntervalIndex','Trip', 'externalIntervalEmbedding', '128')

In [65]:
gds.close()
db_driver.close()