Skip to content

Commit

Permalink
tidy up code and improve messages
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Dec 22, 2021
1 parent eb72ce3 commit e8eb566
Showing 1 changed file with 125 additions and 127 deletions.
252 changes: 125 additions & 127 deletions splink/analyse_blocking_rule.py
@@ -1,16 +1,15 @@
from copy import deepcopy
import re
from splink.blocking import _sql_gen_where_condition, block_using_rules

from copy import deepcopy
from splink.blocking import _sql_gen_where_condition, block_using_rules

from pyspark.sql import DataFrame

def get_queryplan_text(df, blocking_rule, splink_settings):
def _get_queryplan_text(df, blocking_rule, splink_settings):
spark = df.sql_ctx.sparkSession

# Temporarily override broadcast join threshold to ensure
# that the query plan is a SortMergeJoin
current_setting = spark.conf.get("spark.sql.autoBroadcastJoinThreshold")
current_broadcast_setting = spark.conf.get("spark.sql.autoBroadcastJoinThreshold")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")

link_type = splink_settings["link_type"]
Expand All @@ -23,7 +22,7 @@ def get_queryplan_text(df, blocking_rule, splink_settings):

join_filter = _sql_gen_where_condition(link_type, source_dataset_col, unique_id_col)


df.createOrReplaceTempView("df")
sql = f"""
select * from
df as l
Expand All @@ -34,26 +33,25 @@ def get_queryplan_text(df, blocking_rule, splink_settings):
df = spark.sql(sql)
qe = df._jdf.queryExecution()
sp = qe.sparkPlan()
treestring= sp.treeString()
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", current_setting)
treestring = sp.treeString()
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", current_broadcast_setting)
return treestring



def get_join_line(queryplan_text):
def _get_join_line(queryplan_text):
lines = queryplan_text.splitlines()
return lines[0]

def parse_join_line_sortmergejoin(join_line):
parts = split_by_commas_ignoring_within_brackets(join_line)
hash_columns = get_hash_columns(join_line)

def _parse_join_line_sortmergejoin(join_line):
parts = _split_by_commas_ignoring_within_brackets(join_line)
hash_columns = _get_hash_columns(join_line)

if len(parts) == 4:
post_join_filters = remove_col_ids(parts[3])
post_join_filters = _remove_col_ids(parts[3])
else:
post_join_filters = None


return {
"join_strategy": "SortMergeJoin",
"join_type": "Inner",
Expand All @@ -62,9 +60,10 @@ def parse_join_line_sortmergejoin(join_line):
"post_join_filters": post_join_filters,
}

def parse_join_line_cartesian(join_line):
fil = extract_text_from_within_brackets_balanced(join_line, ["(", ")"])
fil = remove_col_ids(fil)

def _parse_join_line_cartesian(join_line):
fil = _extract_text_from_within_brackets_balanced(join_line, ["(", ")"])
fil = _remove_col_ids(fil)
return {
"join_strategy": "Cartesian",
"join_type": "Cartesian",
Expand All @@ -73,15 +72,15 @@ def parse_join_line_cartesian(join_line):
"post_join_filters": fil,
}

def parse_join_line(join_line):

def _parse_join_line(join_line):
if "SortMergeJoin" in join_line:
return parse_join_line_sortmergejoin(join_line)
return _parse_join_line_sortmergejoin(join_line)
if "Cartesian" in join_line:
return parse_join_line_cartesian(join_line)

return _parse_join_line_cartesian(join_line)


def split_by_commas_ignoring_within_brackets(input_str):
def _split_by_commas_ignoring_within_brackets(input_str):
counter = 0
captured_strings = []
captured_string = ""
Expand All @@ -102,7 +101,7 @@ def split_by_commas_ignoring_within_brackets(input_str):
return captured_strings


def extract_text_from_within_brackets_balanced(input_str, bracket_type=["[", "]"]):
def _extract_text_from_within_brackets_balanced(input_str, bracket_type=["[", "]"]):
bracket_counter = 0

start_bracket = bracket_type[0]
Expand All @@ -128,46 +127,13 @@ def extract_text_from_within_brackets_balanced(input_str, bracket_type=["[", "]"
return captured_string[1:]


def remove_col_ids(input_str):
def _remove_col_ids(input_str):
return re.sub(r"#\d{1,6}[L]?", "", input_str)


def _sorted_array(df: DataFrame, field_list: list):
"""Create a new field called _sorted_array
containing a sorted array populated with the values from field_list
Args:
df (DataFrame): Input dataframe
field_list (list): List of fields e.g. ["first_name", "surname"]
Returns:
df, with a new field called _sorted_array
"""

df = df.withColumn("_sorted_array", F.array(*field_list))
df = df.withColumn("_sorted_array", F.sort_array("_sorted_array"))
return df


def _sort_fields(df: DataFrame, field_list_to_sort: list):
"""
Take the fields in field_list and derive new fields
with the same values but sorted alphabetically.
The derieved fields are named __sorted_{field}
Args:
df (DataFrame): Input dataframe
field_list_to_sort (list): list of fields e.g. ["first_name", "surname"]
Returns:
DataFrame: dataframe with new fields __sorted_{field}
"""

df = _sorted_array(df, field_list_to_sort)

for i, field in enumerate(field_list_to_sort):
df = df.withColumn(f"__sorted_{field}", F.col("_sorted_array")[i])
df = df.drop((f"_sorted_array"))
return df


def get_total_comparisons_from_join_columns_that_will_be_hash_partitioned(df, join_cols: list):
def get_total_comparisons_from_join_columns_that_will_be_hash_partitioned(
df, join_cols: list
):
"""Compute the total number of records that will be generated
by an inner joi on the join_cols
For instance, if join_cols = ["first_name", "dmetaphone(surname)"]
Expand Down Expand Up @@ -201,7 +167,10 @@ def get_total_comparisons_from_join_columns_that_will_be_hash_partitioned(df, jo

return spark.sql(sql).collect()[0]["total_comparisons"]

def generate_and_count_num_comparisons_from_blocking_rule(df, blocking_rule, splink_settings):

def generate_and_count_num_comparisons_from_blocking_rule(
df, blocking_rule, splink_settings
):
spark = df.sql_ctx.sparkSession
splink_settings = deepcopy(splink_settings)
splink_settings["blocking_rules"] = [blocking_rule]
Expand All @@ -211,111 +180,140 @@ def generate_and_count_num_comparisons_from_blocking_rule(df, blocking_rule, spl
return df.count()


def get_hash_columns(smj_line):
comma_split = split_by_commas_ignoring_within_brackets(smj_line)
join_left = extract_text_from_within_brackets_balanced(comma_split[0])
join_left = remove_col_ids(join_left)
join_left = split_by_commas_ignoring_within_brackets(join_left)
def _get_hash_columns(smj_line):
comma_split = _split_by_commas_ignoring_within_brackets(smj_line)
join_left = _extract_text_from_within_brackets_balanced(comma_split[0])
join_left = _remove_col_ids(join_left)
join_left = _split_by_commas_ignoring_within_brackets(join_left)

join_right = extract_text_from_within_brackets_balanced(comma_split[1] )
join_right = remove_col_ids(join_right)
join_right = split_by_commas_ignoring_within_brackets(join_right)
join_right = _extract_text_from_within_brackets_balanced(comma_split[1])
join_right = _remove_col_ids(join_right)
join_right = _split_by_commas_ignoring_within_brackets(join_right)
return {"left": join_left, "right": join_right}



def get_num_comparisons_from_blocking_rule(df, blocking_rule, splink_settings):
# Use n strategy if the blocking rule is symmetric
# Use n^2 strategy when inversions are present
# Reults as dict with num comparisons, hash columns, filter columns, join strategy reported

smj_line = get_sortmergejoin_query_plan_text(df, blocking_rule)
hc = get_hash_columns(smj_line)
left_hash_cols = hc["left"]
right_hash_cols = hc["right"]

if left_hash_cols == right_hash_cols:
return get_total_comparisons_from_join_columns_that_will_be_hashed(df, left_hash_cols)
else:
return generate_and_count_num_comparisons_from_blocking_rule(df, blocking_rule, splink_settings)



def get_largest_group(df, join_cols):
def _get_largest_group(df, join_cols):

sel_expr = ", ".join(join_cols)
concat_expr = f"concat({sel_expr})"
concat_ws_expr = f"concat_ws('|', {sel_expr})"
spark = df.sql_ctx.sparkSession

df.createOrReplaceTempView("df")

sql = f"""
SELECT {concat_expr} as concat_expr, {sel_expr},
SELECT {concat_expr} as concat_expr, {concat_ws_expr} as concat_ws_expr, {sel_expr},
count(*) * count(*) as num_comparisons
FROM df
where {concat_expr} is not null
GROUP BY {concat_expr}, {sel_expr}
ORDER BY count(*) desc
limit 1
"""

collected = spark.sql(sql).collect()
largest_group_comparisons = collected[0]["num_comparisons"]
largest_group_concat = collected[0]["concat_expr"]
largest_group_concat = collected[0]["concat_ws_expr"]
return {
"largest_group_expr": largest_group_concat,
"num_comparisons_generated_in_largest_group": largest_group_comparisons,

}


def analyse_blocking_rule(df, blocking_rule, splink_settings, compute_exact_comparisons=False, compute_exact_limit=1e9, compute_largest_group=True):
def analyse_blocking_rule(
df,
blocking_rule,
splink_settings,
compute_exact_comparisons=False,
compute_exact_limit=1e9,
compute_largest_group=True,
):
spark = df.sql_ctx.sparkSession
df.createOrReplaceTempView("df")
splink_settings = deepcopy(splink_settings)
splink_settings["additional_columns_to_retain"] = []
queryplan_text = get_queryplan_text(df, blocking_rule, splink_settings)
parsed = parse_join_line(get_join_line(queryplan_text))
queryplan_text = _get_queryplan_text(df, blocking_rule, splink_settings)
results_dict = _parse_join_line(_get_join_line(queryplan_text))

if parsed["join_strategy"] == 'SortMergeJoin':
# Message detailing any metrics that could not be computed
metric_message = ""

jcl = parsed["join_hashpartition_columns_left"]
jcr = parsed["join_hashpartition_columns_right"]
balanced_join = (jcl == jcr)
if results_dict["join_strategy"] == "SortMergeJoin":

jcl = results_dict["join_hashpartition_columns_left"]
jcr = results_dict["join_hashpartition_columns_right"]
balanced_join = jcl == jcr
if balanced_join:
total_comparisons_generated = get_total_comparisons_from_join_columns_that_will_be_hash_partitioned(df, jcl)
parsed["total_comparisons_generated_before_filter_applied"] = total_comparisons_generated
if compute_largest_group:
group_stats = get_largest_group(df, jcl)
parsed = {**parsed, **group_stats}
total_comparisons_generated = (
get_total_comparisons_from_join_columns_that_will_be_hash_partitioned(
df, jcl
)
)
results_dict[
"total_comparisons_generated_before_filter_applied"
] = total_comparisons_generated
if compute_largest_group:
group_stats = _get_largest_group(df, jcl)
results_dict = {**results_dict, **group_stats}
else:
msg = "Join columns include invesions, so cannot be computed"
parsed["total_comparisons_generated_before_filter_applied"] = msg
metric_message += "Join columns include inversions, so total comparisons and largest group could not be computed efficiently.\n"
results_dict["total_comparisons_generated_before_filter_applied"] = None
total_comparisons_generated = 1
if compute_largest_group:
parsed["largest_group_expr"]= msg
parsed["num_comparisons_generated_in_largest_group"]: msg

if compute_largest_group:
results_dict["largest_group_expr"] = None
results_dict["num_comparisons_generated_in_largest_group"] = None

if parsed["join_strategy"] == 'Cartesian':
if results_dict["join_strategy"] == "Cartesian":
raw_count = df.count()
total_comparisons_generated = raw_count * raw_count
parsed["total_comparisons_generated_before_filter_applied"] = total_comparisons_generated
if compute_largest_group:
parsed["largest_group_expr"]= "Cartesian join so not groups"
parsed["num_comparisons_generated_in_largest_group"] = "Cartesian join so not groups"

if compute_exact_comparisons and total_comparisons_generated<compute_exact_limit:
results_dict[
"total_comparisons_generated_before_filter_applied"
] = total_comparisons_generated
if compute_largest_group:
results_dict["largest_group_expr"] = None
results_dict["num_comparisons_generated_in_largest_group"] = None
metric_message += "Blocking rule results in Cartesian join so largest groups not applicable.\n"

if compute_exact_comparisons and total_comparisons_generated < compute_exact_limit:
splink_settings["blocking_rules"] = [blocking_rule]
blocked = block_using_rules(splink_settings, df, spark)
total_with_filters = blocked.count()
parsed["total_comparisons_generated_after_filters_applied"] = total_with_filters
results_dict[
"total_comparisons_generated_after_filters_applied"
] = total_with_filters
else:
parsed["total_comparisons_generated_after_filters_applied"] = "Not computed, set compute_exact_comparisons=True to compute."

return parsed


results_dict["total_comparisons_generated_after_filters_applied"] = None
if not compute_exact_comparisons:
metric_message += (
"Exact number of filtered comparisons not computed."
" Set compute_exact_comparisons=True to compute.\n"
)
if (
compute_exact_comparisons
and total_comparisons_generated > compute_exact_limit
):
metric_message += (
f"Exact number of filtered comparisons not computed because total number"
f" of comparisons ({total_comparisons_generated:,.0f}) exceeds limit"
f" ({compute_exact_limit:,.0f}). Set compute_exact_limit to a larger value to compute."
)
results_dict["message"] = metric_message

results_order = [
"total_comparisons_generated_before_filter_applied",
"total_comparisons_generated_after_filters_applied",
"largest_group_expr",
"num_comparisons_generated_in_largest_group",
"message",
"join_strategy",
"join_type",
"join_hashpartition_columns_left",
"join_hashpartition_columns_right",
"post_join_filters",
]

results = {}
for key in results_order:
results[key] = results_dict[key]

return results

0 comments on commit e8eb566

Please sign in to comment.