# ⚡ PySpark to 🏹 PyArrow to 🌐 Neo4j GDS

> Note: this uses pre-alpha GDS capabilities. Some of the following steps (e.g. starting/stopping the Flight Server) will change in the first public alpha release.

In [1]:
!pip install "pyarrow<8.0,>=7.0" "graphdatascience<2.0,>=1.0"

Collecting graphdatascience<2.0,>=1.0
  Using cached graphdatascience-1.0.0-py3-none-any.whl (34 kB)
Collecting neo4j<5.0,>=4.4.2
  Using cached neo4j-4.4.3-py3-none-any.whl
Installing collected packages: neo4j, graphdatascience
Successfully installed graphdatascience-1.0.0 neo4j-4.4.3


In [2]:
from typing import Iterator, Tuple
import json
import uuid

import pyspark
from pyspark.sql import Row, SparkSession

import pyarrow as pa
import pyarrow.flight as flight

from graphdatascience import GraphDataScience, __version__ as __gdsversion__

print(f"using pyarrow v{pa.__version__} and GDS client v{__gdsversion__}")

using pyarrow v7.0.0 and GDS client v1.0.0


## Arrow Helpers & Higher Order Functions
Currently, a user must implement some of the integration to GDS using PyArrow directly. This will change in the future, but for now we have two helper functions: `send_action` and `write_table`.

In [3]:
def send_action(client: flight.FlightClient, action_type: str, meta_data: dict) -> dict:
    """
    Communicates an Arrow Action message to the GDS Arrow Service.
    """
    try:
        result = client.do_action(
            flight.Action(action_type, json.dumps(meta_data).encode("utf-8"))
        )
        return json.loads(next(result).body.to_pybytes().decode())
    except Exception as e:
        print(f"send_action error: {e}")
        return None

def write_table(client: flight.FlightClient, desc: bytes, table: pa.Table) -> int:
    """
    Write a PyArrow Table to the GDS Flight service.
    """
    # Writing an Arrow stream requires first communicating the intent via a "PUT"
    # which includes details on the incoming schema. The schema _must not_ change
    # mid-stream.
    upload_descriptor = flight.FlightDescriptor.for_command(desc)
    writer, _ = client.do_put(upload_descriptor, table.schema)
    rows = len(table)
    with writer:
        try:
            writer.write_table(table, max_chunksize=10_000)
            return rows
        except Exception as e:
            print(f"write_table error: {e}")
    return 0


We also need to provide some serializable Python functions for the PySpark workers to use for data prep & transformation. In this case,

* `load_rows_as_tables` deals with transforming the Row-based PySpark data into PyArrow vectors
* `guid_to_int` is a hackjob to deal with translating our 128-bit guid (string) based node ids into 64-bit signed integer node ids compatbility with the current GDS library

In [4]:
def load_rows_as_tables(arrow_host: str, arrow_port: int, desc: dict):
    """
    Higher-order function that converts PySpark Rows into PyArrow Tables and
    feeds them to Neo4j.
    """
    desc = json.dumps(desc).encode("utf-8")

    def loader(iterator: Iterator[Row]) -> Row:
        client = flight.FlightClient(
            flight.Location.for_grpc_tcp(arrow_host, arrow_port)
        )
        # xxx this is a hack and will not scale
        data = dict()
        for row in iterator:
            for field in row.asDict():
                col = data.get(field, [])
                col.append(row[field])
                data[field] = col
        table = pa.table(data)
        
        # This is where the ✨ magic happens...
        num_sent = write_table(client, desc, table)
        
        yield Row(schema=str(table.schema), table_size=len(table), num_sent=num_sent)
    return loader

def guid_to_int(*fields):
    """Convert fields in a given row from a str-based guid to an int value"""
    def _guid_to_int(row: Row) -> Row:
        # this is a total hack...mask off the upper 65 bits :(
        result = {}
        for field in row.asDict():
            if field in fields:
                guid = uuid.UUID(str(row[field]))
                result[field] = int(int(guid) & ((1 << 63) - 1)) # need a sign bit :(
            else:
                result[field] = row[field]
        # keep our data welformed as PySpark Row's with named fields
        return Row(**result)
    return _guid_to_int

## Some Housekeeping

Let's first configure our GDS client connection and assure our Spark session is alive.

In [5]:
with open("creds.txt", mode="r") as f:
    password = f.readline().strip()

with open("ip.txt", mode="r") as f:
    ip_address = f.readline().strip()

NEO4J_USER, NEO4J_PASS, NEO4J_HOST, NEO4J_GRAPH = (
    "neo4j", password, 
    ip_address, # not sure why internal dns doesn't work 🤷‍♂️
    "test"
)
gds = GraphDataScience(f"neo4j://{NEO4J_HOST}:7687", auth=(NEO4J_USER, NEO4J_PASS))

gds.run_cypher("RETURN 'hello pyspark!'")

Unnamed: 0,'hello pyspark!'
0,hello pyspark!


In [6]:
spark = (
    SparkSession.builder
    .appName("PySpark to PyArrow GDS Example")
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .config("spark.sql.execution.arrow.pyspark.maxRecordsPerBatch", 10_000)
    .getOrCreate()
)
spark

## Prepare our Data with help from ⚡PySpark

We're using 2 tables in BigQuery. PySpark running in GCP Dataproc already includes the capabilities to read from BigQuery tables, so this part is super easy.

> _Note: we use PySpark to drop or rename some columns to transform the data to meet expectations of the current pre-alpha GDS Flight Server._

In [7]:
nodes = (
    spark.read
    .format("bigquery")
    .load("neo4j-se-team-201905.fraud_demo_data.user")
    .withColumnRenamed("guid", "node_id")
    .cache()
)
nodes.limit(5).toPandas()

                                                                                

Unnamed: 0,node_id,fraudMoneyTransfer,moneyTransferErrorCancelAcmount
0,fbfecb0d43298abedcfc51ed87f0c23d,0,0.0
1,8bd4336d2bbc95d23db4f58cae5b3ee6,0,0.0
2,19227007c6f5f714b37a2d85de3bcf89,0,0.0
3,99d83a47c8edda428639be9a6d6431dd,0,0.0
4,70e0bc1649566c06431b171dbe5b2fd2,0,0.0


In [8]:
edges = (
    spark.read
    .format("bigquery")
    .load("neo4j-se-team-201905.fraud_demo_data.p2p")
    .withColumnRenamed("start_guid", "source_id")
    .withColumnRenamed("end_guid", "destination_id")
    .drop("start_label", "end_label", "transactionDateTime", "totalAmount")
    .cache()
)
edges.limit(5).toPandas()

                                                                                

Unnamed: 0,source_id,destination_id
0,00000056b0d4d68e9b1f2a80fbe55823,125110adf289183a0fdb15e61f8a81c8
1,0011985949cf4de51ef628aa47c7e527,8584dd66824b27e381ab5752efc6569f
2,002c20d7382d262532b9c43a3e5d4a2f,be133538c37501fcfc351542a71d2f94
3,002c20d7382d262532b9c43a3e5d4a2f,be133538c37501fcfc351542a71d2f94
4,002c20d7382d262532b9c43a3e5d4a2f,be133538c37501fcfc351542a71d2f94


## Clean up any Existing Mess

Let's make sure we've got a clean slate for the demo.

In [9]:
try:
    gds.run_cypher("CALL gds.alpha.flightServer.stop()")
    print("Stopped existing FlightServer.")
except Exception:
    pass

try:
    G = gds.graph.get(NEO4J_GRAPH)
    print(f"Dropping existing graph for {NEO4J_GRAPH}")
    G.drop()
except ValueError:
    pass

Stopped existing FlightServer.


### 🟢 Start our GDS Flight Server

In this pre-alpha version, the Flight Server needs to be manually started. It's a simple stored procedure call.

In [10]:
gds.run_cypher("CALL gds.alpha.flightServer.start()")

Unnamed: 0,listenerAddress,listenerPort
0,0.0.0.0,4242


### Prepare some 🏹 PyArrow Stuff

We need a few Arrow-specific messaging components before we can take off. The GDS Flight Server uses these to signal state progression through the data loading process.

In [11]:
node_descriptor = { "name": NEO4J_GRAPH, "entity_type": "node" }
edge_descriptor = { "name": NEO4J_GRAPH, "entity_type": "relationship" }

node_descriptor, edge_descriptor

({'name': 'test', 'entity_type': 'node'},
 {'name': 'test', 'entity_type': 'relationship'})

We create a `FlightClient` specifically for use from the *Spark Driver* instance. The *Driver* will be responsible for signalling state changes to the GDS backend.

In [12]:
client = flight.FlightClient(flight.Location.for_grpc_tcp(NEO4J_HOST, 4242))
send_action(client, "CREATE_GRAPH", {"name": NEO4J_GRAPH})

{'name': 'test'}

### 🚚 Load the Nodes

Here's where the more complicated processing occurs. We use our helper functions (`guid_to_int` and `load_rows_as_tables`) to transform the Spark data as needed and also batch load it into GDS.

Spark's `.mapPartitions` method on the `RDD` instance allows us to effectively create 1 "loader thread" per Spark partition.

Our helper function returns a Tuple of:
* PyArrow `Table` schema as a `string` 
* the number of nodes in our constructed PyArrow `Table`
* the number of nodes our PyArrow client sent to GDS

In [13]:
results = (
    nodes.rdd
    .map(guid_to_int("node_id"), True)
    .mapPartitions(load_rows_as_tables(NEO4J_HOST, 4242, node_descriptor), True)
    .collect()
)

print(f"loaded nodes: {results}")

[Stage 2:>                                                          (0 + 1) / 1]

loaded nodes: [Row(num_sent=33732, schema='fraudMoneyTransfer: int64\nmoneyTransferErrorCancelAcmount: double\nnode_id: int64', table_size=33732)]


                                                                                

We need to tell the GDS server to stop looking for nodes. This currently __must__ be done prior to streaming relationships as the node ids must already exist for processing any edges.

In [14]:
send_action(client, "NODE_LOAD_DONE", {"name": NEO4J_GRAPH})

{'name': 'test', 'node_count': 33732}

### 🚚 Load the Relationships

Similar to how we loaded the Nodes, we do a little last-mile transformation via PySpark before feeding the data to GDS.

In [15]:
results = (
    edges.rdd
    .map(guid_to_int("source_id", "destination_id"), True)
    .mapPartitions(load_rows_as_tables(NEO4J_HOST, 4242, edge_descriptor))
    .collect()
)

print(f"loaded edges: {results}")

[Stage 3:>                                                          (0 + 1) / 1]

loaded edges: [Row(num_sent=102832, schema='destination_id: int64\nsource_id: int64', table_size=102832)]


                                                                                

Signal the edge load is done. This triggers the graph being finalized into the Graph Catalog.

In [16]:
send_action(client, "RELATIONSHIP_LOAD_DONE", {"name": NEO4J_GRAPH})

{'name': 'test', 'relationship_count': 102832}

### 🛑 Stop our Flight Server

This is manual today, but a simple stored procedure call.

In [17]:
gds.run_cypher("CALL gds.alpha.flightServer.stop();")

Unnamed: 0,listenerAddress,listenerPort
0,0.0.0.0,4242


## Inspect the GDS Graph

Let's make sure the graph was actually created. 😉

This is easy with the GDS client.

In [18]:
gds.graph.list()

Unnamed: 0,degreeDistribution,graphName,database,memoryUsage,sizeInBytes,nodeCount,relationshipCount,configuration,density,creationTime,modificationTime,schema
0,"{'p99': 49, 'min': 0, 'max': 1073, 'mean': 3.0...",test,neo4j,11481 KiB,11757144,33732,102832,{},9e-05,2022-05-04T18:48:15.258358000+00:00,2022-05-04T18:48:15.258341000+00:00,"{'relationships': {'REL': {}}, 'nodes': {'__AL..."


In [19]:
G = gds.graph.get(NEO4J_GRAPH)

We should have some node properties:

In [20]:
G.node_properties("__ALL__")

['fraudMoneyTransfer', 'moneyTransferErrorCancelAcmount']

Run an algo, just for fun.

In [26]:
gds.pageRank.stream(G)

Unnamed: 0,nodeId,score
0,6700320426250256957,0.277500
1,4446448716690964198,0.929590
2,3709327296075386761,0.961240
3,448599208369598941,0.150000
4,4835484041427496914,0.780385
...,...,...
33727,1051263292924472721,0.150000
33728,3064255076439635539,0.277500
33729,7158911241180865343,0.150000
33730,9015902886036781597,1.018529


## Drop our Graph and Clean Up

Your mother doesn't work here...

In [27]:
G.drop()

In [28]:
gds.graph.list()

Unnamed: 0,degreeDistribution,graphName,database,memoryUsage,sizeInBytes,nodeCount,relationshipCount,configuration,density,creationTime,modificationTime,schema


In [30]:
try:
    gds.graph.get(NEO4J_GRAPH)
except ValueError as e:
    print(e)

No projected graph named 'test' exists


# 👋 That's all Folks!

Questions? Comments? Concerns? `dave<dot>voutila [at] neotechnology.com`