In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, split, col, concat, lit
from pyspark.sql.types import TimestampType
from pyspark.sql.streaming import DataStreamReader
from neo4j import GraphDatabase


spark = (SparkSession.builder
         .master("local")
         .appName("Transactions")
         .config('spark.executor.memory', '1g')
         .config('spark.executor.cores', '1')
         .config('spark.driver.memory','2g')
         .getOrCreate()
        )
StreamReader = DataStreamReader(spark)

URI_container = "neo4j://neo4j:7687"
URI_local = "bolt://localhost:7687"
URI = URI_container

In [None]:
# Creating property indexes to improve insert performance

with GraphDatabase.driver(URI) as driver:
    driver.execute_query("CREATE TEXT INDEX FOR (n:Person) ON n.ssn")
    driver.execute_query("CREATE TEXT INDEX FOR (n:Transaction) ON n.trans_num")
    driver.execute_query("CREATE TEXT INDEX FOR (n:Merchant) ON n.merchant")
    driver.execute_query("CREATE TEXT INDEX FOR (n:Location) ON n.city")
    driver.execute_query("CREATE TEXT INDEX FOR (n:Account) ON n.acct_num")
    driver.execute_query("CREATE TEXT INDEX FOR (n:CreditCard) on n.cc_num")

In [None]:
sample_in = spark.read.csv("schema.csv", sep="|", header=True)
schema = sample_in.schema
headers = sample_in.columns

In [None]:
person_fields = ["ssn", "first", "last", "gender", "job", "dob"]
acct_fields = ["acct_num"]
cc_fields = ["cc_num"]
per_loc_fields = ["street","city","state","zip","lat","long","city_pop"]
trans_fields = ["trans_num","trans_date","trans_time","amt", "trans_datetime"]
merch_fields = ["category","merchant","merch_lat","merch_long"]

In [None]:
path = "Sparkov_Data_Generation/output/"

float_fields = ["amt","lat","long","merch_lat","merch_long"]
date_fields = ["trans_date","dob"]

def cast_data(df, fields):
    if "trans_date" in fields:
        df = df.withColumn("trans_datetime", concat(col("trans_date"), lit(" "), col("trans_time")))
        df = df.withColumn("trans_datetime",col("trans_datetime").cast(TimestampType()))
    if set(date_fields) & set(fields):
        df = df.select([col(column).cast("date") if column in date_fields else col(column) for column in fields])
    if set(float_fields) & set(fields):
        df = df.select([col(column).cast("double") if column in float_fields else col(column) for column in fields])
    
    return df

class StreamWriter(pyspark.sql.DataFrame):
    def __init__(self, spark_read):
        self.read_stream = spark_read

    def insert_nodes(self, fields, checkpoint_path, label, key):
        self.read_stream = cast_data(self.read_stream, fields)
        self.read_stream.select(fields).filter("first != 'first'").dropna(how="any").writeStream \
          .format("org.neo4j.spark.DataSource") \
          .option("url", URI) \
          .option("checkpointLocation", checkpoint_path) \
          .option("labels", label) \
          .option("node.keys", key) \
          .option("SaveMode", "Overwrite") \
          .start()

    def insert_relationships(self, fields, checkpoint_path, relationship, labels_keys):
        self.read_stream.select(fields).dropna(how="any").writeStream \
          .format("org.neo4j.spark.DataSource") \
          .option("relationship", relationship) \
          .option("url", URI) \
          .option("checkpointLocation", checkpoint_path) \
          .option("relationship.save.strategy", "keys") \
          .option("relationship.source.labels", labels_keys.get("s_label")) \
          .option("relationship.source.save.mode", "Overwrite") \
          .option("relationship.source.node.keys", labels_keys.get("s_key")) \
          .option("relationship.target.labels", labels_keys.get("t_label")) \
          .option("relationship.target.save.mode", "Overwrite") \
          .option("relationship.target.node.keys", labels_keys.get("t_key")) \
          .start()

# Reading the csv for the single read stream, which splits to multiple writer streams to control the node writes
csv_reader = StreamReader.csv(path=path, sep="|", schema=schema)

per_stream = StreamWriter(csv_reader)
tran_stream = StreamWriter(csv_reader)
merchant_stream = StreamWriter(csv_reader)
loc_stream = StreamWriter(csv_reader)
acct_stream = StreamWriter(csv_reader)
cc_stream = StreamWriter(csv_reader)

sent_stream = StreamWriter(csv_reader)
received_stream = StreamWriter(csv_reader)
resides_stream = StreamWriter(csv_reader)
used_stream = StreamWriter(csv_reader)

In [None]:
per_stream.insert_nodes(person_fields, "/tmp/chpt1", ":Person", "ssn")
tran_stream.insert_nodes(trans_fields, "/tmp/chpt2", ":Transaction", "trans_num")
merchant_stream.insert_nodes(merch_fields, "/tmp/chpt3", ":Merchant", "merchant")
loc_stream.insert_nodes(per_loc_fields, "/tmp/chpt4", ":Location", "city")
acct_stream.insert_nodes(acct_fields, "/tmp/chpt5", ":Account", "acct_num")
cc_stream.insert_nodes(cc_fields, "/tmp/chpt6", ":CreditCard", "cc_num")


used_stream.insert_relationships(["ssn", "cc_num"],
                                 "/tmp/chpt7",
                                 "USED",
                                 {
                                    "s_label": ":Person",
                                    "s_key": "ssn",
                                    "t_label":":CreditCard",
                                    "t_key": "cc_num"
                                })
sent_stream.insert_relationships(["cc_num","trans_num"],
                                "/tmp/chpt8",
                                "SENT",
                                {
                                    "s_label": ":CreditCard",
                                    "s_key": "cc_num",
                                    "t_label":":Transaction",
                                    "t_key": "trans_num"
                                })
received_stream.insert_relationships(["trans_num", "merchant"],
                                "/tmp/chpt9",
                                "RECEIVED",
                                {
                                    "s_label": ":Transaction",
                                    "s_key": "trans_num",
                                    "t_label":":Merchant",
                                    "t_key": "merchant"
                                })
resides_stream.insert_relationships(["ssn","city"],
                                "/tmp/chpt10",
                                "RESIDES_IN",
                                {
                                    "s_label": ":Person",
                                    "s_key": "ssn",
                                    "t_label":":Location",
                                    "t_key": "city"
                                })

In [None]:
# spark.stop()