In [10]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType
import requests
from io import BytesIO
from zipfile import ZipFile
from neo4j_parallel_spark_loader.predefined_components import group_and_batch_spark_dataframe
from neo4j_parallel_spark_loader import ingest_spark_dataframe
import json

StatementMeta(medium, 26, 9, Finished, Available, Finished)

## Create spark session
Update the values below for your environment

In [4]:
username = "NEO4J_USER"
password = "NEO4J_PASSWORD"
url = "NEO4J_URL"
dbname = "NEO4J_DATABASE"
spark_executor_count=5

spark = (
    SparkSession.builder
    .appName("ReditThreads")
    .config("neo4j.url", url)
    .config("url", url)
    .config("neo4j.authentication.basic.username", username)
    .config("neo4j.authentication.basic.password", "i2_dYwwAMKKqp7tokHoscvNJbBBn1snAZKX0uA_gffA")
    .config("neo4j.database", dbname)
    .getOrCreate()
)

StatementMeta(medium, 26, 3, Finished, Available, Finished)

## Download data

In [46]:
schema = StructType([
    StructField("graph_id", IntegerType(), True),
    StructField("source_id", IntegerType(), True),
    StructField("target_id", IntegerType(), True)
])

# Download and read the zip file content
response = requests.get("https://snap.stanford.edu/data/reddit_threads.zip")
zip_file = ZipFile(BytesIO(response.content))

# Read the JSON file content from the zip
with zip_file.open("reddit_threads/reddit_edges.json") as file:
    # Parse JSON content
    data = json.loads(file.read().decode('utf-8'))
    flattened = [[int(t[0]), int(sublist[0]), int(sublist[1])] for t in data.items() for sublist in t[1]]
    
    # Create DataFrame from parsed JSON
    reddit_df = spark.createDataFrame(flattened, schema=schema)

# Show the result
reddit_df.show()

StatementMeta(medium, 26, 45, Finished, Available, Finished)

+--------+---------+---------+
|graph_id|source_id|target_id|
+--------+---------+---------+
|       0|        0|        2|
|       0|        1|        5|
|       0|        2|        4|
|       0|        2|        5|
|       0|        2|        6|
|       0|        2|        7|
|       0|        2|        8|
|       0|        2|        9|
|       0|        2|       10|
|       0|        3|        8|
|       1|        0|        3|
|       1|        0|        6|
|       1|        1|        8|
|       1|        2|        8|
|       1|        4|        8|
|       1|        5|        8|
|       1|        6|        8|
|       1|        7|        8|
|       1|        8|        9|
|       1|        8|       10|
+--------+---------+---------+
only showing top 20 rows



In [47]:
reddit_df.count()

StatementMeta(medium, 26, 46, Finished, Available, Finished)

5074915

## Load nodes

In [65]:
node_df = (reddit_df
            .select('graph_id', 'source_id')
            .withColumnRenamed('source_id', 'nodeId')
            .union(
                reddit_df
                .select('graph_id', 'target_id')
                .withColumnRenamed('target_id', 'nodeId'))
                .dropDuplicates())
node_df.count()

StatementMeta(medium, 26, 64, Finished, Available, Finished)

4859280

In [66]:
(
    node_df.write
    .format("org.neo4j.spark.DataSource")
    .mode("Overwrite")
    .option("labels", ":Node")
    .option("node.keys", "graph_id:graphId,nodeId:nodeId")
    .option("schema.optimization.node.keys", "KEY")
    .save()
)

StatementMeta(medium, 26, 65, Finished, Available, Finished)

## Load rels

In [69]:
rel_batch_df = group_and_batch_spark_dataframe(spark_dataframe=reddit_df, 
                                               partition_col='graph_id', 
                                               num_groups=spark_executor_count)

StatementMeta(medium, 26, 68, Finished, Available, Finished)

In [70]:
rel_batch_df.show()

StatementMeta(medium, 26, 69, Finished, Available, Finished)

+--------+---------+---------+-----+-----+
|graph_id|source_id|target_id|group|batch|
+--------+---------+---------+-----+-----+
|   10223|        0|        2|    3|    0|
|   10223|        1|        2|    3|    0|
|   10223|        2|        3|    3|    0|
|   10222|        7|       23|    2|    0|
|   10222|        8|       16|    2|    0|
|   10222|        9|       16|    2|    0|
|   10222|       10|       16|    2|    0|
|   10222|       11|       16|    2|    0|
|   10222|       13|       16|    2|    0|
|   10222|       14|       16|    2|    0|
|   10222|       15|       16|    2|    0|
|   10222|       16|       17|    2|    0|
|   10222|       16|       18|    2|    0|
|   10222|       16|       19|    2|    0|
|   10222|       16|       20|    2|    0|
|   10222|       16|       21|    2|    0|
|   10222|       16|       22|    2|    0|
|   10222|       16|       23|    2|    0|
|   10222|       16|       24|    2|    0|
|   10222|       16|       25|    2|    0|
+--------+-

In [71]:
query = """
    MATCH(source:Node {graphId: event.graph_id, nodeId: event.source_id})
    MATCH(target:Node {graphId: event.graph_id, nodeId: event.target_id})
    MERGE(source)-[r:RELATES_TO]->(target)
    """

ingest_spark_dataframe(
    spark_dataframe=rel_batch_df,
    save_mode= "Overwrite",
    options={"query":query}
)

StatementMeta(medium, 26, 70, Finished, Available, Finished)

## Delete rels

In [72]:
rel_count = reddit_df.count()
batch_count = rel_count // 10000 + 1
print(rel_count, batch_count)

StatementMeta(medium, 26, 71, Finished, Available, Finished)

5074915 508


In [73]:
from pyspark.sql.functions import lit
del_df = (spark.range(batch_count)
    .select(lit(1).alias("id")))
print(del_df.count())

StatementMeta(medium, 26, 72, Finished, Available, Finished)

508


In [74]:
del_query = """
    MATCH ()-[r:RELATES_TO]->()
    WITH r LIMIT 10000
    DELETE r"""

(
    del_df.coalesce(1).write
    .format("org.neo4j.spark.DataSource")
    .mode("Overwrite")
    .option("query", del_query)
    .option("batch.size", 1)
    .save()
)

StatementMeta(medium, 26, 73, Finished, Available, Finished)

## Load rels serially

In [75]:
(
    reddit_df.coalesce(1).write
    .format("org.neo4j.spark.DataSource")
    .mode("Overwrite")
    .option("query", query)
    .save()
)

StatementMeta(medium, 26, 74, Finished, Available, Finished)