# Benchmarking Aura features

In [None]:
import os
import timeit

import numpy as np
import seaborn as sns
from tqdm import tqdm

from graphdatascience.aura_sessions import AuraSessions
from graphdatascience.query_runner.aura_db_arrow_query_runner import AuraDbConnectionInfo

In [None]:
from dotenv import load_dotenv

load_dotenv("credentials.env")

CLIENT_ID = os.environ.get("CLIENT_ID")
CLIENT_SECRET = os.environ.get("CLIENT_SECRET")

DB_PASSWORD = os.environ.get("DB_PASSWORD")
DB_ID = os.environ.get("DB_ID")

db_connection_info = AuraDbConnectionInfo(
    f"neo4j+s://{DB_ID}-{os.environ['AURA_ENV']}.databases.neo4j-dev.io", ("neo4j", DB_PASSWORD)
)

In [None]:
sessions = AuraSessions(db_connection_info, (CLIENT_ID, CLIENT_SECRET))
session_name = "benchmark-session"
session_pw = "my-password"

print("Starting GDS session")
if session_name in [i.name for i in sessions.list_sessions()]:
    gds = sessions.connect(session_name, session_pw)
else:
    gds = sessions.create_gds(session_name, session_pw, "24GB")

In [None]:
def measure(func, setup, iterations, warmup_iterations):
    pbar = tqdm(total=iterations + warmup_iterations)

    def wrapper():
        value = func()
        pbar.update(1)
        return value

    pbar.set_description("Warmup")
    warmup = timeit.repeat(wrapper, setup=setup, number=1, repeat=warmup_iterations)

    pbar.set_description("Measurement")
    measurement = timeit.repeat(wrapper, setup=setup, number=1, repeat=iterations)

    pbar.close()

    return {"iterations": measurement, "mean": np.mean(measurement), "avg": np.average(measurement)}

## Projection

In [None]:
def run_remote_projection(query, concurrency):
    with gds.graph.project("graph", query, concurrency=concurrency) as G:
        ()

In [None]:
run_remote_projection(
    """
    CYPHER runtime = parallel
    MATCH (u)
    OPTIONAL MATCH (u)-[r]->(t)
    RETURN gds.graph.project.remote(u, t, {})
    """,
    concurrency=1,
)

### Project entire graph, structure only

#### No parallel Runtime

In [None]:
data = {}
for concurrency in [8]:
    result = measure(
        lambda: run_remote_projection(
            """
            MATCH (u)
            OPTIONAL MATCH (u)-[r]->(t)
            RETURN gds.graph.project.remote(u, t, null)
            """,
            concurrency=concurrency,
        ),
        lambda: gds.graph.drop("graph", failIfMissing=False),
        iterations=1,
        warmup_iterations=1,
    )
    data[concurrency] = result

In [None]:
values = {k: v["mean"] for k, v in list(data.items())}

plot = sns.barplot(values)
plot.set(xlabel="concurrency", ylabel="average runtime")

#### Parallel Runtime

In [None]:
data = {}
for concurrency in [8]:
    result = measure(
        lambda: run_remote_projection(
            """
            CYPHER runtime = parallel
            MATCH (u)
            OPTIONAL MATCH (u)-[r]->(t)
            RETURN gds.graph.project.remote(u, t, null)
            """,
            concurrency=concurrency,
        ),
        lambda: (),
        iterations=1,
        warmup_iterations=1,
    )
    data[concurrency] = result

plot = sns.barplot({k: v["mean"] for k, v in data.items()})
plot.set(xlabel="concurrency", ylabel="average runtime")

### Project entire graph, with properties and labels

In [None]:
query = """ CYPHER runtime = parallel
            MATCH (u)
            OPTIONAL MATCH (u)-[r]->(t)
            RETURN gds.graph.project.remote(u, t, {
                sourceNodeLabels: labels(u),
                sourceNodeProperties: {id: id(u)},
                targetNodeLabels: labels(t),
                targetNodeProperties: {id: id(t)},
                relationshipType: type(r),
                relationshipProperties: {id: id(r)}
            })
        """

data = {}
for concurrency in [8]:
    result = measure(
        lambda: run_remote_projection(query, concurrency=concurrency), lambda: (), iterations=10, warmup_iterations=5
    )
    data[concurrency] = result

plot = sns.barplot({k: v["mean"] for k, v in data.items()})
plot.set(xlabel="concurrency", ylabel="average runtime")

## Write back

In [None]:
G, _ = gds.graph.project(
    "graph",
    """
    CYPHER runtime = parallel
    MATCH (u)
    OPTIONAL MATCH (u)-[r]->(t)
    RETURN gds.graph.project.remote(u, t, {})
    """,
    concurrency=4,
)

In [None]:
G = gds.graph.get("graph")

### Node properties

In [None]:
gds.degree.mutate(G, mutateProperty="degree")

In [None]:
# testing once 
gds.graph.nodeProperties.write(G, node_properties=["degree"])

In [None]:
gds.run_cypher("MATCH (n) WHERE NOT n.degree IS null RETURN count(n)")

In [None]:
result = measure(
    lambda: gds.graph.nodeProperties.write(G, node_properties=["degree"]),
    lambda: gds.run_cypher("MATCH (n) REMOVE n.degree"),
    iterations=1,
    warmup_iterations=1,
)

print(f"scalar property write-back: {result}")

In [None]:
gds.fastRP.mutate(G, mutateProperty="embedding", embeddingDimension=10, iterationWeights=[1.0])

In [None]:
result = timeit.repeat(
    lambda: gds.run_cypher("MATCH (n) SET n.embedding = null"),
    lambda: gds.graph.nodeProperties.write(G, node_properties=["embedding"]),
    number=1,
    repeat=iterations,
)

print(f"array property: {result}")

### Relationship

In [None]:
matched_nodes = gds.run_cypher("MATCH (n)-[r]->() WITH count(r) as degree, n WHERE degree > 5 RETURN id(n) ORDER BY id(n) DESC LIMIT 2").squeeze()

In [None]:
gds.allShortestPaths.delta.mutate(
    G, sourceNode=matched_nodes[0], mutateRelationshipType="MUTATED_RELS", concurrency=6
)

In [None]:
gds.graph.relationship.write(G, relationship_type="MUTATED_RELS", relationship_property="totalCost")

In [None]:
# cleanup rels

writted_rels_count = gds.run_cypher("MATCH ()-[r:MUTATED_RELS]->() RETURN count(r)")
print(f"written rels: {writted_rels_count}")
gds.run_cypher("MATCH ()-[r:MUTATED_RELS]->() CALL { WITH r DELETE r} IN TRANSACTIONS  OF 2000 ROWS")

In [None]:
result = timeit.repeat(
    lambda: gds.run_cypher("MATCH (n)-[r:KNN_RELS]->() DELETE r"),
    lambda: gds.graph.relationship.write(G, relationship_type="KNN_RELS", relationship_property="score"),
    number=1,
    repeat=iterations,
)

print(f"relationships: {result}")

## Cleanup

In [None]:
gds.graph.drop("graph")

In [None]:
sessions.delete_gds(session_name)

In [None]:
# cleanup

from graphdatascience.aura_api import AuraApi

aura_api = AuraApi(CLIENT_ID, CLIENT_SECRET)
aura_api.delete_instance(DB_ID)