In [1]:
!pip install jellyfish

Collecting jellyfish
  Downloading jellyfish-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.6 kB)
Downloading jellyfish-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (356 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m356.9/356.9 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: jellyfish
Successfully installed jellyfish-1.2.0


In [1]:
import os
import sys
import multiprocessing

project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
from functools import reduce

from itertools import combinations
from pyspark.sql import SparkSession
from pyspark.sql.functions import row_number, lit, col, concat_ws, struct, collect_list, expr, explode, monotonically_increasing_id, count
from pyspark.storagelevel import StorageLevel

from packages.utils.dataset_generation import MyDatasets
from packages.utils.spark_udfs import soundex_udf

In [None]:
# ===============================================================================================
# 1. Create a Spark session
cpu_cores = multiprocessing.cpu_count()
print(cpu_cores)

# Load Spark configuration from file
config_path = os.path.join(project_root, "packages", "conf", "spark.local.conf")

# Read configuration file and build SparkSession
spark_builder = SparkSession.builder.master("local[*]").appName("LocalTesting")

with open(config_path, 'r') as f:
    for line in f:
        line = line.strip()
        if line and not line.startswith('#') and '=' in line:
            key, value = line.split('=', 1)
            spark_builder = spark_builder.config(key.strip(), value.strip())

spark = spark_builder.getOrCreate()

# spark = SparkSession.builder.master("local[*]")\
#     .appName("TungstenExample")\
#     .config("spark.sql.tungsten.enabled", "true")\
#     .config("spark.sql.codegen.wholeStage", "true")\
#     .config("spark.default.parallelism", "200")\
#     .config("spark.sql.shuffle.partitions", str(cpu_cores * 3))\
#     .config("spark.storage.memoryFraction", "0.4")\
#     .config("spark.executor.memory", "4g")\
#     .config("spark.driver.memory", "4g")\
#     .config("spark.sql.adaptive.enabled", "true")\
#     .config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "64MB")\
#     .config("spark.sql.adaptive.coalescePartitions.enabled", "true")\
#     .config("spark.sql.adaptive.localShuffleReader.enabled", "true")\
#     .config("spark.sql.shuffle.partitions", "200")\
#     .config("spark.dynamicAllocation.enabled", "true")\
#     .config("spark.dynamicAllocation.minExecutors", "1")\
#     .config("spark.dynamicAllocation.maxExecutors", "10")\
#     .config("spark.dynamicAllocation.executorIdleTimeout", "30s")\
#     .getOrCreate()

# ===============================================================================================

8


In [8]:
DEBUG = False

COLUMNS = ["_c1", "_c2", "_c3", "_c4", "_c5"]
datasets = MyDatasets()
dfs = datasets.size_1000().exclude(["0"]).soundex()
# dfs = datasets.size_1000().exclude(["0"]).soundex().noise_50().hash_sha256()
# dfs = datasets.size_10000()
# dfs = datasets.size_50000()
# dfs = datasets.size_75000()
# dfs = datasets.size_100000().exclude(["0"]).soundex().noise_600().hash_sha256() # we call all the data here


In [9]:
# ===============================================================================================
# 2. Load data and preprocess them
processed_dfs = []

for i, df in enumerate(dfs):
    df = df[["0","1","2","3","4","5"]]
    df.columns = ["_c0", *COLUMNS]

    df_tmp = spark.createDataFrame(df)
    df_tmp = df_tmp.withColumn("id", monotonically_increasing_id())
    df_tmp = df_tmp.withColumn("origin", lit(i))

    # for c in COLUMNS:
    #     df_tmp = df_tmp.withColumn(c, soundex_udf(df_tmp[c]))

    processed_dfs.append(df_tmp.select('origin', 'id', '_c0', *COLUMNS))


multiparty_datasets = processed_dfs[0]
for df in processed_dfs[1:]:
    multiparty_datasets = multiparty_datasets.union(df)

# ===============================================================================================

# ===============================================================================================
# 3. Create the struct - The struct can be used for any combination of data request we give
column_combinations = list(combinations(COLUMNS, 3))

blocking_passes = []
for pass_id, combo in enumerate(column_combinations):

    df_with_block_key = multiparty_datasets.withColumn(
        "block_key",
        concat_ws("_", *[col(c) for c in combo])
    )

    blocked_df = df_with_block_key.groupBy("block_key").agg(
        collect_list(struct(*['origin', 'id'])).alias("records")
    ).withColumn("pass_id", lit(pass_id))

    blocking_passes.append(blocked_df)

multiparty_struct = reduce(lambda df1, df2: df1.unionByName(df2), blocking_passes).persist(StorageLevel.DISK_ONLY)
if DEBUG : multiparty_struct.show(truncate=False)
# ===============================================================================================

In [None]:
# ===============================================================================================
# 2. Load data and preprocess them
processed_dfs = []

for i, df in enumerate(dfs):
    df = df[["0", "1", "2", "3", "4", "5"]]
    df.columns = ["_c0", *COLUMNS]

    df_tmp = spark.createDataFrame(df)
    df_tmp = df_tmp.withColumn("id", monotonically_increasing_id())
    df_tmp = df_tmp.withColumn("origin", lit(i))

    processed_dfs.append(df_tmp.select('origin', 'id', '_c0', *COLUMNS))

multiparty_datasets = processed_dfs[0]
for df in processed_dfs[1:]:
    multiparty_datasets = multiparty_datasets.union(df)

# ===============================================================================================
# 3. Create the struct with greedy record assignment

column_combinations = list(combinations(COLUMNS, 3))

blocking_passes = []
assigned_records = spark.createDataFrame([], schema=multiparty_datasets.select('origin', 'id').schema)  # empty tracker

for pass_id, combo in enumerate(column_combinations):
    # Keep only records not yet assigned
    unassigned_records = multiparty_datasets.join(
        assigned_records, on=['origin', 'id'], how='left_anti'
    )

    if unassigned_records.isEmpty():
        print(f"All records assigned by pass {pass_id - 1}")
        break  # Stop if no records remain

    # Build block key
    df_with_block_key = unassigned_records.withColumn(
        "block_key",
        concat_ws("_", *[col(c) for c in combo])
    )

    blocked_df = df_with_block_key.groupBy("block_key").agg(
        collect_list(struct(*['origin', 'id'])).alias("records"),
        count("*").alias("block_size")
    ).filter(col("block_size") > 1).withColumn("pass_id", lit(pass_id))

    blocking_passes.append(blocked_df)

    # Track assigned records
    new_assigned = blocked_df.select(explode('records').alias('record')).select('record.origin', 'record.id')
    assigned_records = assigned_records.union(new_assigned).dropDuplicates(['origin', 'id'])

# Combine all passes
if blocking_passes:
    multiparty_struct = reduce(lambda df1, df2: df1.unionByName(df2), blocking_passes).persist(StorageLevel.DISK_ONLY)
    if DEBUG: multiparty_struct.show(truncate=False)
else:
    multiparty_struct = spark.createDataFrame([], schema=multiparty_datasets.schema)

# ===============================================================================================

In [None]:
multiparty_struct.show()

In [10]:
# ===============================================================================================
# 4. We have created our struct, we need to get the data now and we need to set the experiment.
# If we want to get the matches of party 0 with all the rest of the matches we need to see 
# the configuration like this example

to_party = 0
from_parties = (1, 2, 3, 4)

party = processed_dfs[to_party]

# Union only the specified from_parties datasets
ground_truth_datasets = processed_dfs[from_parties[0]]
for i in from_parties[1:]:
    ground_truth_datasets = ground_truth_datasets.union(processed_dfs[i])

# Perform join and count
ground_truth = party.select('_c0').join(
    ground_truth_datasets
        .filter(col("_c0") != 'fake')
        .select('_c0'),
    on='_c0',
    how='inner'
)
gt_counted = ground_truth.count()

# Convert from_parties to a Spark SQL IN clause string
from_parties_str = ','.join(str(p) for p in from_parties)

flattened_df = multiparty_struct.filter(
    expr(f"""
        exists(records, x -> x.origin = {to_party})
        AND exists(records, x -> x.origin IN ({from_parties_str}))
    """)
).select(
    explode(col("records")).alias("record")
).select(
    col("record.origin").alias("origin"),
    col("record.id").alias("id")
)

filtered_df = flattened_df.dropDuplicates(["origin", "id"])

results = []
for i in (0,1,2,3,4):
    # Assign the correct origin value to each dataset
    df_with_origin = processed_dfs[i].withColumn("origin", lit(i))

    # Filter only the relevant (origin, id) pairs for this dataset
    df_filtered = filtered_df.filter(col("origin") == (i))

    # Join to get _c0
    joined_df = df_with_origin.join(
        df_filtered,
        on=["origin", "id"],
        how="inner"
    ).select("origin", "id", "_c0")

    results.append(joined_df)

final_result = reduce(lambda df1, df2: df1.unionByName(df2), results)
if DEBUG : final_result.show()
# ===============================================================================================

In [11]:
# ===============================================================================================
# 5. This is for the develpment stage to calcualte Statistics. It shouldnnt interfere with the 
# code in the production

filtered = final_result.join(
    ground_truth.distinct().select("_c0"),
    on="_c0",
    how="inner"
)

total = final_result.count()

gt = gt_counted
tp = filtered.select("_c0").count()
fp = total - tp
fn = gt - tp

precision = tp / (tp + fp) if (tp + fp) != 0 else 0
recall = tp / (tp + fn) if (tp + fn) != 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0

print(f"Ground Truth (GT): {gt}")
print(f"True Positives (TP): {tp}")
print(f"False Positives (FP): {fp}")
print(f"False Negatives (FN): {fn}")
print(f"Total Predictions: {total}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")


# metrics are aslo very important. To visualize the struct
# pass_block_counts = multiparty_struct.groupBy("pass_id").count().orderBy("pass_id")
# pass_block_counts.show(truncate=False)

# ===============================================================================================

Ground Truth (GT): 1000
True Positives (TP): 934
False Positives (FP): 7
False Negatives (FN): 66
Total Predictions: 941
Precision: 0.9926
Recall: 0.9340
F1 Score: 0.9624
