In [1]:
from pyspark.sql.functions import lit
from pyspark.sql import SparkSession, functions as F
import pandas as pd
from packages.utils.transformations import MyTransformation as mtf
from packages.utils.spark_udfs import soundex_udf
from pyspark.sql.functions import col, array_contains


spark = SparkSession.builder.master("local[*]").appName("example2").getOrCreate()

# Read and rename columns as before
def read_with_origin(path, origin):
    df = spark.read.csv(path, header=None, inferSchema=True).select(["_c0","_c1","_c2","_c3","_c4","_c5"])
    for i in range(6):
        df = df.withColumnRenamed(f"_c{i}", str(i))
    df = df.withColumn("origin", lit(origin))
    return df

In [2]:
transformer = mtf()
match_column = 0
cols = ['1', '2', '3', '4', '5']

df1 =  pd.read_csv("data/df1.csv", usecols=[0,1,2,3,4,5], header=None)[[0,1,2,3,4,5]].sample(frac=0.1, random_state=42)
df2 =  pd.read_csv("data/df2.csv", usecols=[0,1,2,3,4,5], header=None)[[0,1,2,3,4,5]].sample(frac=0.1, random_state=42)

_expected = {"gt": 2531, "tp":1775, "fp":92, "fn":756}

transformer.apply_soundex(df1, match_column = 0)
transformer.apply_soundex(df2, match_column = 0)

df1['origin'] =  1
df2['origin'] =  2

spark_df1 = spark.createDataFrame(df1)
spark_df2 = spark.createDataFrame(df2)

concatData = spark_df1.union(spark_df2)

  for column, series in pdf.iteritems():
  for column, series in pdf.iteritems():


In [41]:
# Data for df1
data1 = [
    ["ID00005", "N039", "E298", "Q412", "V409", "R232"],  # TP1
    ["ID00009", "R822", "W179", "H017", "P323", "F298"],  # TP2
    ["ID00007", "R449", "X716", "M948", "G667", "S702"],  # TP3
    ["ID00004", "N002", "E396", "N843", "I458", "S719"],  # TP4
    ["ID10004", "N002", "E396", "N853", "I623", "S569"],  # FN1
    ["ID50004", "J547", "B222", "G492", "R551", "S490"],  # FP1
    ["IDTIE00", "N322", "K685", "T443", "C225", "W947"],  # FP-tie: this should be skipped
    ["ID50008", "N322", "K685", "T443", "C225", "W967"],  # FP2
    ["ID00000", "W815", "L281", "R155", "F768", "B914"],
    ["ID00001", "C172", "B326", "X400", "M508", "O776"],
    ["ID00002", "V683", "C265", "J127", "D589", "F482"],
    ["ID00003", "E851", "P721", "F745", "D863", "K229"],
    ["ID00016", "T873", "D670", "U046", "Z181", "X621"],
    ["ID00017", "F327", "G856", "E567", "O929", "Q721"],
    ["ID00010", "O283", "T723", "Z034", "V319", "X338"],
]

# Data for df2
data2 = [
    ["ID00005", "R746", "E298", "Q412", "L291", "R232"],  # TP1
    ["ID00009", "R822", "W179", "H017", "P323", "F298"],  # TP2
    ["ID00007", "Z011", "X716", "M948", "W967", "S702"],  # TP3
    ["ID00004", "N002", "E396", "N843", "V935", "S719"],  # TP4
    ["ID10004", "N002", "E396", "N553", "I453", "S459"],  # FN1
    ["NEW80187", "J547", "B222", "G492", "W673", "S490"],  # FP1
    ["NEW30110", "N322", "K685", "T432", "C225", "W967"],  # FP2
    ["NEW72832", "F875", "Q768", "H822", "Z154", "X678"],
    ["NEW30110", "R560", "C434", "M687", "Q689", "Q863"],
    ["NEW81243", "R762", "N687", "A109", "K476", "R637"],
    ["NEW52689", "A089", "V733", "W158", "A640", "H331"],
    ["NEW67368", "Z079", "J617", "G878", "W111", "Q500"],
    ["NEW72348", "J547", "B222", "G492", "R551", "S490"],
    ["NEW34469", "Y990", "H898", "W673", "L967", "M829"],
    ["NEW34462", "Y990", "H898", "W673", "L967", "M829"],
]

_expected = {'gt': 5, 'tp': 4, 'fp': 2, 'fn': 1}

#udfs 
# trim_udf = get_trimmer_udf(trim=0)

# Create DataFrames
columns = [0, 1, 2, 3, 4, 5]
df1 = pd.DataFrame(data1, columns=columns)
df2 = pd.DataFrame(data2, columns=columns)

df1['origin'] =  1
df2['origin'] =  2

spark_df1 = spark.createDataFrame(df1)
spark_df2 = spark.createDataFrame(df2)

  for column, series in pdf.iteritems():
  for column, series in pdf.iteritems():


In [50]:
from pyspark.sql.functions import first, collect_set, struct

cols = ['1', '2', '3', '4', '5']

ground_truth_ids = spark_df1.join(spark_df2, on=["0"], how="inner").select(F.col("0").alias("df1_id")).distinct()
ground_truth_ids.cache()
gt = ground_truth_ids.count()

# 1. Create entity keys for df1
df1_entities = spark_df1.withColumn(
    "entity_key", F.concat_ws("", *[F.col(c) for c in cols])
).select(F.col("0").alias("df1_id"), "entity_key", *cols)

# 2. Create entity keys for the bucket (use spark_df2 or concatData)
bucket_entities = spark_df2.withColumn(
    "entity_key", F.concat_ws("", *[F.col(c) for c in cols])
).select(
    F.col("0").alias("bucket_id"),
    "entity_key",
    *cols
)

# 3. Cross join and count matches
joined = df1_entities.alias("a").crossJoin(bucket_entities.alias("b"))
match_exprs = [(F.col(f"a.{col}") == F.col(f"b.{col}")).cast("int") for col in cols]
joined = joined.withColumn("match_count", sum(match_exprs))

# 4. Filter pairs with at least 3 matching columns
filtered = joined.filter(F.col("match_count") >= 3)

# 5. For each df1_id, find the max match_count
max_match = filtered.groupBy("a.df1_id").agg(F.max("match_count").alias("max_match_count")).withColumnRenamed("df1_id", "max_df1_id")

# 6. Join back to get all (df1_id, bucket_id) pairs with the max match_count
best_matches = filtered.join(
    max_match,
    (filtered["a.df1_id"] == max_match["max_df1_id"]) & (filtered["match_count"] == max_match["max_match_count"])
).select(
    F.col("a.df1_id").alias("df1_id"),
    F.col("b.bucket_id").alias("bucket_id"),
    F.col("b.entity_key").alias("entity_key"),
    F.col("match_count"),
    # Use SAME column names for both structs
    struct(*[F.col(f"a.{c}").alias(c) for c in cols]).alias("df1_columns"),
    struct(*[F.col(f"b.{c}").alias(c) for c in cols]).alias("df2_columns")
)

# 7. Group by bucket_id and entity_key, collecting all column variations
buckets_with_columns = best_matches.groupBy("bucket_id", "entity_key").agg(
    F.collect_list("df1_id").alias("assigned_df1_ids"),
    collect_set("df1_columns").alias("df1_column_values"),
    collect_set("df2_columns").alias("df2_column_values")
).withColumn(
    "all_column_values", 
    F.array_union(F.col("df1_column_values"), F.col("df2_column_values"))
).select(
    "entity_key",
    "bucket_id", 
    "assigned_df1_ids",
    "all_column_values"
)

buckets_with_columns.cache()
buckets_with_columns.show(truncate=False)

+--------------------+---------+------------------+------------------------------------------------------------------------------------------------+
|entity_key          |bucket_id|assigned_df1_ids  |all_column_values                                                                               |
+--------------------+---------+------------------+------------------------------------------------------------------------------------------------+
|R822W179H017P323F298|ID00009  |[ID00009]         |[{R822, W179, H017, P323, F298}]                                                                |
|R746E298Q412L291R232|ID00005  |[ID00005]         |[{N039, E298, Q412, V409, R232}, {R746, E298, Q412, L291, R232}]                                |
|N002E396N843V935S719|ID00004  |[ID00004]         |[{N002, E396, N843, I458, S719}, {N002, E396, N843, V935, S719}]                                |
|Z011X716M948W967S702|ID00007  |[ID00007]         |[{R449, X716, M948, G667, S702}, {Z011, X716, M948, W96

In [48]:
buckets.show()

+---------+--------------------+------------------+
|bucket_id|          entity_key|  assigned_df1_ids|
+---------+--------------------+------------------+
|  ID00009|R822W179H017P323F298|         [ID00009]|
|  ID00005|R746E298Q412L291R232|         [ID00005]|
|  ID00004|N002E396N843V935S719|         [ID00004]|
|  ID00007|Z011X716M948W967S702|         [ID00007]|
| NEW30110|N322K685T432C225W967|[ID50008, IDTIE00]|
| NEW72348|J547B222G492R551S490|         [ID50004]|
+---------+--------------------+------------------+

