In [None]:
import os
from dotenv import load_dotenv
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from neo4j import Query, GraphDatabase, RoutingControl, Result

In [None]:
import pandas as pd
pd.set_option("display.max_columns", None)

## Setup Spark

In [None]:
env_file = '.env'

In [None]:
if os.path.exists(env_file):
    load_dotenv(env_file, override=True)

    # Neo4j
    NEO4J_URI = os.getenv('NEO4J_URI')
    NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
    NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
    NEO4J_DATABASE = os.getenv('NEO4J_DATABASE')

    # Files 
    COMPANIES_FILE = os.getenv('COMPANIES_FILE')
else:
    print(f"File {env_file} not found.")

In [None]:
NEO4J_CONNECTOR = "org.neo4j:neo4j-connector-apache-spark_2.12:5.3.10_for_spark_3"

spark = (
    SparkSession.builder
    .appName("CompanyHouse")
    .master("local[*]")
    .config("spark.jars.packages", NEO4J_CONNECTOR)
    .config("url", NEO4J_URI)
    .config("neo4j.url", NEO4J_URI)
    .config("neo4j.authentication.basic.username", NEO4J_USERNAME)
    .config("neo4j.authentication.basic.password", NEO4J_PASSWORD)
    .config("neo4j.database", NEO4J_DATABASE)
    .getOrCreate()
)

In [None]:
spark

In [None]:
spark.sparkContext._jvm.scala.util.Properties.versionString()

In [None]:
spark.sparkContext.getConf().get("spark.jars.packages")

## Read Company Data

In [None]:
companies_raw = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "true")
    .csv(COMPANIES_FILE)
)

In [None]:
companies_raw.count()

In [None]:
companies_raw.printSchema()

In [None]:
companies_raw.limit(5).toPandas()

In [None]:
companies_df = (
    companies_raw
        .select(
            F.upper(F.trim(F.col(" CompanyNumber"))).alias("company_number"),
            F.col("CompanyName").alias("name"),
            F.col("CompanyStatus").alias("status"),
            F.col("CompanyCategory").alias("category"),
            F.col("CountryOfOrigin").alias("country_of_origin"),
            F.col("IncorporationDate").alias("incorporation_date"),
            F.col("DissolutionDate").alias("dissolution_date"),
            F.col("URI").alias("uri"),
            F.col("ConfStmtNextDueDate").alias("conf_stmt_next_due_date"),
            F.col(" ConfStmtLastMadeUpDate").alias("conf_stmt_last_made_up_date"),
        )
        .where(F.col("company_number").isNotNull() & (F.length("company_number") > 0))
        .dropDuplicates(["company_number"])
)

In [None]:
companies_df.limit(5).toPandas()

In [None]:
def norm(colname):
    # normalize for ID generation
    return F.upper(F.trim(F.coalesce(F.col(colname), F.lit(""))))

has_address_df = (
    companies_raw
      .select(
          F.upper(F.trim(F.col(" CompanyNumber"))).alias("company_number"),
          F.col("`RegAddress.CareOf`").alias("care_of"),
          F.col("`RegAddress.POBox`").alias("po_box"),
          F.col("`RegAddress.AddressLine1`").alias("address_line_1"),
          F.col("` RegAddress.AddressLine2`").alias("address_line_2"),
          F.col("`RegAddress.PostTown`").alias("post_town"),
          F.col("`RegAddress.County`").alias("county"),
          F.col("`RegAddress.Country`").alias("country"),
          F.col("`RegAddress.PostCode`").alias("post_code"),
      )
      .withColumn(
          "address_id",
          F.sha2(
              F.concat_ws(
                  "||",
                  norm("care_of"),
                  norm("po_box"),
                  norm("address_line_1"),
                  norm("address_line_2"),
                  norm("post_town"),
                  norm("county"),
                  norm("country"),
                  norm("post_code"),
              ),
              256,
          ),
      )
      .select(
          "company_number",
          "address_id",
          "care_of",
          "po_box",
          "address_line_1",
          "address_line_2",
          "post_town",
          "county",
          "country",
          "post_code",
      )
)

In [None]:
has_address_df.limit(5).toPandas()

In [None]:
has_sic_df = (
    companies_raw
      .select(
          F.upper(F.trim(F.col(" CompanyNumber"))).alias("company_number"),
          F.col("`SICCode.SicText_1`").alias("sic1"),
          F.col("`SICCode.SicText_2`").alias("sic2"),
          F.col("`SICCode.SicText_3`").alias("sic3"),
          F.col("`SICCode.SicText_4`").alias("sic4"),
      )
      .select(
          "company_number",
          F.expr("stack(4, sic1, sic2, sic3, sic4) as sic_raw")
      )
      .where(F.col("sic_raw").isNotNull() & (F.length(F.trim(F.col("sic_raw"))) > 0))
      .withColumn("sic_code", F.regexp_extract(F.col("sic_raw"), r"^\s*([0-9]{4,5})", 1))
      .withColumn("sic_text", F.trim(F.regexp_replace(F.col("sic_raw"), r"^\s*[0-9]{4,5}\s*-\s*", "")))
      .withColumn("sic_code", F.when(F.length(F.col("sic_code")) == 0, F.lit(None)).otherwise(F.col("sic_code")))
      .where(F.col("sic_code").isNotNull())
      .select(
          F.col("company_number").alias("company_number"),
          F.col("sic_code").alias("sic_code"),
          F.col("sic_text").alias("description"),
      )
)

In [None]:
address_df = (
    has_address_df
    .select(['address_id', 'care_of', 'po_box', 'address_line_1', 'address_line_2', 'post_town', 'county', 'country', 'post_code'])
    .dropDuplicates(['address_id'])
)

In [None]:
address_df.limit(5).toPandas()

In [None]:
company_registered_at_df = (
    has_address_df
    .select(
        F.col("company_number").alias("company_number"),
        F.col("address_id").alias("address_id")
    )
    .where(F.col("company_number").isNotNull() & F.col("address_id").isNotNull())
    .dropDuplicates(["company_number", "address_id"])
)

In [None]:
company_registered_at_df.limit(5).toPandas()

In [None]:
sic_df = (
    has_sic_df.dropDuplicates(['sic_code'])
)

In [None]:
sic_df.limit(5).toPandas()

In [None]:
company_has_sic_df = (
    has_sic_df
    .where(F.col("company_number").isNotNull() & F.col("sic_code").isNotNull())
    .select(
        F.col("company_number").alias("company_number"),
        F.col("sic_code").alias("sic_code")
    )
    .dropDuplicates(["company_number", "sic_code"])
)

In [None]:
company_has_sic_df.limit(5).toPandas()

## Load data

### Connection to Neo4j

In [None]:
companies_df.limit(5).toPandas()

In [None]:
(
    companies_df
    .write
    .format("org.neo4j.spark.DataSource")
    .mode("Append")
    .option("labels", ":Company")
    .option("node.keys", "company_number")
    .save()
)

In [None]:
address_df.limit(5).toPandas()

In [None]:
(
    address_df
    .select(['address_id', 'care_of', 'po_box', 'address_line_1', 'address_line_2', 'post_town', 'county', 'country', 'post_code'])    .write
    .format("org.neo4j.spark.DataSource")
    .mode("Append")
    .option("labels", ":Address")
    .option("node.keys", "address_id")
    .save()
)

In [None]:
sic_df.limit(5).toPandas()

In [None]:
(
  sic_df
    .select('sic_code', 'description')
    .write
    .format("org.neo4j.spark.DataSource")
    .mode("Append")
    .option("labels", ":SIC")
    .option("node.keys", "sic_code")
    .save()
)

In [None]:
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

In [None]:
driver.execute_query(
    """
    MATCH (n) RETURN COUNT(n) as Count
    """,
    database_=NEO4J_DATABASE,
    routing_=RoutingControl.READ,
    result_transformer_= lambda r: r.to_df()
)

In [None]:
driver.execute_query(
    """
        CREATE CONSTRAINT company_number IF NOT EXISTS FOR (c:Company) REQUIRE c.company_number IS UNIQUE
    """,
    database_=NEO4J_DATABASE,
    routing_=RoutingControl.WRITE,
    result_transformer_= lambda r: r.to_df()
)

In [None]:
driver.execute_query(
    """
        CREATE CONSTRAINT address_id IF NOT EXISTS FOR (a:Address) REQUIRE a.address_id IS UNIQUE
    """,
    database_=NEO4J_DATABASE,
    routing_=RoutingControl.WRITE,
    result_transformer_= lambda r: r.to_df()
)

In [None]:
driver.execute_query(
    """
        CREATE CONSTRAINT sic_code IF NOT EXISTS FOR (s:SIC) REQUIRE s.sic_code IS UNIQUE
    """,
    database_=NEO4J_DATABASE,
    routing_=RoutingControl.WRITE,
    result_transformer_= lambda r: r.to_df()
)

In [None]:
driver.execute_query(
    """
    SHOW CONSTRAINTS
    """,
    database_=NEO4J_DATABASE,
    routing_=RoutingControl.READ,
    result_transformer_= lambda r: r.to_df()
)

In [None]:
(
  company_registered_at_df
    .repartition(1)
    .write
    .format("org.neo4j.spark.DataSource")
    .mode("Append")
    .option("relationship", "REGISTERED_AT")
    .option("relationship.save.strategy", "keys")

    .option("relationship.source.labels", ":Company")
    .option("relationship.source.node.keys", "company_number:company_number")

    .option("relationship.target.labels", ":Address")
    .option("relationship.target.node.keys", "address_id:address_id")

    .save()
)

In [None]:
(
  company_has_sic_df
    .repartition(1)
    .write
    .format("org.neo4j.spark.DataSource")
    .mode("Append")
    .option("relationship", "HAS_SIC")
    .option("relationship.save.strategy", "keys")

    .option("relationship.source.labels", ":Company")
    .option("relationship.source.node.keys", "company_number:company_number")

    .option("relationship.target.labels", ":SIC")
    .option("relationship.target.node.keys", "sic_code:sic_code")

    .save()
)

Correcting URI

In [None]:
driver.execute_query(
    """
    CALL apoc.periodic.iterate(
      "MATCH (c:Company) RETURN c",
      "SET c.uri = 'https://find-and-update.company-information.service.gov.uk/' + split(c.uri, 'gov.uk/id')[1]",
      {batchSize:10000, parallel:true})
    """,
    database_=NEO4J_DATABASE,
    routing_=RoutingControl.WRITE,
    result_transformer_= lambda r: r.to_df()
)