In [8]:
from pyspark.sql.functions import lit
from pyspark.sql import SparkSession, functions as F
import pandas as pd
from packages.utils.transformations import MyTransformation as mtf
from packages.utils.spark_udfs import soundex_udf
from pyspark.sql.functions import col, array_contains


spark = SparkSession.builder.master("local[*]").appName("example2").getOrCreate()

# Read and rename columns as before
def read_with_origin(path, origin):
    df = spark.read.csv(path, header=None, inferSchema=True).select(["_c0","_c1","_c2","_c3","_c4","_c5"])
    for i in range(6):
        df = df.withColumnRenamed(f"_c{i}", str(i))
    df = df.withColumn("origin", lit(origin))
    return df

In [None]:
transformer = mtf()
match_column = 0
cols = ['1', '2', '3', '4', '5']

df1 =  pd.read_csv("data/df1.csv", usecols=[0,1,2,3,4,5], header=None)[[0,1,2,3,4,5]].sample(frac=0.1, random_state=42)
df2 =  pd.read_csv("data/df2.csv", usecols=[0,1,2,3,4,5], header=None)[[0,1,2,3,4,5]].sample(frac=0.1, random_state=42)

_expected = {"gt": 2531, "tp":1775, "fp":92, "fn":756}

transformer.apply_soundex(df1, match_column = 0)
transformer.apply_soundex(df2, match_column = 0)

df1['origin'] =  1
df2['origin'] =  2

spark_df1 = spark.createDataFrame(df1)
spark_df2 = spark.createDataFrame(df2)

concatData = spark_df1.union(spark_df2)
# for col_name in spark_df1.columns:
#     if str(col_name) != str(match_column):
#         spark_df1 = spark_df1.withColumn(str(col_name), trimmer_udf(spark_df1[col_name]))
#         spark_df2 = spark_df2.withColumn(str(col_name), trimmer_udf(spark_df2[col_name]))


cols = ['1', '2', '3', '4', '5']

ground_truth_ids = spark_df1.join(spark_df2, on=["0"], how="inner").select(F.col("0").alias("df1_id")).distinct()
ground_truth_ids.cache()
gt = ground_truth_ids.count()

# 1. Create entity keys for df1
df1_entities = spark_df1.withColumn(
    "entity_key", F.concat_ws("", *[F.col(c) for c in cols])
).select(F.col("0").alias("df1_id"), "entity_key", *cols)

# 2. Create entity keys for the bucket (use spark_df2 or concatData)
bucket_entities = spark_df2.withColumn(
    "entity_key", F.concat_ws("", *[F.col(c) for c in cols])
).select(
    F.col("0").alias("bucket_id"),
    "entity_key",
    *cols
)

# 3. Cross join and count matches
joined = df1_entities.alias("a").crossJoin(bucket_entities.alias("b"))
match_exprs = [(F.col(f"a.{col}") == F.col(f"b.{col}")).cast("int") for col in cols]
joined = joined.withColumn("match_count", sum(match_exprs))

# 4. Filter pairs with at least 3 matching columns
filtered = joined.filter(F.col("match_count") >= 3)

# 5. For each df1_id, find the max match_count
max_match = filtered.groupBy("a.df1_id").agg(F.max("match_count").alias("max_match_count")).withColumnRenamed("df1_id", "max_df1_id")

# 6. Join back to get all (df1_id, bucket_id) pairs with the max match_count
best_matches = filtered.join(
    max_match,
    (filtered["a.df1_id"] == max_match["max_df1_id"]) & (filtered["match_count"] == max_match["max_match_count"])
).select(
    F.col("a.df1_id").alias("df1_id"),
    F.col("b.bucket_id").alias("bucket_id"),
    F.col("match_count")
)

# 7. Group by bucket_id to collect all df1_ids assigned to it
buckets = best_matches.groupBy("bucket_id").agg(
    F.collect_list("df1_id").alias("assigned_df1_ids")
)

buckets.cache()

  for column, series in pdf.iteritems():
  for column, series in pdf.iteritems():


In [16]:
# True Positives
tp_arr = buckets \
    .filter(array_contains(col("assigned_df1_ids"), col("bucket_id"))) \
    .join(ground_truth_ids, buckets.bucket_id == ground_truth_ids.df1_id, how="inner")

tp = tp_arr.count()

fp = buckets.filter(~array_contains(col("assigned_df1_ids"), col("bucket_id"))).count()
fn = gt - tp

precision = tp / (tp + fp) if tp + fp > 0 else 0
recall = tp / (tp + fn) if tp + fn > 0 else 0

print(_expected)
print(f"Ground Truth (gt): {gt}")
print(f"True Positives (tp): {tp}")
print(f"False Positives (fp): {fp}")
print(f"False Negative (fn): {fn}")
print(f"precision: {precision}")
print(f"Recall: {recall}")

{'gt': 5, 'tp': 4, 'fp': 2, 'fn': 1}
Ground Truth (gt): 5
True Positives (tp): 4
False Positives (fp): 2
False Negative (fn): 1
precision: 0.6666666666666666
Recall: 0.8


In [13]:
buckets.write.mode("overwrite").parquet("data/output/buckets1.parquet")