Skip to content

Commit

Permalink
test analyse blocking rules
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Dec 23, 2021
1 parent e8eb566 commit 64ad9dc
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 6 deletions.
12 changes: 8 additions & 4 deletions splink/analyse_blocking_rule.py
Expand Up @@ -216,7 +216,7 @@ def _get_largest_group(df, join_cols):
largest_group_concat = collected[0]["concat_ws_expr"]
return {
"largest_group_expr": largest_group_concat,
"num_comparisons_generated_in_largest_group": largest_group_comparisons,
"comparisons_generated_in_largest_group_before_filter_applied": largest_group_comparisons,
}


Expand Down Expand Up @@ -261,7 +261,9 @@ def analyse_blocking_rule(
total_comparisons_generated = 1
if compute_largest_group:
results_dict["largest_group_expr"] = None
results_dict["num_comparisons_generated_in_largest_group"] = None
results_dict[
"comparisons_generated_in_largest_group_before_filter_applied"
] = None

if results_dict["join_strategy"] == "Cartesian":
raw_count = df.count()
Expand All @@ -271,7 +273,9 @@ def analyse_blocking_rule(
] = total_comparisons_generated
if compute_largest_group:
results_dict["largest_group_expr"] = None
results_dict["num_comparisons_generated_in_largest_group"] = None
results_dict[
"comparisons_generated_in_largest_group_before_filter_applied"
] = None
metric_message += "Blocking rule results in Cartesian join so largest groups not applicable.\n"

if compute_exact_comparisons and total_comparisons_generated < compute_exact_limit:
Expand Down Expand Up @@ -303,7 +307,7 @@ def analyse_blocking_rule(
"total_comparisons_generated_before_filter_applied",
"total_comparisons_generated_after_filters_applied",
"largest_group_expr",
"num_comparisons_generated_in_largest_group",
"comparisons_generated_in_largest_group_before_filter_applied",
"message",
"join_strategy",
"join_type",
Expand Down
7 changes: 5 additions & 2 deletions splink/default_settings.py
Expand Up @@ -36,7 +36,10 @@ def _get_default_case_statements_functions(spark):
default_case_stmts["numeric"][3] = sql_gen_case_stmt_numeric_perc_3
default_case_stmts["numeric"][4] = sql_gen_case_stmt_numeric_perc_4

jaro_exists = _check_jaro_registered(spark)
if spark:
jaro_exists = _check_jaro_registered(spark)
else:
jaro_exists = False

if jaro_exists:
default_case_stmts["string"][2] = sql_gen_case_smnt_strict_equality_2
Expand Down Expand Up @@ -147,7 +150,7 @@ def _complete_tf_adjustment_weights(col_settings: dict):
)
else:
weights = [0.0] * col_settings["num_levels"]
weights[-1] = 1.0
weights[-1] = 1.0
col_settings["tf_adjustment_weights"] = weights


Expand Down
134 changes: 134 additions & 0 deletions tests/test_analyse_blocking_rules.py
@@ -0,0 +1,134 @@
from pyspark.sql import Row
from splink.analyse_blocking_rule import analyse_blocking_rule
from splink.default_settings import complete_settings_dict


def test_analyse_blocking_rules(spark):

# fmt: off
rows = [
{"unique_id": 1, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df1"},
{"unique_id": 2, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df2"},
{"unique_id": 3, "mob": 10, "surname": "Linacer", "forename": "Robin", "source_dataset": "df1"},
{"unique_id": 4, "mob": 7, "surname": "Smith", "forename": "John", "source_dataset": "df1"},
{"unique_id": 5, "mob": 8, "surname": "Smith", "forename": "John", "source_dataset": "df2"},
{"unique_id": 6, "mob": 8, "surname": "Smith", "forename": "Jon", "source_dataset": "df1"},
{"unique_id": 7, "mob": 8, "surname": "Jones", "forename": "Robin", "source_dataset": "df2"},
]
# fmt: on

df = spark.createDataFrame(Row(**x) for x in rows)

splink_settings = {
"link_type": "dedupe_only",
"comparison_columns": [
{"col_name": "surname"},
{"col_name": "mob"},
{"col_name": "forename"},
],
"blocking_rules": [],
}

splink_settings = complete_settings_dict(splink_settings, None)

############
# Test 1
############
blocking_rule = "l.surname = r.surname"

results = analyse_blocking_rule(
df, blocking_rule, splink_settings, compute_exact_comparisons=True
)

df.createOrReplaceTempView("df")
sql = f"select count(*) from df as l inner join df as r on {blocking_rule}"
expected_count = spark.sql(sql).collect()[0][0]

expected = {
"total_comparisons_generated_before_filter_applied": expected_count,
"total_comparisons_generated_after_filters_applied": 4,
"largest_group_expr": "Smith",
"comparisons_generated_in_largest_group_before_filter_applied": 9,
"join_strategy": "SortMergeJoin",
"join_type": "Inner",
"join_hashpartition_columns_left": ["surname"],
"join_hashpartition_columns_right": ["surname"],
}
for key in expected.keys():
assert results[key] == expected[key]

############
# Test 2 - Cartesian
############
blocking_rule = "l.surname = r.surname or l.forename = r.forename"

results = analyse_blocking_rule(
df, blocking_rule, splink_settings, compute_exact_comparisons=True
)

df.createOrReplaceTempView("df")
sql = f"select count(*) from df as l cross join df as r"
expected_count = spark.sql(sql).collect()[0][0]

expected = {
"total_comparisons_generated_before_filter_applied": expected_count,
"total_comparisons_generated_after_filters_applied": 9,
"largest_group_expr": None,
"comparisons_generated_in_largest_group_before_filter_applied": None,
"join_strategy": "Cartesian",
"join_type": "Cartesian",
"join_hashpartition_columns_left": [],
"join_hashpartition_columns_right": [],
}
for key in expected.keys():
assert results[key] == expected[key]

############
# Test 3 - link only
############
blocking_rule = "l.surname = r.surname and l.forename = r.forename"
splink_settings = {
"link_type": "link_only",
"comparison_columns": [
{"col_name": "surname"},
{"col_name": "mob"},
{"col_name": "forename"},
],
"blocking_rules": [],
}

# fmt: off
rows = [
{"unique_id": 1, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df1"},
{"unique_id": 2, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df2"},
{"unique_id": 3, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df1"},
{"unique_id": 4, "mob": 10, "surname": "Linacre", "forename": "Robin", "source_dataset": "df2"},
{"unique_id": 5, "mob": 10, "surname": "Smith", "forename": "John", "source_dataset": "df2"},

]
# fmt: on

df = spark.createDataFrame(Row(**x) for x in rows)

splink_settings = complete_settings_dict(splink_settings, None)

results = analyse_blocking_rule(
df, blocking_rule, splink_settings, compute_exact_comparisons=True
)

df.createOrReplaceTempView("df")
sql = f"select count(*) from df as l inner join df as r on {blocking_rule}"
expected_count = spark.sql(sql).collect()[0][0]

expected = {
"total_comparisons_generated_before_filter_applied": expected_count,
"total_comparisons_generated_after_filters_applied": 4,
"largest_group_expr": "Linacre|Robin",
"comparisons_generated_in_largest_group_before_filter_applied": 16,
"join_strategy": "SortMergeJoin",
"join_type": "Inner",
"join_hashpartition_columns_left": ["surname", "forename"],
"join_hashpartition_columns_right": ["surname", "forename"],
}
for key in expected.keys():
assert results[key] == expected[key]

0 comments on commit 64ad9dc

Please sign in to comment.