-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,319 @@ | ||
from copy import deepcopy | ||
import re | ||
|
||
from splink.blocking import _sql_gen_where_condition, block_using_rules | ||
|
||
|
||
def _get_queryplan_text(df, blocking_rule, splink_settings): | ||
spark = df.sql_ctx.sparkSession | ||
|
||
# Temporarily override broadcast join threshold to ensure | ||
# that the query plan is a SortMergeJoin | ||
current_broadcast_setting = spark.conf.get("spark.sql.autoBroadcastJoinThreshold") | ||
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") | ||
|
||
link_type = splink_settings["link_type"] | ||
if link_type == "dedupe_only": | ||
source_dataset_col = None | ||
else: | ||
source_dataset_col = splink_settings["source_dataset_column_name"] | ||
|
||
unique_id_col = splink_settings["unique_id_column_name"] | ||
|
||
join_filter = _sql_gen_where_condition(link_type, source_dataset_col, unique_id_col) | ||
|
||
df.createOrReplaceTempView("df") | ||
sql = f""" | ||
select * from | ||
df as l | ||
inner join df as r | ||
on {blocking_rule} {join_filter} | ||
""" | ||
|
||
df = spark.sql(sql) | ||
qe = df._jdf.queryExecution() | ||
sp = qe.sparkPlan() | ||
treestring = sp.treeString() | ||
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", current_broadcast_setting) | ||
return treestring | ||
|
||
|
||
def _get_join_line(queryplan_text): | ||
lines = queryplan_text.splitlines() | ||
return lines[0] | ||
|
||
|
||
def _parse_join_line_sortmergejoin(join_line): | ||
parts = _split_by_commas_ignoring_within_brackets(join_line) | ||
hash_columns = _get_hash_columns(join_line) | ||
|
||
if len(parts) == 4: | ||
post_join_filters = _remove_col_ids(parts[3]) | ||
else: | ||
post_join_filters = None | ||
|
||
return { | ||
"join_strategy": "SortMergeJoin", | ||
"join_type": "Inner", | ||
"join_hashpartition_columns_left": hash_columns["left"], | ||
"join_hashpartition_columns_right": hash_columns["right"], | ||
"post_join_filters": post_join_filters, | ||
} | ||
|
||
|
||
def _parse_join_line_cartesian(join_line): | ||
fil = _extract_text_from_within_brackets_balanced(join_line, ["(", ")"]) | ||
fil = _remove_col_ids(fil) | ||
return { | ||
"join_strategy": "Cartesian", | ||
"join_type": "Cartesian", | ||
"join_hashpartition_columns_left": [], | ||
"join_hashpartition_columns_right": [], | ||
"post_join_filters": fil, | ||
} | ||
|
||
|
||
def _parse_join_line(join_line): | ||
if "SortMergeJoin" in join_line: | ||
return _parse_join_line_sortmergejoin(join_line) | ||
if "Cartesian" in join_line: | ||
return _parse_join_line_cartesian(join_line) | ||
|
||
|
||
def _split_by_commas_ignoring_within_brackets(input_str): | ||
counter = 0 | ||
captured_strings = [] | ||
captured_string = "" | ||
for i in input_str: | ||
captured_string += i | ||
|
||
if i in ("(", "[", "{"): | ||
counter += 1 | ||
|
||
if i in (")", "]", "}"): | ||
counter -= 1 | ||
if counter == 0 and i == ",": | ||
captured_string = captured_string[:-1] | ||
captured_strings.append(captured_string) | ||
captured_string = "" | ||
captured_strings.append(captured_string) | ||
captured_strings = [s.strip() for s in captured_strings] | ||
return captured_strings | ||
|
||
|
||
def _extract_text_from_within_brackets_balanced(input_str, bracket_type=["[", "]"]): | ||
bracket_counter = 0 | ||
|
||
start_bracket = bracket_type[0] | ||
end_bracket = bracket_type[1] | ||
|
||
captured_string = "" | ||
|
||
if start_bracket not in input_str: | ||
return None | ||
|
||
for i in input_str: | ||
|
||
if i == start_bracket: | ||
bracket_counter += 1 | ||
|
||
if i == end_bracket: | ||
bracket_counter -= 1 | ||
if bracket_counter == 0: | ||
break | ||
if bracket_counter > 0: | ||
captured_string += i | ||
|
||
return captured_string[1:] | ||
|
||
|
||
def _remove_col_ids(input_str): | ||
return re.sub(r"#\d{1,6}[L]?", "", input_str) | ||
|
||
|
||
def get_total_comparisons_from_join_columns_that_will_be_hash_partitioned( | ||
df, join_cols: list | ||
): | ||
"""Compute the total number of records that will be generated | ||
by an inner joi on the join_cols | ||
For instance, if join_cols = ["first_name", "dmetaphone(surname)"] | ||
will compute the total number of comparisons generated by the blocking rule: | ||
["l.first_name = r.first_name and dmetaphone(l.surname) = dmetaphone(r.surname)"] | ||
Args: | ||
df (DataFrame): Input dataframe | ||
join_cols (list): List of blocking columns e.g. ["first_name", "dmetaphone(surname)"] | ||
Returns: | ||
integer: Number of comparisons generated by the blocking rule | ||
""" | ||
|
||
sel_expr = ", ".join(join_cols) | ||
concat_expr = f"concat({sel_expr})" | ||
spark = df.sql_ctx.sparkSession | ||
|
||
df.createOrReplaceTempView("df") | ||
|
||
sql = f""" | ||
with | ||
block_groups as ( | ||
SELECT {concat_expr}, {sel_expr}, | ||
count(*) * count(*) as num_comparisons | ||
FROM df | ||
where {concat_expr} is not null | ||
GROUP BY {concat_expr}, {sel_expr} | ||
) | ||
select sum(num_comparisons) as total_comparisons | ||
from block_groups | ||
""" | ||
|
||
return spark.sql(sql).collect()[0]["total_comparisons"] | ||
|
||
|
||
def generate_and_count_num_comparisons_from_blocking_rule( | ||
df, blocking_rule, splink_settings | ||
): | ||
spark = df.sql_ctx.sparkSession | ||
splink_settings = deepcopy(splink_settings) | ||
splink_settings["blocking_rules"] = [blocking_rule] | ||
|
||
df = block_using_rules(splink_settings, df, spark) | ||
|
||
return df.count() | ||
|
||
|
||
def _get_hash_columns(smj_line): | ||
comma_split = _split_by_commas_ignoring_within_brackets(smj_line) | ||
join_left = _extract_text_from_within_brackets_balanced(comma_split[0]) | ||
join_left = _remove_col_ids(join_left) | ||
join_left = _split_by_commas_ignoring_within_brackets(join_left) | ||
|
||
join_right = _extract_text_from_within_brackets_balanced(comma_split[1]) | ||
join_right = _remove_col_ids(join_right) | ||
join_right = _split_by_commas_ignoring_within_brackets(join_right) | ||
return {"left": join_left, "right": join_right} | ||
|
||
|
||
def _get_largest_group(df, join_cols): | ||
|
||
sel_expr = ", ".join(join_cols) | ||
concat_expr = f"concat({sel_expr})" | ||
concat_ws_expr = f"concat_ws('|', {sel_expr})" | ||
spark = df.sql_ctx.sparkSession | ||
|
||
df.createOrReplaceTempView("df") | ||
|
||
sql = f""" | ||
SELECT {concat_expr} as concat_expr, {concat_ws_expr} as concat_ws_expr, {sel_expr}, | ||
count(*) * count(*) as num_comparisons | ||
FROM df | ||
where {concat_expr} is not null | ||
GROUP BY {concat_expr}, {sel_expr} | ||
ORDER BY count(*) desc | ||
limit 1 | ||
""" | ||
|
||
collected = spark.sql(sql).collect() | ||
largest_group_comparisons = collected[0]["num_comparisons"] | ||
largest_group_concat = collected[0]["concat_ws_expr"] | ||
return { | ||
"largest_group_expr": largest_group_concat, | ||
"num_comparisons_generated_in_largest_group": largest_group_comparisons, | ||
} | ||
|
||
|
||
def analyse_blocking_rule( | ||
df, | ||
blocking_rule, | ||
splink_settings, | ||
compute_exact_comparisons=False, | ||
compute_exact_limit=1e9, | ||
compute_largest_group=True, | ||
): | ||
spark = df.sql_ctx.sparkSession | ||
df.createOrReplaceTempView("df") | ||
splink_settings = deepcopy(splink_settings) | ||
splink_settings["additional_columns_to_retain"] = [] | ||
queryplan_text = _get_queryplan_text(df, blocking_rule, splink_settings) | ||
results_dict = _parse_join_line(_get_join_line(queryplan_text)) | ||
|
||
# Message detailing any metrics that could not be computed | ||
metric_message = "" | ||
|
||
if results_dict["join_strategy"] == "SortMergeJoin": | ||
|
||
jcl = results_dict["join_hashpartition_columns_left"] | ||
jcr = results_dict["join_hashpartition_columns_right"] | ||
balanced_join = jcl == jcr | ||
if balanced_join: | ||
total_comparisons_generated = ( | ||
get_total_comparisons_from_join_columns_that_will_be_hash_partitioned( | ||
df, jcl | ||
) | ||
) | ||
results_dict[ | ||
"total_comparisons_generated_before_filter_applied" | ||
] = total_comparisons_generated | ||
if compute_largest_group: | ||
group_stats = _get_largest_group(df, jcl) | ||
results_dict = {**results_dict, **group_stats} | ||
else: | ||
metric_message += "Join columns include inversions, so total comparisons and largest group could not be computed efficiently.\n" | ||
results_dict["total_comparisons_generated_before_filter_applied"] = None | ||
total_comparisons_generated = 1 | ||
if compute_largest_group: | ||
results_dict["largest_group_expr"] = None | ||
results_dict["num_comparisons_generated_in_largest_group"] = None | ||
|
||
if results_dict["join_strategy"] == "Cartesian": | ||
raw_count = df.count() | ||
total_comparisons_generated = raw_count * raw_count | ||
results_dict[ | ||
"total_comparisons_generated_before_filter_applied" | ||
] = total_comparisons_generated | ||
if compute_largest_group: | ||
results_dict["largest_group_expr"] = None | ||
results_dict["num_comparisons_generated_in_largest_group"] = None | ||
metric_message += "Blocking rule results in Cartesian join so largest groups not applicable.\n" | ||
|
||
if compute_exact_comparisons and total_comparisons_generated < compute_exact_limit: | ||
splink_settings["blocking_rules"] = [blocking_rule] | ||
blocked = block_using_rules(splink_settings, df, spark) | ||
total_with_filters = blocked.count() | ||
results_dict[ | ||
"total_comparisons_generated_after_filters_applied" | ||
] = total_with_filters | ||
else: | ||
results_dict["total_comparisons_generated_after_filters_applied"] = None | ||
if not compute_exact_comparisons: | ||
metric_message += ( | ||
"Exact number of filtered comparisons not computed." | ||
" Set compute_exact_comparisons=True to compute.\n" | ||
) | ||
if ( | ||
compute_exact_comparisons | ||
and total_comparisons_generated > compute_exact_limit | ||
): | ||
metric_message += ( | ||
f"Exact number of filtered comparisons not computed because total number" | ||
f" of comparisons ({total_comparisons_generated:,.0f}) exceeds limit" | ||
f" ({compute_exact_limit:,.0f}). Set compute_exact_limit to a larger value to compute." | ||
) | ||
results_dict["message"] = metric_message | ||
|
||
results_order = [ | ||
"total_comparisons_generated_before_filter_applied", | ||
"total_comparisons_generated_after_filters_applied", | ||
"largest_group_expr", | ||
"num_comparisons_generated_in_largest_group", | ||
"message", | ||
"join_strategy", | ||
"join_type", | ||
"join_hashpartition_columns_left", | ||
"join_hashpartition_columns_right", | ||
"post_join_filters", | ||
] | ||
|
||
results = {} | ||
for key in results_order: | ||
results[key] = results_dict[key] | ||
|
||
return results |