# Aura Graph Analytics with Spark

<a target="_blank" href="https://colab.research.google.com/github/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb) in the Neo4j Graph Data Science Client Github repository.

The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session.

We consider a graph of bicycle rentals, which we're using as a simple example to show how project data from Spark to a GDS Session, run algorithms, and eventually retrieving the results back to Spark.
We will cover all management operations: creation, listing, and deletion.

## Prerequisites

This notebook requires having an AuraDB instance available and have the Aura Graph Analytics [feature](https://neo4j.com/docs/aura/graph-analytics/#aura-gds-serverless) enabled for your project.

You also need to have the `graphdatascience` Python library installed, version `1.15` or later.

In [None]:
%pip install "graphdatascience>=1.18a2" python-dotenv "pyspark[sql]"

In [None]:
from dotenv import load_dotenv

# This allows to load required secrets from `.env` file in local directory
# This can include Aura API Credentials and Database Credentials.
# If file does not exist this is a noop.
load_dotenv("sessions.env")

### Connecting to a Spark Session

To interact with the Spark Cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine.
Working with a remote Spark cluster will work similarly.

In [None]:
import os

from pyspark.sql import SparkSession

os.environ["JAVA_HOME"] = "/home/max/.sdkman/candidates/java/current"

spark = SparkSession.builder.master("local[4]").appName("GraphAnalytics").getOrCreate()

# Enable Arrow-based columnar data transfers
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

## Aura API credentials

The entry point for managing GDS Sessions is the `GdsSessions` object, which requires creating [Aura API credentials](https://neo4j.com/docs/aura/api/authentication).

In [None]:
import os

from graphdatascience.session import AuraAPICredentials, GdsSessions

# you can also use AuraAPICredentials.from_env() to load credentials from environment variables
api_credentials = AuraAPICredentials(
    client_id=os.environ["CLIENT_ID"],
    client_secret=os.environ["CLIENT_SECRET"],
    # If your account is a member of several project, you must also specify the project ID to use
    project_id=os.environ.get("PROJECT_ID", None),
)

sessions = GdsSessions(api_credentials=api_credentials)

## Creating a new session

A new session is created by calling `sessions.get_or_create()` with the following parameters:

* A session name, which lets you reconnect to an existing session by calling `get_or_create` again.
* The session memory. 
* The cloud location.
* A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.

See the API reference [documentation](https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create) or the manual for more details on the parameters.

In [None]:
from datetime import timedelta

from graphdatascience.session import CloudLocation, SessionMemory

# Create a GDS session!
gds = sessions.get_or_create(
    # we give it a representative name
    session_name="bike_trips",
    memory=SessionMemory.m_2GB,
    ttl=timedelta(minutes=30),
    cloud_location=CloudLocation("gcp", "europe-west1"),
)

## Adding a dataset

As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips).

In [None]:
import io
import os
import zipfile

import requests

download_path = "bike_trips_data"
if not os.path.exists(download_path):
    url = "https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips"

    response = requests.get(url)
    response.raise_for_status()

    # Unzip the content
    with zipfile.ZipFile(io.BytesIO(response.content)) as z:
        z.extractall(download_path)

df = spark.read.csv(download_path, header=True, inferSchema=True)
df.createOrReplaceTempView("bike_trips")
df.limit(10).show()

## Projecting Graphs

Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.

We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.

Our input data already resembles edge triplets, where each of the rows represents an edge from a source station to a target station. This allows us to use the arrows servers graph import from triplets functionality, which requires the following protocol:

1. Send an action `v2/graph.project.fromTriplets`
   This will initialize the import process and allows us to specify the graph name, and settings like `undirected_relationship_types`. It returns a job id, that we need to reference the import job in the following steps.
2. Send the data in batches to the arrow server.
3. Send another action called `v2/graph.project.fromTriplets.done` to tell the import process that no more data will be send. This will trigger the final graph creation inside the session.
4. Wait for the import process to reach the `DONE` state.

While the overall process is straight forward, we need to somehow tell Spark to

In [None]:
import pandas as pd
import pyarrow
from pyspark.sql import functions

graph_name = "bike_trips"

arrow_client = gds.arrow_client()

# 1. Start the import process
job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)


# Define a function that receives an arrow batch and uploads it to the session
def upload_batch(iterator):
    for batch in iterator:
        arrow_client.upload_triplets(job_id, [batch])
        yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({"batch_rows_imported": [len(batch)]}))


# Select the source target pairs from our source data
source_target_pairs = spark.sql("""
  SELECT start_station_id AS sourceNode, end_station_id AS targetNode
  FROM bike_trips
""")

# 2. Use the `mapInArrow` function to upload the data to the sessions. Returns a dataframe with a single column with the batch sizes.
uploaded_batches = source_target_pairs.mapInArrow(upload_batch, "batch_rows_imported long")

# Aggregate the batch sizes to receive the row count.
uploaded_batches.agg(functions.sum("batch_rows_imported").alias("rows_imported")).show()

# 3. Finish the import process
arrow_client.triplet_load_done(job_id)

# 4. Wait for the import to finish
while not arrow_client.job_status(job_id).succeeded():
    pass

G = gds.v2.graph.get(graph_name)
G

## Running Algorithms

We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples.

In [None]:
print("Running PageRank ...")
pr_result = gds.v2.page_rank.mutate(G, mutate_property="pagerank")

## Sending the computation result back to Spark

Once the computation is done, we might want to further use the result in Spark.
We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers.
Retrieving the data is a bit more complicated since we need some input data frame in order to trigger computations on the Spark workers.
We use a data range equal to the size of workers we have in our cluster as our driving table.
On the workers we will disregard the input and instead stream the computation data from the GDS Session.

In [None]:
# 1. Start the node property export on the session
job_id = arrow_client.get_node_properties(G.name(), ["pagerank"])


# Define a function that receives data from the GDS Session and turns it into data batches
def retrieve_data(ignored):
    stream_data = arrow_client.stream_job(G.name(), job_id)
    batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)
    for b in batches:
        yield b


# Create DataFrame with a single column and one row per worker
input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF("batch_id")
# 2. Stream the data from the GDS Session into the Spark workers
received_batches = input_partitions.mapInArrow(retrieve_data, "nodeId long, pagerank double")
# Optional: Repartition the data to make sure it is distributed equally
result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)

result.show()

## Cleanup

Now that we have finished our analysis, we can delete the session and stop the spark connection.

Deleting the session will release all resources associated with it, and stop incurring costs.

In [None]:
gds.delete()
spark.stop()