diff --git a/splink/analyse_blocking_rule.py b/splink/analyse_blocking_rule.py index 203e8049c3..ea0f1b2f5e 100644 --- a/splink/analyse_blocking_rule.py +++ b/splink/analyse_blocking_rule.py @@ -216,7 +216,7 @@ def _get_largest_group(df, join_cols): largest_group_concat = collected[0]["concat_ws_expr"] return { "largest_group_expr": largest_group_concat, - "num_comparisons_generated_in_largest_group": largest_group_comparisons, + "comparisons_generated_in_largest_group_before_filter_applied": largest_group_comparisons, } @@ -261,7 +261,9 @@ def analyse_blocking_rule( total_comparisons_generated = 1 if compute_largest_group: results_dict["largest_group_expr"] = None - results_dict["num_comparisons_generated_in_largest_group"] = None + results_dict[ + "comparisons_generated_in_largest_group_before_filter_applied" + ] = None if results_dict["join_strategy"] == "Cartesian": raw_count = df.count() @@ -271,7 +273,9 @@ def analyse_blocking_rule( ] = total_comparisons_generated if compute_largest_group: results_dict["largest_group_expr"] = None - results_dict["num_comparisons_generated_in_largest_group"] = None + results_dict[ + "comparisons_generated_in_largest_group_before_filter_applied" + ] = None metric_message += "Blocking rule results in Cartesian join so largest groups not applicable.\n" if compute_exact_comparisons and total_comparisons_generated < compute_exact_limit: @@ -303,7 +307,7 @@ def analyse_blocking_rule( "total_comparisons_generated_before_filter_applied", "total_comparisons_generated_after_filters_applied", "largest_group_expr", - "num_comparisons_generated_in_largest_group", + "comparisons_generated_in_largest_group_before_filter_applied", "message", "join_strategy", "join_type", diff --git a/splink/default_settings.py b/splink/default_settings.py index 6f49527f7a..74e3fff389 100644 --- a/splink/default_settings.py +++ b/splink/default_settings.py @@ -36,7 +36,10 @@ def _get_default_case_statements_functions(spark): default_case_stmts["numeric"][3] = sql_gen_case_stmt_numeric_perc_3 default_case_stmts["numeric"][4] = sql_gen_case_stmt_numeric_perc_4 - jaro_exists = _check_jaro_registered(spark) + if spark: + jaro_exists = _check_jaro_registered(spark) + else: + jaro_exists = False if jaro_exists: default_case_stmts["string"][2] = sql_gen_case_smnt_strict_equality_2 @@ -147,7 +150,7 @@ def _complete_tf_adjustment_weights(col_settings: dict): ) else: weights = [0.0] * col_settings["num_levels"] - weights[-1] = 1.0 + weights[-1] = 1.0 col_settings["tf_adjustment_weights"] = weights diff --git a/tests/test_analyse_blocking_rules.py b/tests/test_analyse_blocking_rules.py new file mode 100644 index 0000000000..64f3b916a9 --- /dev/null +++ b/tests/test_analyse_blocking_rules.py @@ -0,0 +1,134 @@ +from pyspark.sql import Row +from splink.analyse_blocking_rule import analyse_blocking_rule +from splink.default_settings import complete_settings_dict + + +def test_analyse_blocking_rules(spark): + + # fmt: off + rows = [ + {"unique_id": 1, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df1"}, + {"unique_id": 2, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df2"}, + {"unique_id": 3, "mob": 10, "surname": "Linacer", "forename": "Robin", "source_dataset": "df1"}, + {"unique_id": 4, "mob": 7, "surname": "Smith", "forename": "John", "source_dataset": "df1"}, + {"unique_id": 5, "mob": 8, "surname": "Smith", "forename": "John", "source_dataset": "df2"}, + {"unique_id": 6, "mob": 8, "surname": "Smith", "forename": "Jon", "source_dataset": "df1"}, + {"unique_id": 7, "mob": 8, "surname": "Jones", "forename": "Robin", "source_dataset": "df2"}, + ] + # fmt: on + + df = spark.createDataFrame(Row(**x) for x in rows) + + splink_settings = { + "link_type": "dedupe_only", + "comparison_columns": [ + {"col_name": "surname"}, + {"col_name": "mob"}, + {"col_name": "forename"}, + ], + "blocking_rules": [], + } + + splink_settings = complete_settings_dict(splink_settings, None) + + ############ + # Test 1 + ############ + blocking_rule = "l.surname = r.surname" + + results = analyse_blocking_rule( + df, blocking_rule, splink_settings, compute_exact_comparisons=True + ) + + df.createOrReplaceTempView("df") + sql = f"select count(*) from df as l inner join df as r on {blocking_rule}" + expected_count = spark.sql(sql).collect()[0][0] + + expected = { + "total_comparisons_generated_before_filter_applied": expected_count, + "total_comparisons_generated_after_filters_applied": 4, + "largest_group_expr": "Smith", + "comparisons_generated_in_largest_group_before_filter_applied": 9, + "join_strategy": "SortMergeJoin", + "join_type": "Inner", + "join_hashpartition_columns_left": ["surname"], + "join_hashpartition_columns_right": ["surname"], + } + for key in expected.keys(): + assert results[key] == expected[key] + + ############ + # Test 2 - Cartesian + ############ + blocking_rule = "l.surname = r.surname or l.forename = r.forename" + + results = analyse_blocking_rule( + df, blocking_rule, splink_settings, compute_exact_comparisons=True + ) + + df.createOrReplaceTempView("df") + sql = f"select count(*) from df as l cross join df as r" + expected_count = spark.sql(sql).collect()[0][0] + + expected = { + "total_comparisons_generated_before_filter_applied": expected_count, + "total_comparisons_generated_after_filters_applied": 9, + "largest_group_expr": None, + "comparisons_generated_in_largest_group_before_filter_applied": None, + "join_strategy": "Cartesian", + "join_type": "Cartesian", + "join_hashpartition_columns_left": [], + "join_hashpartition_columns_right": [], + } + for key in expected.keys(): + assert results[key] == expected[key] + + ############ + # Test 3 - link only + ############ + blocking_rule = "l.surname = r.surname and l.forename = r.forename" + splink_settings = { + "link_type": "link_only", + "comparison_columns": [ + {"col_name": "surname"}, + {"col_name": "mob"}, + {"col_name": "forename"}, + ], + "blocking_rules": [], + } + + # fmt: off + rows = [ + {"unique_id": 1, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df1"}, + {"unique_id": 2, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df2"}, + {"unique_id": 3, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df1"}, + {"unique_id": 4, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df2"}, + {"unique_id": 5, "mob": 10, "surname": "Smith", "forename": "John", "source_dataset": "df2"}, + + ] + # fmt: on + + df = spark.createDataFrame(Row(**x) for x in rows) + + splink_settings = complete_settings_dict(splink_settings, None) + + results = analyse_blocking_rule( + df, blocking_rule, splink_settings, compute_exact_comparisons=True + ) + + df.createOrReplaceTempView("df") + sql = f"select count(*) from df as l inner join df as r on {blocking_rule}" + expected_count = spark.sql(sql).collect()[0][0] + + expected = { + "total_comparisons_generated_before_filter_applied": expected_count, + "total_comparisons_generated_after_filters_applied": 4, + "largest_group_expr": "Linacre|Robin", + "comparisons_generated_in_largest_group_before_filter_applied": 16, + "join_strategy": "SortMergeJoin", + "join_type": "Inner", + "join_hashpartition_columns_left": ["surname", "forename"], + "join_hashpartition_columns_right": ["surname", "forename"], + } + for key in expected.keys(): + assert results[key] == expected[key] \ No newline at end of file