In [0]:
%pip install neo4j-parallel-spark-loader

In [0]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType
from pyspark.sql import functions as F
import requests
from io import BytesIO
from zipfile import ZipFile
from neo4j import Query, GraphDatabase, RoutingControl, Result
from neo4j_parallel_spark_loader.bipartite import group_and_batch_spark_dataframe
from neo4j_parallel_spark_loader import ingest_spark_dataframe
from neo4j_parallel_spark_loader.visualize import create_ingest_heatmap
import time

## Create spark session
Update the values below for your environment

In [0]:
os.environ['NEO4J_DATABASE'] = "neo4j"
os.environ['NEO4J_USERNAME'] = dbutils.secrets.get(scope="kv_db", key="neo4jUsername")
os.environ['NEO4J_PASSWORD'] = dbutils.secrets.get(scope="kv_db", key="neo4jPassword")
os.environ['NEO4J_URI'] = dbutils.secrets.get(scope="kv_db", key="neo4jUri")

In [0]:
spark_executor_count=5

In [0]:
spark = (
    SparkSession.builder
    .appName("AmazonRatings")
    .config("neo4j.url", os.environ.get("NEO4J_URI"))
    .config("url", os.environ.get("NEO4J_URI"))
    .config("neo4j.authentication.basic.username", os.environ.get("NEO4J_USERNAME"))
    .config("neo4j.authentication.basic.password", os.environ.get("NEO4J_PASSWORD"))
    .config("neo4j.database", os.environ.get("NEO4J_DATABASE"))
    .getOrCreate()
)

## Connect to Neo4j

To connect to the database we use the [Neo4j Python Driver](https://neo4j.com/docs/python-manual/5/). The credentials are stored in our environment so can be specified to the driver.

In [0]:
driver = GraphDatabase.driver(
    os.environ.get("NEO4J_URI"),
    auth=(os.environ.get("NEO4J_USERNAME"), os.environ.get("NEO4J_PASSWORD"))
)

In [0]:
driver.execute_query(
    """
    MATCH (n) RETURN COUNT(n) as Count
    """,
    database_=os.environ['NEO4J_DATABASE'],
    routing_=RoutingControl.READ,
    result_transformer_= lambda r: r.to_df()
)

## Download and Transform Data 

In [0]:
# Define the schema
schema = StructType([
    StructField("source_id", IntegerType(), True),
    StructField("target_id", IntegerType(), True),
    StructField("rating", FloatType(), True),
    StructField("timestamp", IntegerType(), True)
])

# Download the ZIP file
response = requests.get("https://nrvis.com/download/data/dynamic/rec-amazon-ratings.zip")
zip_file = ZipFile(BytesIO(response.content))

# Read the CSV file directly from the ZIP
with zip_file.open("rec-amazon-ratings.edges") as file:
    # Convert to string buffer for Spark to read
    content = file.read().decode('utf-8')
    
    # Create RDD from content
    rdd = spark.sparkContext.parallelize(content.splitlines())
    
    # Convert RDD to DataFrame with schema
    rating_df = spark.read.csv(rdd, schema=schema, header=False)

# Now df is your Spark DataFrame containing the data with proper column names and types
# You can verify the data
rating_df.show()
rating_df.printSchema()

In [0]:
rating_df.count()

## Load Nodes

In [0]:
source_df = ( 
    rating_df
    .select("source_id")
    .distinct()
)

In [0]:
source_df.count()

In [0]:
source_df.limit(5).display()

In [0]:
driver.execute_query(
    """
        CREATE CONSTRAINT source IF NOT EXISTS
        FOR (s:Source) REQUIRE s.source_id IS UNIQUE    
    """,
    database_=os.environ['NEO4J_DATABASE'],
    routing_=RoutingControl.WRITE,
    result_transformer_= lambda r: r.to_df()
)

In [0]:
(
    source_df
    .write
    .format("org.neo4j.spark.DataSource")
    .mode("Overwrite")
    .option("labels", ":Source")
    .option("node.keys", "source_id:source_id")
    .option("schema.optimization.node.keys", "KEY")
    .save()
)

In [0]:
target_df = (
    rating_df
    .select("target_id")
    .distinct()
)   

In [0]:
target_df.count()

In [0]:
target_df.limit(5).display()

In [0]:
driver.execute_query(
    """
    CREATE CONSTRAINT target IF NOT EXISTS
    FOR (t:Target) REQUIRE t.target_id IS UNIQUE
    """,
    database_=os.environ['NEO4J_DATABASE'],
    routing_=RoutingControl.WRITE,
    result_transformer_= lambda r: r.to_df()
)

In [0]:
(
    target_df
    .write
    .format("org.neo4j.spark.DataSource")
    .mode("Overwrite")
    .option("labels", ":Target")
    .option("node.keys", "target_id:target_id")
    .option("schema.optimization.node.keys", "KEY")
    .save()
)

## Load Relations in Parallel

In [0]:
t0 = time.time()

In [0]:
rel_batch_df = group_and_batch_spark_dataframe(spark_dataframe=rating_df, 
                                               source_col='source_id', 
                                               target_col='target_id', 
                                               num_groups=spark_executor_count)

In [0]:
rel_batch_df.show()

In [0]:
create_ingest_heatmap(rel_batch_df)

In [0]:
rel_batch_df

In [0]:
query = """
    MATCH (source:Source {source_id: event.source_id})
    MATCH (target:Target {target_id: event.target_id})
    MERGE (source)-[r:RELATES_TO {timestamp:event.timestamp}]->(target)
    SET r.rating = event.rating
    """
    
ingest_spark_dataframe(
    spark_dataframe=rel_batch_df,
    save_mode= "Overwrite",
    options={"query":query},
    num_groups = spark_executor_count
)

In [0]:
t1 = time.time()
parallel_load_time = t1-t0
parallel_load_time

In [0]:
driver.execute_query(
    """
    MATCH ()-[r]-()
    RETURN COUNT(r) 
    """,
    database_=os.environ['NEO4J_DATABASE'],
    routing_=RoutingControl.READ,
    result_transformer_= lambda r: r.to_df()
)

## Delete Relations

In [0]:
driver.execute_query(
    """
    CALL apoc.periodic.iterate(
        "MATCH ()-[r]-() RETURN r",
        "DELETE r",
        {batchSize:10000, parallel:false}
    )
    """,
    database_=os.environ['NEO4J_DATABASE'],
    routing_=RoutingControl.WRITE,
    result_transformer_= lambda r: r.to_df()
)

## Load Relations Serially

In [0]:
t3 = time.time()

In [0]:
(
    rating_df
    .withColumnRenamed('source_id', 'source.source_id')
    .withColumnRenamed('target_id', 'target.target_id')
    .repartition(1)
    .write
    .format("org.neo4j.spark.DataSource")
    .mode("overwrite")
    .option("url", os.environ.get("NEO4J_URI"))
    .option("authentication.type", "basic")
    .option("authentication.basic.username", os.environ.get("NEO4J_USERNAME"))
    .option("authentication.basic.password", os.environ.get("NEO4J_PASSWORD"))
    .option("database", os.environ.get("NEO4J_DATABASE"))
    .option("relationship", "RELATES_TO")
    .option("relationship.source.labels", ":Source")
    .option("relationship.source.save.mode", "overwrite")
    .option("relationship.source.node.keys", "source.source_id:source_id")
    .option("relationship.target.labels", ":Target")
    .option("relationship.target.save.mode", "overwrite")
    .option("relationship.target.node.keys", "target.target_id:target_id")
    .save()
)

In [0]:
t4 = time.time()
load_serial_time = t4 - t3
load_serial_time

In [0]:
driver.execute_query(
    """
    MATCH ()-[r]-()
    RETURN COUNT(r) 
    """,
    database_=os.environ['NEO4J_DATABASE'],
    routing_=RoutingControl.READ,
    result_transformer_= lambda r: r.to_df()
)

In [0]:
driver.execute_query(
    """
    CALL apoc.periodic.iterate(
        "MATCH (n) RETURN n",
        "DETACH DELETE n",
        {batchSize:10000, parallel:false}
    )
    """,
    database_=os.environ['NEO4J_DATABASE'],
    routing_=RoutingControl.WRITE,
    result_transformer_= lambda r: r.to_df()
)