Skip to content

Commit

Permalink
Merge e8eb566 into 4a1704a
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Dec 22, 2021
2 parents 4a1704a + e8eb566 commit a589e43
Showing 1 changed file with 319 additions and 0 deletions.
319 changes: 319 additions & 0 deletions splink/analyse_blocking_rule.py
@@ -0,0 +1,319 @@
from copy import deepcopy
import re

from splink.blocking import _sql_gen_where_condition, block_using_rules


def _get_queryplan_text(df, blocking_rule, splink_settings):
spark = df.sql_ctx.sparkSession

# Temporarily override broadcast join threshold to ensure
# that the query plan is a SortMergeJoin
current_broadcast_setting = spark.conf.get("spark.sql.autoBroadcastJoinThreshold")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")

link_type = splink_settings["link_type"]
if link_type == "dedupe_only":
source_dataset_col = None
else:
source_dataset_col = splink_settings["source_dataset_column_name"]

unique_id_col = splink_settings["unique_id_column_name"]

join_filter = _sql_gen_where_condition(link_type, source_dataset_col, unique_id_col)

df.createOrReplaceTempView("df")
sql = f"""
select * from
df as l
inner join df as r
on {blocking_rule} {join_filter}
"""

df = spark.sql(sql)
qe = df._jdf.queryExecution()
sp = qe.sparkPlan()
treestring = sp.treeString()
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", current_broadcast_setting)
return treestring


def _get_join_line(queryplan_text):
lines = queryplan_text.splitlines()
return lines[0]


def _parse_join_line_sortmergejoin(join_line):
parts = _split_by_commas_ignoring_within_brackets(join_line)
hash_columns = _get_hash_columns(join_line)

if len(parts) == 4:
post_join_filters = _remove_col_ids(parts[3])
else:
post_join_filters = None

return {
"join_strategy": "SortMergeJoin",
"join_type": "Inner",
"join_hashpartition_columns_left": hash_columns["left"],
"join_hashpartition_columns_right": hash_columns["right"],
"post_join_filters": post_join_filters,
}


def _parse_join_line_cartesian(join_line):
fil = _extract_text_from_within_brackets_balanced(join_line, ["(", ")"])
fil = _remove_col_ids(fil)
return {
"join_strategy": "Cartesian",
"join_type": "Cartesian",
"join_hashpartition_columns_left": [],
"join_hashpartition_columns_right": [],
"post_join_filters": fil,
}


def _parse_join_line(join_line):
if "SortMergeJoin" in join_line:
return _parse_join_line_sortmergejoin(join_line)
if "Cartesian" in join_line:
return _parse_join_line_cartesian(join_line)


def _split_by_commas_ignoring_within_brackets(input_str):
counter = 0
captured_strings = []
captured_string = ""
for i in input_str:
captured_string += i

if i in ("(", "[", "{"):
counter += 1

if i in (")", "]", "}"):
counter -= 1
if counter == 0 and i == ",":
captured_string = captured_string[:-1]
captured_strings.append(captured_string)
captured_string = ""
captured_strings.append(captured_string)
captured_strings = [s.strip() for s in captured_strings]
return captured_strings


def _extract_text_from_within_brackets_balanced(input_str, bracket_type=["[", "]"]):
bracket_counter = 0

start_bracket = bracket_type[0]
end_bracket = bracket_type[1]

captured_string = ""

if start_bracket not in input_str:
return None

for i in input_str:

if i == start_bracket:
bracket_counter += 1

if i == end_bracket:
bracket_counter -= 1
if bracket_counter == 0:
break
if bracket_counter > 0:
captured_string += i

return captured_string[1:]


def _remove_col_ids(input_str):
return re.sub(r"#\d{1,6}[L]?", "", input_str)


def get_total_comparisons_from_join_columns_that_will_be_hash_partitioned(
df, join_cols: list
):
"""Compute the total number of records that will be generated
by an inner joi on the join_cols
For instance, if join_cols = ["first_name", "dmetaphone(surname)"]
will compute the total number of comparisons generated by the blocking rule:
["l.first_name = r.first_name and dmetaphone(l.surname) = dmetaphone(r.surname)"]
Args:
df (DataFrame): Input dataframe
join_cols (list): List of blocking columns e.g. ["first_name", "dmetaphone(surname)"]
Returns:
integer: Number of comparisons generated by the blocking rule
"""

sel_expr = ", ".join(join_cols)
concat_expr = f"concat({sel_expr})"
spark = df.sql_ctx.sparkSession

df.createOrReplaceTempView("df")

sql = f"""
with
block_groups as (
SELECT {concat_expr}, {sel_expr},
count(*) * count(*) as num_comparisons
FROM df
where {concat_expr} is not null
GROUP BY {concat_expr}, {sel_expr}
)
select sum(num_comparisons) as total_comparisons
from block_groups
"""

return spark.sql(sql).collect()[0]["total_comparisons"]


def generate_and_count_num_comparisons_from_blocking_rule(
df, blocking_rule, splink_settings
):
spark = df.sql_ctx.sparkSession
splink_settings = deepcopy(splink_settings)
splink_settings["blocking_rules"] = [blocking_rule]

df = block_using_rules(splink_settings, df, spark)

return df.count()


def _get_hash_columns(smj_line):
comma_split = _split_by_commas_ignoring_within_brackets(smj_line)
join_left = _extract_text_from_within_brackets_balanced(comma_split[0])
join_left = _remove_col_ids(join_left)
join_left = _split_by_commas_ignoring_within_brackets(join_left)

join_right = _extract_text_from_within_brackets_balanced(comma_split[1])
join_right = _remove_col_ids(join_right)
join_right = _split_by_commas_ignoring_within_brackets(join_right)
return {"left": join_left, "right": join_right}


def _get_largest_group(df, join_cols):

sel_expr = ", ".join(join_cols)
concat_expr = f"concat({sel_expr})"
concat_ws_expr = f"concat_ws('|', {sel_expr})"
spark = df.sql_ctx.sparkSession

df.createOrReplaceTempView("df")

sql = f"""
SELECT {concat_expr} as concat_expr, {concat_ws_expr} as concat_ws_expr, {sel_expr},
count(*) * count(*) as num_comparisons
FROM df
where {concat_expr} is not null
GROUP BY {concat_expr}, {sel_expr}
ORDER BY count(*) desc
limit 1
"""

collected = spark.sql(sql).collect()
largest_group_comparisons = collected[0]["num_comparisons"]
largest_group_concat = collected[0]["concat_ws_expr"]
return {
"largest_group_expr": largest_group_concat,
"num_comparisons_generated_in_largest_group": largest_group_comparisons,
}


def analyse_blocking_rule(
df,
blocking_rule,
splink_settings,
compute_exact_comparisons=False,
compute_exact_limit=1e9,
compute_largest_group=True,
):
spark = df.sql_ctx.sparkSession
df.createOrReplaceTempView("df")
splink_settings = deepcopy(splink_settings)
splink_settings["additional_columns_to_retain"] = []
queryplan_text = _get_queryplan_text(df, blocking_rule, splink_settings)
results_dict = _parse_join_line(_get_join_line(queryplan_text))

# Message detailing any metrics that could not be computed
metric_message = ""

if results_dict["join_strategy"] == "SortMergeJoin":

jcl = results_dict["join_hashpartition_columns_left"]
jcr = results_dict["join_hashpartition_columns_right"]
balanced_join = jcl == jcr
if balanced_join:
total_comparisons_generated = (
get_total_comparisons_from_join_columns_that_will_be_hash_partitioned(
df, jcl
)
)
results_dict[
"total_comparisons_generated_before_filter_applied"
] = total_comparisons_generated
if compute_largest_group:
group_stats = _get_largest_group(df, jcl)
results_dict = {**results_dict, **group_stats}
else:
metric_message += "Join columns include inversions, so total comparisons and largest group could not be computed efficiently.\n"
results_dict["total_comparisons_generated_before_filter_applied"] = None
total_comparisons_generated = 1
if compute_largest_group:
results_dict["largest_group_expr"] = None
results_dict["num_comparisons_generated_in_largest_group"] = None

if results_dict["join_strategy"] == "Cartesian":
raw_count = df.count()
total_comparisons_generated = raw_count * raw_count
results_dict[
"total_comparisons_generated_before_filter_applied"
] = total_comparisons_generated
if compute_largest_group:
results_dict["largest_group_expr"] = None
results_dict["num_comparisons_generated_in_largest_group"] = None
metric_message += "Blocking rule results in Cartesian join so largest groups not applicable.\n"

if compute_exact_comparisons and total_comparisons_generated < compute_exact_limit:
splink_settings["blocking_rules"] = [blocking_rule]
blocked = block_using_rules(splink_settings, df, spark)
total_with_filters = blocked.count()
results_dict[
"total_comparisons_generated_after_filters_applied"
] = total_with_filters
else:
results_dict["total_comparisons_generated_after_filters_applied"] = None
if not compute_exact_comparisons:
metric_message += (
"Exact number of filtered comparisons not computed."
" Set compute_exact_comparisons=True to compute.\n"
)
if (
compute_exact_comparisons
and total_comparisons_generated > compute_exact_limit
):
metric_message += (
f"Exact number of filtered comparisons not computed because total number"
f" of comparisons ({total_comparisons_generated:,.0f}) exceeds limit"
f" ({compute_exact_limit:,.0f}). Set compute_exact_limit to a larger value to compute."
)
results_dict["message"] = metric_message

results_order = [
"total_comparisons_generated_before_filter_applied",
"total_comparisons_generated_after_filters_applied",
"largest_group_expr",
"num_comparisons_generated_in_largest_group",
"message",
"join_strategy",
"join_type",
"join_hashpartition_columns_left",
"join_hashpartition_columns_right",
"post_join_filters",
]

results = {}
for key in results_order:
results[key] = results_dict[key]

return results

0 comments on commit a589e43

Please sign in to comment.