In [1]:
# Setup Spark and Graph Libraries
!pip install -q pyspark networkx nx_altair

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyspark
  Downloading pyspark-3.3.0.tar.gz (281.3 MB)
[K     |████████████████████████████████| 281.3 MB 46 kB/s 
Collecting nx_altair
  Downloading nx_altair-0.1.6-py3-none-any.whl (7.9 kB)
Collecting py4j==0.10.9.5
  Downloading py4j-0.10.9.5-py2.py3-none-any.whl (199 kB)
[K     |████████████████████████████████| 199 kB 43.8 MB/s 
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.3.0-py2.py3-none-any.whl size=281764026 sha256=baedaa8184245c267f414552b0c326783606f4b8554ed38cb1cd82dd2a92f5c4
  Stored in directory: /root/.cache/pip/wheels/7a/8e/1b/f73a52650d2e5f337708d9f6a1750d451a7349a867f928b885
Successfully built pyspark
Installing collected packages: py4j, pyspark, nx-altair
Successfully installed nx-altair-0.1.6 py4j-0.10.9.5 pyspark-3.3.0


In [2]:
from functools import reduce

import networkx as nx
import nx_altair as nxa
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql.window import Window

# Create a Spark Session with GraphFrames
spark = (
    SparkSession.builder.master("local[*]").appName("Spark Session")
    .config("spark.sql.repl.eagerEval.enabled", True)
    .config("spark.jars.packages", "graphframes:graphframes:0.8.2-spark3.2-s_2.12")
    .getOrCreate()
)

from graphframes import GraphFrame

# Check Spark Session Information
spark

# Identity Resolution

In this excercise we'll be using fake customer records to demonstrate the Connected Components algorithm.

Here we've setup DataFrames from three different systems that interact with customers:
1. Loyalty: this could be a CRM or POS system
2. Ecommerce: this could be an online platform like Shopify
3. Subscriber: this could be an ESP like Klaviyo


Across these systems we have two identifiers - `email` and `loyalty_id`. We've generated data quality issues to see how the algorithm can overcome these challenges.

In [3]:
loyalty_records = spark.createDataFrame(
    [
        Row(email="alice@readmail.test", loyalty_id=101),
        Row(email="blake@readmail.test", loyalty_id=102),
        Row(email="clare@readmail.test", loyalty_id=103),
        Row(email="dylan@readmail.test", loyalty_id=104),
        Row(email="ethan@readmail.test", loyalty_id=105),
        Row(email="fiona@readmail.test", loyalty_id=106),
        Row(email="grant@readmail.test", loyalty_id=107),
        Row(email="hanna@readmail.test", loyalty_id=108),
    ],
)

ecommerce_records = spark.createDataFrame(
    [
        Row(email="alice@readmail.test", loyalty_id=101),
        Row(email="blake@readmail.test", loyalty_id=102),
        # Same email address, different loyalty_id
        Row(email="ethan@readmail.test", loyalty_id=112),
        # Same email address, different loyalty_id
        Row(email="fiona@readmail.test", loyalty_id=212),
        Row(email="grant@readmail.test", loyalty_id=None),
        Row(email="hanna@readmail.test", loyalty_id=None),
    ],
)

subscriber_records = spark.createDataFrame(
    [
        # Differfent email domain, same loyalty_id
        Row(email="alice@bulkmail.test", loyalty_id=101),
        # Misspelled email username, same loyalty_id
        Row(email="bloke@readmail.test", loyalty_id=102),
        Row(email="dylan@readmail.test", loyalty_id=None),
        # Many loyalty_ids by the same email address
        Row(email="grant@readmail.test", loyalty_id=157),
        Row(email="grant@readmail.test", loyalty_id=257),
        Row(email="grant@readmail.test", loyalty_id=357),
    ],
)

customers = loyalty_records.union(ecommerce_records).union(subscriber_records)

customers

email,loyalty_id
alice@readmail.test,101.0
blake@readmail.test,102.0
clare@readmail.test,103.0
dylan@readmail.test,104.0
ethan@readmail.test,105.0
fiona@readmail.test,106.0
grant@readmail.test,107.0
hanna@readmail.test,108.0
alice@readmail.test,101.0
blake@readmail.test,102.0


# Vertices

Data from all systems is unioned into a single dataframe `customers`. To form our `vertices` table we simply add a row number as the `id` column. This column is necessary for constructing the graph.

Note that we prefer [row_number](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.row_number.html#pyspark.sql.functions.row_number) over [monotonically_increasing_id](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.monotonically_increasing_id.html#pyspark.sql.functions.monotonically_increasing_id) which is non-deterministic.

In [4]:
# Vertics is the same as customers, but with a row number as the `id`
# This id is necessary for constructing the graph

vertices = customers.withColumn(
    "id", F.row_number().over(Window.partitionBy(F.lit(1)).orderBy(F.asc("email")))
)

vertices

email,loyalty_id,id
alice@bulkmail.test,101.0,1
alice@readmail.test,101.0,2
alice@readmail.test,101.0,3
blake@readmail.test,102.0,4
blake@readmail.test,102.0,5
bloke@readmail.test,102.0,6
clare@readmail.test,103.0,7
dylan@readmail.test,104.0,8
dylan@readmail.test,,9
ethan@readmail.test,105.0,10


# Edges

Edges are defined as relationships between our vertices where an identifier is shared. We'll need to *discover* these relationships ourselves and provide to GraphFrames as another dataframe.

Here's the method we'll use:
1. Organise the identifiers of each `id` row into a single stack `identifiers`
2. Copy the `identifiers` dataframe as `mirror` and prepend column names with `_`
3. Join `identifiers` with `mirror` wherever the identifiers match but `identifiers.id` < `mirror._id`

In [5]:
IDENTIFIER_FIELDS = ["email", "loyalty_id"]

identifier_frames = (
    vertices.select(
        "id",
        F.col(idef).alias("link_value"),
        F.lit(idef).alias("link_type"),
        F.lit("Shares " + idef).alias("action"),
    )
    for idef in IDENTIFIER_FIELDS
)

identifiers = reduce(DataFrame.unionAll, identifier_frames).where(
    F.col("link_value").isNotNull()
)

identifiers

id,link_value,link_type,action
1,alice@bulkmail.test,email,Shares email
2,alice@readmail.test,email,Shares email
3,alice@readmail.test,email,Shares email
4,blake@readmail.test,email,Shares email
5,blake@readmail.test,email,Shares email
6,bloke@readmail.test,email,Shares email
7,clare@readmail.test,email,Shares email
8,dylan@readmail.test,email,Shares email
9,dylan@readmail.test,email,Shares email
10,ethan@readmail.test,email,Shares email


In [6]:
mirror = identifiers.toDF(*("_" + col for col in identifiers.columns))

mirror

_id,_link_value,_link_type,_action
1,alice@bulkmail.test,email,Shares email
2,alice@readmail.test,email,Shares email
3,alice@readmail.test,email,Shares email
4,blake@readmail.test,email,Shares email
5,blake@readmail.test,email,Shares email
6,bloke@readmail.test,email,Shares email
7,clare@readmail.test,email,Shares email
8,dylan@readmail.test,email,Shares email
9,dylan@readmail.test,email,Shares email
10,ethan@readmail.test,email,Shares email


In [7]:
join_conditions = [
    (F.col("id") < F.col("_id")) & (F.col("link_value") == F.col("_link_value"))
]

edges = identifiers.join(mirror, join_conditions).select(
    F.col("id").alias("src"),
    F.col("_id").alias("dst"),
    "link_value",
    "link_type",
    "action",
)

edges

src,dst,link_value,link_type,action
2,3,alice@readmail.test,email,Shares email
4,5,blake@readmail.test,email,Shares email
8,9,dylan@readmail.test,email,Shares email
10,11,ethan@readmail.test,email,Shares email
12,13,fiona@readmail.test,email,Shares email
14,18,grant@readmail.test,email,Shares email
14,17,grant@readmail.test,email,Shares email
14,16,grant@readmail.test,email,Shares email
14,15,grant@readmail.test,email,Shares email
15,18,grant@readmail.test,email,Shares email


# Graph

Now we have everything in place to construct the `graph` and call [connectedComponents](https://graphframes.github.io/graphframes/docs/_site/user-guide.html#connected-components).

In [8]:
# create graph
graph = GraphFrame(vertices, edges)

# run connected components
spark.sparkContext.setCheckpointDir("/tmp/checkpoints")
cc = graph.connectedComponents()

cc

  "DataFrame.sql_ctx is an internal property, and will be removed "


email,loyalty_id,id,component
alice@bulkmail.test,101.0,1,1
alice@readmail.test,101.0,2,1
alice@readmail.test,101.0,3,1
blake@readmail.test,102.0,4,4
blake@readmail.test,102.0,5,4
bloke@readmail.test,102.0,6,4
clare@readmail.test,103.0,7,7
dylan@readmail.test,104.0,8,8
dylan@readmail.test,,9,8
ethan@readmail.test,105.0,10,10


# Visualisation

Connected components has discovered the distinct set of identities in our data. The `component` number stores this for us - it is the minimum `id` of all vertices.

As we're using a toy dataset it's easy enough to visualise this with NetworkX.

In [9]:
G = nx.Graph()

node_data = (
    cc.withColumn("component", F.col("component").astype("string"))
    .toPandas()
    .set_index("id")
    .to_dict("index")
)

for node_id, data in node_data.items():
    G.add_node(node_id, **data)

edge_data = (
    edges.toPandas()
    .drop_duplicates(subset=["src", "dst"])
    .set_index(["src", "dst"])
    .to_dict("index")
)

for edge_id, data in edge_data.items():
    G.add_edge(edge_id[0], edge_id[1], **data)


nxa.draw_networkx(
    G=G,
    pos=nx.spring_layout(G),
    node_color="component",
    node_tooltip=["component", "email", "loyalty_id"],
    edge_tooltip=["action"],
).properties(
    title="Identity Graph View",
    height=800,
    width=1200,
).interactive()