In [1]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyspark.sql.functions as F
import yaml

import graphframes as gf
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

while not Path("data") in Path(".").iterdir():
    os.chdir("..")

plt.style.use("seaborn-white")
conf_dict = yaml.safe_load(Path("config/conf.yaml").read_text())

checkpoint_dir = str(Path("spark-checkpoints").absolute())
graphframes_jar_path = str(
    Path(
        ".venv/lib/python3.9/site-packages/pyspark/jars/graphframes-0.8.2-spark3.1-s_2.12.jar"
    ).absolute()
)

spark_conf = (
    SparkConf()
    .set("spark.jars", graphframes_jar_path)
    .set("spark.sql.sources.partitionOverwriteMode", "dynamic")
)

sc = SparkContext(conf=spark_conf).getOrCreate()
sc.setCheckpointDir(checkpoint_dir)
sc.setLogLevel("ERROR")

spark = SparkSession.builder.config("spark.driver.memory", "8g").getOrCreate()

22/06/25 17:36:49 WARN Utils: Your hostname, domvwt-XPS-13-9305 resolves to a loopback address: 127.0.1.1; using 192.168.0.24 instead (on interface wlp164s0)
22/06/25 17:36:49 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
22/06/25 17:36:49 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/06/25 17:36:50 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
companies_processed_df = spark.read.parquet(conf_dict["companies_processed"])
relationships_processed_df = spark.read.parquet(conf_dict["relationships_processed"])
persons_processed_df = spark.read.parquet(conf_dict["persons_processed"])
nodes_df = spark.read.parquet(conf_dict["nodes"])
edges_df = spark.read.parquet(conf_dict["edges"])
connected_components = spark.read.parquet(conf_dict["connected_components"])

In [3]:
print(f"Node count: {nodes_df.count():,}")
print(f"Edge count: {edges_df.count():,}")

Node count: 12,716,813
Edge count: 5,704,926


## Graph

In [4]:
graph = gf.GraphFrame(connected_components, edges_df)

In [5]:
component_sizes = (
    connected_components.groupBy("component").count().orderBy(F.desc("count"))
)
print(f"Connected component count: {component_sizes.count():,}")



Connected component count: 7,648,306


                                                                                

In [14]:
component_sizes.filter("count <= 300").show()



+-----------+-----+
|  component|count|
+-----------+-----+
| 8589988227|  300|
|      16691|  298|
|17179878787|  289|
|      22303|  283|
|      40669|  273|
|       6587|  265|
|      15187|  263|
| 8589934701|  258|
|      12283|  252|
| 8589969829|  247|
|        774|  243|
|       1211|  236|
| 8589960992|  234|
|       9278|  231|
|      51040|  227|
|17179910232|  207|
|      38111|  207|
|17179898898|  206|
| 8589982438|  205|
| 8589941079|  203|
+-----------+-----+
only showing top 20 rows



                                                                                

In [6]:
large_components = component_sizes.filter("count >= 10")
large_components.count()
print(f"Large component count: {large_components.count():,}")



Large component count: 12,039


                                                                                

In [7]:
large_component_ids = [
    row.component for row in large_components.select("component").collect()
]
graph_filtered = (
    graph.filterVertices(F.col("component").isin(large_component_ids))
    .dropIsolatedVertices()
    .cache()
)

                                                                                

In [8]:
edges_filtered_df = graph_filtered.edges
edges_filtered_df.write.parquet("data/graph/component-edges.parquet", mode="overwrite")

                                                                                

In [9]:
graph_filtered.vertices.groupBy("isCompany").count().show()

                                                                                

+---------+------+
|isCompany| count|
+---------+------+
|     true|229736|
|    false| 37120|
+---------+------+



In [10]:
# graph_pageranked = graph_filtered.pageRank(resetProbability=0.1, maxIter=20)
# nodes_pageranked = graph_pageranked.vertices.select("id", F.col("pagerank").cast("Long"))

In [11]:
nodes_filtered_df = (
    graph_filtered.vertices.join(graph_filtered.inDegrees, ["id"], how="left")
    .join(graph_filtered.outDegrees, ["id"], how="left")
    .join(
        graph_filtered.triangleCount()
        .withColumnRenamed("count", "triangleCount")
        .select("id", "triangleCount"),
        ["id"],
    )
    # .join(nodes_pageranked, ["id"])
    .fillna(0)
)
nodes_filtered_df.write.parquet("data/graph/component-nodes.parquet", mode="overwrite")

                                                                                