In [1]:
!pip install jellyfish

Collecting jellyfish
  Downloading jellyfish-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.6 kB)
Downloading jellyfish-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (356 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m356.9/356.9 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: jellyfish
Successfully installed jellyfish-1.2.0


In [2]:
import os
import sys
import multiprocessing

project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))

DEV_ENV = "spark.local.conf"
# dev_evn = "spark.production.conf"
CONFIG_PATH = os.path.join(project_root, "packages", "conf", DEV_ENV)
OUTPUT_PATH = os.path.join(project_root, "output", "structs")


if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
from functools import reduce

from itertools import combinations
from pyspark.sql import SparkSession
from pyspark.sql.functions import row_number, lit, col, concat_ws, struct, collect_list, expr, explode, monotonically_increasing_id, count, when, broadcast, size,  avg, stddev
from pyspark.sql.types import StructType, StringType, IntegerType, VarcharType
from pyspark.storagelevel import StorageLevel

from packages.utils.dataset_generation import MyDatasets
from packages.utils.spark_udfs import soundex_udf

In [3]:
# ===============================================================================================
# 1. Create a Spark session

spark_builder = SparkSession.builder.master("local[*]").appName("LocalTesting")

with open(CONFIG_PATH, 'r') as f:
    for line in f:
        line = line.strip()
        if line and not line.startswith('#') and '=' in line:
            key, value = line.split('=', 1)
            spark_builder = spark_builder.config(key.strip(), value.strip())

spark = spark_builder.getOrCreate()
# ===============================================================================================

In [4]:
datasets = MyDatasets()
dfs = datasets.size_1000().exclude(["0"]).soundex().noise_10().hash_sha256()

In [7]:
dfs[0].to_csv("1000_10_A.csv")
dfs[1].to_csv("1000_10_B.csv")
dfs[2].to_csv("1000_10_C.csv")
dfs[3].to_csv("1000_10_D.csv")
dfs[4].to_csv("1000_10_E.csv")


In [11]:
DEBUG = False

COLUMNS = ["_c1", "_c2", "_c3", "_c4", "_c5"]
datasets = MyDatasets()
# dfs = datasets.size_1000().exclude(["0"]).soundex()
# dfs = datasets.size_1000().exclude(["0"]).soundex().noise_50().hash_sha256()
# dfs = datasets.size_10000()
# dfs = datasets.size_50000()
# dfs = datasets.size_75000()
# dfs = datasets.size_100000().exclude(["0"]).soundex().noise_200().hash_sha256() # full dataset size


In [12]:
# ===============================================================================================
# 2. Load data and preprocess them
processed_dfs = []

for i, df in enumerate(dfs):
    df = df[["0","1","2","3","4","5"]]
    df.columns = ["_c0", *COLUMNS]

    df_tmp = spark.createDataFrame(df)
    df_tmp = df_tmp.withColumn("id", monotonically_increasing_id())
    df_tmp = df_tmp.withColumn("origin", lit(i))

    # for c in COLUMNS:
    #     df_tmp = df_tmp.withColumn(c, soundex_udf(df_tmp[c]))

    processed_dfs.append(df_tmp.select('origin', 'id', '_c0', *COLUMNS))


multiparty_datasets = processed_dfs[0]
for df in processed_dfs[1:]:
    multiparty_datasets = multiparty_datasets.union(df)

# ===============================================================================================

In [13]:
# ===============================================================================================
# 3. Create the struct - The struct can be used for any combination of data requests we give
column_combinations = list(combinations(COLUMNS, 3))

blocking_passes = []
for pass_id, combo in enumerate(column_combinations):

    df_with_block_key = multiparty_datasets.withColumn(
        "block_key",
        concat_ws("_", *[col(c) for c in combo])
    )

    blocked_df = df_with_block_key.groupBy("block_key").agg(
        collect_list(struct(*['origin', 'id'])).alias("records")
    ).withColumn("pass_id", lit(pass_id))

    blocked_df = blocked_df.filter(size(col("records")) > 1)
    blocking_passes.append(blocked_df)

multiparty_struct = reduce(lambda df1, df2: df1.unionByName(df2), blocking_passes)
multiparty_struct = multiparty_struct.persist(StorageLevel.MEMORY_AND_DISK)
multiparty_struct.count()

if DEBUG : multiparty_struct.show(truncate=False)
# ===============================================================================================

In [14]:
# ============================================
# Per-pass block metrics
pass_block_counts = multiparty_struct.groupBy("pass_id").count().orderBy("pass_id")
pass_block_counts.show(truncate=False)

block_size_metrics = multiparty_struct.withColumn("block_size", size(col("records"))) \
    .groupBy("pass_id") \
    .agg(
        avg("block_size").alias("average_block_size"),
        stddev("block_size").alias("stddev_block_size"),
        count("*").alias("number_of_blocks")
    ) \
    .orderBy("pass_id")

block_size_metrics.show(truncate=False)

# ============================================
# Overall block metrics across all passes
overall_block_metrics = multiparty_struct.withColumn("block_size", size(col("records"))) \
    .agg(
        avg("block_size").alias("overall_average_block_size"),
        stddev("block_size").alias("overall_stddev_block_size"),
        count("*").alias("total_possible_matches")
    )

overall_block_metrics.show(truncate=False)

+-------+-----+
|pass_id|count|
+-------+-----+
|0      |28726|
|1      |21735|
|2      |37717|
|3      |24194|
|4      |40636|
|5      |96293|
|6      |27591|
|7      |49822|
|8      |46654|
|9      |48149|
+-------+-----+

+-------+------------------+------------------+----------------+
|pass_id|average_block_size|stddev_block_size |number_of_blocks|
+-------+------------------+------------------+----------------+
|0      |2.28702220984474  |0.6318635837788373|28726           |
|1      |2.3437313089487004|0.6190506253664175|21735           |
|2      |2.43351804226211  |0.9492441424139563|37717           |
|3      |2.330577829213855 |0.6174697571559731|24194           |
|4      |2.4723151885028054|1.0820396294686985|40636           |
|5      |2.4953942654190855|1.0434780792407468|96293           |
|6      |2.3637055561596174|0.6784951993596732|27591           |
|7      |3.00764722411786  |2.6086555787640253|49822           |
|8      |2.565010502850774 |1.1999464397700417|46654        

In [15]:
# ===============================================================================================
# 4. We have created our struct, we need to get the data now and we need to set the experiment.
# If we want to get the matches of party 0 with all the rest of the matches we need to see 
# the configuration like this example


# multiparty_struct = spark.read.parquet(os.path.join(OUTPUT_PATH, "multiparty_struct"))

to_party = 0
from_parties = (1, 2, 3, 4)

party = processed_dfs[to_party]

# Union only the specified from_parties datasets
ground_truth_datasets = processed_dfs[from_parties[0]]
for i in from_parties[1:]:
    ground_truth_datasets = ground_truth_datasets.union(processed_dfs[i])

# Perform join and count
ground_truth = party.select('_c0').join(
    ground_truth_datasets
        .filter(col("_c0") != 'fake')
        .select('_c0'),
    on='_c0',
    how='inner'
)
gt_counted = ground_truth.count()

# Convert from_parties to a Spark SQL IN clause string
from_parties_str = ','.join(str(p) for p in from_parties)

flattened_df = multiparty_struct.filter(
    expr(f"""
        exists(records, x -> x.origin = {to_party})
        AND exists(records, x -> x.origin IN ({from_parties_str}))
    """)
).select(
    explode(col("records")).alias("record")
).select(
    col("record.origin").alias("origin"),
    col("record.id").alias("id")
)

filtered_df = flattened_df.dropDuplicates(["origin", "id"])

results = []
for i in (0,1,2,3,4):
    # Assign the correct origin value to each dataset
    df_with_origin = processed_dfs[i].withColumn("origin", lit(i))

    # Filter only the relevant (origin, id) pairs for this dataset
    df_filtered = filtered_df.filter(col("origin") == (i))

    # Join to get _c0
    joined_df = df_with_origin.join(
        df_filtered,
        on=["origin", "id"],
        how="inner"
    ).select("origin", "id", "_c0")

    results.append(joined_df)

final_result = reduce(lambda df1, df2: df1.unionByName(df2), results)

multiparty_datasets.unpersist()

if DEBUG : final_result.show()
# ===============================================================================================

In [16]:
# ===============================================================================================
# 5. This is for the develpment stage to calcualte Statistics. It shouldnnt interfere with the 
# code in the production

filtered = final_result.join(
    ground_truth.distinct().select("_c0"),
    on="_c0",
    how="inner"
)

total = final_result.count()

gt = gt_counted
tp = filtered.select("_c0").count()
fp = total - tp
fn = gt - tp

precision = tp / (tp + fp) if (tp + fp) != 0 else 0
recall = tp / (tp + fn) if (tp + fn) != 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0

print(f"Ground Truth (GT): {gt}")
print(f"True Positives (TP): {tp}")
print(f"False Positives (FP): {fp}")
print(f"False Negatives (FN): {fn}")
print(f"Total Predictions: {total}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
# ===============================================================================================

Ground Truth (GT): 100000
True Positives (TP): 96421
False Positives (FP): 48010
False Negatives (FN): 3579
Total Predictions: 144431
Precision: 0.6676
Recall: 0.9642
F1 Score: 0.7889
