From 84de3fb8acc9c777cc194fb849b5b126229da4d5 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 3 Mar 2020 17:36:38 +0000 Subject: [PATCH] add to tests including cartesian option and no blocking rules --- splink/blocking.py | 4 +-- tests/test_spark.py | 70 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/splink/blocking.py b/splink/blocking.py index 1263202101..e37aee3b5e 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -117,7 +117,6 @@ def sql_gen_block_using_rules( # Where a record from left and right are being compared, you want the left record to end up in the _l fields, and the right record to end up in _r fields. where_condition = f"where (l._source_table < r._source_table) or (l.{unique_id_col} < r.{unique_id_col} and l._source_table = r._source_table)" - sqls = [] previous_rules =[] for rule in blocking_rules: @@ -268,10 +267,8 @@ def cartesian_block( pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison """ - link_type = settings["link_type"] - columns_to_retain = _get_columns_to_retain_blocking(settings) unique_id_col = settings["unique_id_column_name"] @@ -283,6 +280,7 @@ def cartesian_block( df_r.createOrReplaceTempView("df_r") if link_type == "link_and_dedupe": + columns_to_retain.append("_source_table") df_concat = vertically_concatenate_datasets(df_l, df_r, settings, spark=spark) df_concat.createOrReplaceTempView("df") df_concat.persist() diff --git a/tests/test_spark.py b/tests/test_spark.py index 815a182703..7072c73f36 100644 --- a/tests/test_spark.py +++ b/tests/test_spark.py @@ -454,6 +454,53 @@ def test_link_option_link_dedupe(spark, link_dedupe_data_repeat_ids): assert list(df["u_l"]) == ['2l', '1l', '1l', '2l', '2l', '3l', '3l', '1r', '2r'] assert list(df["u_r"]) == ['3l', '1r', '3r', '2r', '3r', '2r', '3r', '3r', '3r'] + # Same for no blocking rules = cartesian product + + settings = { + "link_type": "link_and_dedupe", + "comparison_columns": [{"col_name": "first_name"}, + {"col_name": "surname"}], + "blocking_rules": [ + ] + } + settings = complete_settings_dict(settings, spark=None) + dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids) + df_l = spark.createDataFrame(dfpd_l) + dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids) + df_r = spark.createDataFrame(dfpd_r) + df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r) + df = df.toPandas() + + df["u_l"] = df["unique_id_l"].astype(str) + df["_source_table_l"].str.slice(0,1) + df["u_r"] = df["unique_id_r"].astype(str) + df["_source_table_r"].str.slice(0,1) + df = df.sort_values(["_source_table_l", "unique_id_l","_source_table_r", "unique_id_r"]) + + assert list(df["u_l"]) == ['1l', '1l', '1l', '1l', '1l', '2l', '2l', '2l', '2l', '3l', '3l', '3l', '1r', '1r', '2r'] + assert list(df["u_r"]) == ['2l', '3l', '1r', '2r', '3r', '3l', '1r', '2r', '3r', '1r', '2r', '3r', '2r', '3r', '3r'] + + + + # Same for cartesian product + + settings = { + "link_type": "link_and_dedupe", + "comparison_columns": [{"col_name": "first_name"}, + {"col_name": "surname"}], + "cartesian_product": True + } + settings = complete_settings_dict(settings, spark=None) + dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids) + df_l = spark.createDataFrame(dfpd_l) + dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids) + df_r = spark.createDataFrame(dfpd_r) + df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r) + df = df.toPandas() + df["u_l"] = df["unique_id_l"].astype(str) + df["_source_table_l"].str.slice(0,1) + df["u_r"] = df["unique_id_r"].astype(str) + df["_source_table_r"].str.slice(0,1) + df = df.sort_values(["_source_table_l", "unique_id_l","_source_table_r", "unique_id_r"]) + + assert list(df["u_l"]) == ['1l', '1l', '1l', '1l', '1l', '2l', '2l', '2l', '2l', '3l', '3l', '3l', '1r', '1r', '2r'] + assert list(df["u_r"]) == ['2l', '3l', '1r', '2r', '3r', '3l', '1r', '2r', '3r', '1r', '2r', '3r', '2r', '3r', '3r'] def test_link_option_link(spark, link_dedupe_data_repeat_ids): settings = { @@ -478,6 +525,29 @@ def test_link_option_link(spark, link_dedupe_data_repeat_ids): assert list(df["unique_id_l"]) == [1, 1, 2, 2, 3, 3] assert list(df["unique_id_r"]) == [1, 3, 2, 3, 2, 3] + # Test cartesian version + + settings = { + "link_type": "link_only", + "comparison_columns": [{"col_name": "first_name"}, + {"col_name": "surname"}], + "blocking_rules": [ + + ] + } + settings = complete_settings_dict(settings, spark=None) + dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids) + df_l = spark.createDataFrame(dfpd_l) + dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids) + df_r = spark.createDataFrame(dfpd_r) + df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r) + df = df.toPandas() + + df = df.sort_values(["unique_id_l", "unique_id_r"]) + + assert list(df["unique_id_l"]) == [1, 1, 1, 2, 2, 2, 3, 3, 3] + assert list(df["unique_id_r"]) == [1, 2, 3, 1, 2, 3, 1, 2, 3] + def test_link_option_dedupe_only(spark, link_dedupe_data_repeat_ids):