diff --git a/splink/cost_of_blocking_rules.py b/splink/cost_of_blocking_rules.py index 770e930092..8117a67e92 100644 --- a/splink/cost_of_blocking_rules.py +++ b/splink/cost_of_blocking_rules.py @@ -1,8 +1,6 @@ import logging from typing import Dict, List, Union -import pandas as pd - logger = logging.getLogger(__name__) @@ -51,9 +49,9 @@ def calculate_field_freedom_cost(combination_of_brs: List[Dict]) -> int: def calculate_cost_of_combination_of_brs( - br_combination: pd.DataFrame, + br_combination: List[Dict], max_comparison_count: int, - complexity_weight: Union[int, float] = 1, + num_equi_join_weight: Union[int, float] = 1, field_freedom_weight: Union[int, float] = 1, num_brs_weight: Union[int, float] = 1, num_comparison_weight: Union[int, float] = 1, @@ -61,15 +59,15 @@ def calculate_cost_of_combination_of_brs( """ Calculates the cost for a given combination of blocking rules. - The cost is a weighted sum of the complexity of the rules, the count of rules, - the number of fields that are allowed to vary, and the number of rows. + The cost is a weighted sum of the number of equi joins in the rules, the count of + rules, the number of fields that are allowed to vary, and the number of rows. Args: - br_combination (pd.DataFrame): The combination of rows outputted by + br_combination (List[Dict]): The combination of rows outputted by find_blocking_rules_below_threshold_comparison_count. max_comparison_count (int): The maximum comparison count amongst the rules. This is needed to normalise the cost of more or fewer comparison rows. - complexity_weight (Union[int, float], optional): The weight for complexity. + num_equi_join_weight (Union[int, float], optional): The weight for num_equi_join Defaults to 1. field_freedom_weight (Union[int, float], optional): The weight for field freedom. Defaults to 1. @@ -81,7 +79,6 @@ def calculate_cost_of_combination_of_brs( Returns: dict: The calculated cost and individual component costs. """ - br_combination = br_combination.to_dict(orient="records") num_equi_join_cost = sum(row["num_equi_joins"] for row in br_combination) total_row_count = sum(row["comparison_count"] for row in br_combination) @@ -92,7 +89,7 @@ def calculate_cost_of_combination_of_brs( field_freedom_cost = calculate_field_freedom_cost(br_combination) num_brs_cost = len(br_combination) - num_equi_join_cost_weighted = complexity_weight * num_equi_join_cost + num_equi_join_cost_weighted = num_equi_join_weight * num_equi_join_cost field_freedom_cost_weighted = field_freedom_weight * field_freedom_cost num_brs_cost_weighted = num_brs_weight * num_brs_cost num_comparison_rows_cost_weighted = num_comparison_weight * normalised_row_count diff --git a/splink/find_brs_with_comparison_counts_below_threshold.py b/splink/find_brs_with_comparison_counts_below_threshold.py index b151ab72c4..1d71125606 100644 --- a/splink/find_brs_with_comparison_counts_below_threshold.py +++ b/splink/find_brs_with_comparison_counts_below_threshold.py @@ -81,7 +81,7 @@ def _generate_blocking_rule( block_on = block_on_module.block_on if len(cols_as_string) == 0: - return "1 = 1" + return block_on("1") br = block_on(cols_as_string) @@ -128,7 +128,7 @@ def _search_tree_for_blocking_rules_below_threshold_count( 'blocking_columns_sanitised':['first_name'], 'splink_blocking_rule':', comparison_count':4827, - 'complexity':1, + 'num_equi_join':1, '__fixed__first_name':1, '__fixed__surname':0, '__fixed__dob':0, diff --git a/splink/linker.py b/splink/linker.py index 67f1476cf0..8731f2f3e8 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -65,6 +65,9 @@ from .em_training_session import EMTrainingSession from .estimate_u import estimate_u_values from .exceptions import SplinkDeprecated, SplinkException +from .find_brs_with_comparison_counts_below_threshold import ( + find_blocking_rules_below_threshold_comparison_count, +) from .find_matches_to_new_records import add_unique_id_and_source_dataset_cols_if_needed from .labelling_tool import ( generate_labelling_tool_comparisons, @@ -86,6 +89,7 @@ prob_to_bayes_factor, ) from .missingness import completeness_data, missingness_data +from .optimise_cost_of_brs import suggest_blocking_rules from .pipeline import SQLPipeline from .predict import predict_from_comparison_vectors_sqls from .profile_data import profile_columns @@ -3751,3 +3755,172 @@ def _remove_splinkdataframe_from_cache(self, splink_dataframe: SplinkDataFrame): for k in keys_to_delete: del self._intermediate_table_cache[k] + + def _find_blocking_rules_below_threshold( + self, max_comparisons_per_rule, blocking_expressions=None + ): + return find_blocking_rules_below_threshold_comparison_count( + self, max_comparisons_per_rule, blocking_expressions + ) + + def _detect_blocking_rules_for_prediction( + self, + max_comparisons_per_rule, + blocking_expressions=None, + min_freedom=1, + num_runs=200, + num_equi_join_weight=0, + field_freedom_weight=1, + num_brs_weight=10, + num_comparison_weight=10, + return_as_df=False, + ): + """Find blocking rules for prediction below some given threshold of the + maximum number of comparisons that can be generated per blocking rule + (max_comparisons_per_rule). + Uses a heuristic cost algorithm to identify the 'best' set of blocking rules + Args: + max_comparisons_per_rule (int): The maximum number of comparisons that + each blocking rule is allowed to generate + blocking_expressions: By default, blocking rules will be equi-joins + on the columns used by the Splink model. This allows you to manually + specify sql expressions from which combinations will be created. For + example, if you specify ["substr(dob, 1,4)", "surname", "dob"] + blocking rules will be chosen by blocking on combinations + of those expressions. + min_freedom (int, optional): The minimum amount of freedom any column should + be allowed. + num_runs (int, optional): Each run selects rows using a heuristic and costs + them. The more runs, the more likely you are to find the best rule. + Defaults to 5. + num_equi_join_weight (int, optional): Weight allocated to number of equi + joins in the blocking rules. + Defaults to 0 since this is cost better captured by other criteria. + field_freedom_weight (int, optional): Weight given to the cost of + having individual fields which don't havem much flexibility. Assigning + a high weight here makes it more likely you'll generate combinations of + blocking rules for which most fields are allowed to vary more than + the minimum. Defaults to 1. + num_brs_weight (int, optional): Weight assigned to the cost of + additional blocking rules. Higher weight here will result in a + preference for fewer blocking rules. Defaults to 10. + num_comparison_weight (int, optional): Weight assigned to the cost of + larger numbers of comparisons, which happens when more of the blocking + rules are close to the max_comparisons_per_rule. A higher + weight here prefers sets of rules which generate lower total + comparisons. Defaults to 10. + return_as_df (bool, optional): If false, assign recommendation to settings. + If true, return a dataframe containing details of the weights. + Defaults to False. + """ + + df_br_below_thres = find_blocking_rules_below_threshold_comparison_count( + self, max_comparisons_per_rule, blocking_expressions + ) + + blocking_rule_suggestions = suggest_blocking_rules( + df_br_below_thres, + min_freedom=min_freedom, + num_runs=num_runs, + num_equi_join_weight=num_equi_join_weight, + field_freedom_weight=field_freedom_weight, + num_brs_weight=num_brs_weight, + num_comparison_weight=num_comparison_weight, + ) + + if return_as_df: + return blocking_rule_suggestions + else: + if blocking_rule_suggestions is None or len(blocking_rule_suggestions) == 0: + logger.warning("No set of blocking rules found within constraints") + else: + suggestion = blocking_rule_suggestions[ + "suggested_blocking_rules_as_splink_brs" + ].iloc[0] + self._settings_obj._blocking_rules_to_generate_predictions = suggestion + + suggestion_str = blocking_rule_suggestions[ + "suggested_blocking_rules_for_prediction" + ].iloc[0] + msg = ( + "The following blocking_rules_to_generate_predictions were " + "automatically detected and assigned to your settings:\n" + ) + logger.info(f"{msg}{suggestion_str}") + + def _detect_blocking_rules_for_em_training( + self, + max_comparisons_per_rule, + min_freedom=1, + num_runs=200, + num_equi_join_weight=0, + field_freedom_weight=1, + num_brs_weight=20, + num_comparison_weight=10, + return_as_df=False, + ): + """Find blocking rules for EM training below some given threshold of the + maximum number of comparisons that can be generated per blocking rule + (max_comparisons_per_rule). + Uses a heuristic cost algorithm to identify the 'best' set of blocking rules + Args: + max_comparisons_per_rule (int): The maximum number of comparisons that + each blocking rule is allowed to generate + min_freedom (int, optional): The minimum amount of freedom any column should + be allowed. + num_runs (int, optional): Each run selects rows using a heuristic and costs + them. The more runs, the more likely you are to find the best rule. + Defaults to 5. + num_equi_join_weight (int, optional): Weight allocated to number of equi + joins in the blocking rules. + Defaults to 0 since this is cost better captured by other criteria. + Defaults to 0 since this is cost better captured by other criteria. + field_freedom_weight (int, optional): Weight given to the cost of + having individual fields which don't havem much flexibility. Assigning + a high weight here makes it more likely you'll generate combinations of + blocking rules for which most fields are allowed to vary more than + the minimum. Defaults to 1. + num_brs_weight (int, optional): Weight assigned to the cost of + additional blocking rules. Higher weight here will result in a + preference for fewer blocking rules. Defaults to 10. + num_comparison_weight (int, optional): Weight assigned to the cost of + larger numbers of comparisons, which happens when more of the blocking + rules are close to the max_comparisons_per_rule. A higher + weight here prefers sets of rules which generate lower total + comparisons. Defaults to 10. + return_as_df (bool, optional): If false, return just the recommendation. + If true, return a dataframe containing details of the weights. + Defaults to False. + """ + + df_br_below_thres = find_blocking_rules_below_threshold_comparison_count( + self, max_comparisons_per_rule + ) + + blocking_rule_suggestions = suggest_blocking_rules( + df_br_below_thres, + min_freedom=min_freedom, + num_runs=num_runs, + num_equi_join_weight=num_equi_join_weight, + field_freedom_weight=field_freedom_weight, + num_brs_weight=num_brs_weight, + num_comparison_weight=num_comparison_weight, + ) + + if return_as_df: + return blocking_rule_suggestions + else: + if blocking_rule_suggestions is None or len(blocking_rule_suggestions) == 0: + logger.warning("No set of blocking rules found within constraints") + return None + else: + suggestion_str = blocking_rule_suggestions[ + "suggested_EM_training_statements" + ].iloc[0] + msg = "The following EM training strategy was detected:\n" + msg = f"{msg}{suggestion_str}" + logger.info(msg) + suggestion = blocking_rule_suggestions[ + "suggested_blocking_rules_as_splink_brs" + ].iloc[0] + return suggestion diff --git a/splink/optimise_cost_of_brs.py b/splink/optimise_cost_of_brs.py new file mode 100644 index 0000000000..18277a5967 --- /dev/null +++ b/splink/optimise_cost_of_brs.py @@ -0,0 +1,214 @@ +import logging +from random import randint + +import pandas as pd + +from .cost_of_blocking_rules import calculate_cost_of_combination_of_brs + +logger = logging.getLogger(__name__) + + +def localised_shuffle(lst: list, window_percent: float) -> list: + """ + Performs a localised shuffle on a list. + + This is used to choose semi-randomly from a list of + sorted rows, so you tend to pick from items towards the top + + Args: + lst (list): The list to shuffle. + window_percent (float): The window percent for shuffle e.g. 0.3 for shuffle + within 30% of orig position + + Returns: + list: A shuffled copy of the original list. + """ + window_size = max(1, int(window_percent * len(lst))) + return sorted(lst, key=lambda x: lst.index(x) + randint(-window_size, window_size)) + + +def check_field_freedom(candidate_set, field_names, min_field_freedom): + """ + Checks if each field in the candidate set is allowed to vary at least + 'min_field_freedom' times. + + Args: + candidate_set (list): The candidate set of rows. + field_names (list): The list of field names. + min_field_freedom (int): The minimum field freedom. + + Returns: + bool: True if each field can vary at least 'min_field_freedom' times, + False otherwise. + """ + covered_fields = {field: 0 for field in field_names} + for row in candidate_set: + for field in field_names: + if row[field] == 0: + covered_fields[field] += 1 + return all(count >= min_field_freedom for count in covered_fields.values()) + + +def heuristic_select_brs_that_have_min_freedom(data, field_names, min_field_freedom): + """ + A heuristic algorithm to select blocking rules that between them + ensure that each field is allowed to vary at least 'min_field_freedom' times. + + Args: + data (list): The data rows. + field_names (list): The list of field names. + min_field_freedom (int): The minimum field freedom. + + Returns: + list: The candidate set of rows. + """ + data_sorted_randomised = localised_shuffle(data, 0.5) + candidate_rows = [] + + for row in data_sorted_randomised: + candidate_rows.append(row) + if check_field_freedom(candidate_rows, field_names, min_field_freedom): + break + + sorted_candidate_rows = sorted( + candidate_rows, key=lambda x: x["blocking_columns_sanitised"] + ) + + return sorted_candidate_rows + + +def get_block_on_string(br_rows): + block_on_strings = [] + + for row in br_rows: + quoted_args = [] + for arg in row["blocking_columns_sanitised"]: + quoted_arg = f'"{arg}"' + quoted_args.append(quoted_arg) + + block_on_args = ", ".join(quoted_args) + block_on_string = f"block_on([{block_on_args}])" + block_on_strings.append(block_on_string) + + return " \n".join(block_on_strings) + + +def get_em_training_string(br_rows): + block_on_strings = [] + + for row in br_rows: + quoted_args = [] + for arg in row["blocking_columns_sanitised"]: + quoted_arg = f'"{arg}"' + quoted_args.append(quoted_arg) + + block_on_args = ", ".join(quoted_args) + block_on_string = f"block_on([{block_on_args}])" + block_on_strings.append(block_on_string) + + training_statements = [] + for block_on_str in block_on_strings: + statement = ( + f"linker.estimate_parameters_using_expectation_maximisation({block_on_str})" + ) + training_statements.append(statement) + + return " \n".join(training_statements) + + +def suggest_blocking_rules( + df_block_stats, + min_freedom=1, + num_runs=100, + num_equi_join_weight=0, + field_freedom_weight=1, + num_brs_weight=10, + num_comparison_weight=10, +): + """Use a cost optimiser to suggest blocking rules + + Args: + df_block_stats: Dataframe returned by find_blocking_rules_below_threshold + min_freedom (int, optional): Each column should have at least this many + opportunities to vary amongst the blocking rules. Defaults to 1. + num_runs (int, optional): How many random combinations of + rules to try. The best will be selected. Defaults to 5. + num_equi_join_weight (int, optional): The weight for number of equi joins. + Defaults to 0. + field_freedom_weight (int, optional): The weight for field freedom. Defaults to + 10. + num_brs_weight (int, optional): The weight for the number of blocking rules + found. Defaults to 10. + + Returns: + pd.DataFrame: A DataFrame containing the results of the blocking rules + suggestion. It includes columns such as + 'suggested_blocking_rules_for_prediction', + 'suggested_EM_training_statements', and various cost information + + """ + if len(df_block_stats) == 0: + return None + + max_comparison_count = df_block_stats["comparison_count"].max() + + df_block_stats = df_block_stats.sort_values( + by=["num_equi_joins", "comparison_count"], ascending=[True, False] + ) + blocks_found_recs = df_block_stats.to_dict(orient="records") + + blocking_cols = list(blocks_found_recs[0].keys()) + blocking_cols = [c for c in blocking_cols if c.startswith("__fixed__")] + + results = [] + + for run in range(num_runs): + selected_rows = heuristic_select_brs_that_have_min_freedom( + blocks_found_recs, blocking_cols, min_field_freedom=min_freedom + ) + + cost_dict = { + "suggested_blocking_rules_for_prediction": get_block_on_string( + selected_rows + ), + "suggested_EM_training_statements": get_em_training_string(selected_rows), + } + + costs = calculate_cost_of_combination_of_brs( + selected_rows, + max_comparison_count, + num_equi_join_weight, + field_freedom_weight, + num_brs_weight, + num_comparison_weight, + ) + + cost_dict.update(costs) + cost_dict.update( + { + "run_num": run, + "minimum_freedom_for_each_column": min_freedom, + "suggested_blocking_rules_as_splink_brs": [ + row["splink_blocking_rule"] for row in selected_rows + ], + } + ) + results.append(cost_dict) + + results_df = pd.DataFrame(results) + # easier to read if we normalise the cost so the best is 0 + min_ = results_df["field_freedom_cost"].min() + results_df["field_freedom_cost"] = results_df["field_freedom_cost"] - min_ + + min_ = results_df["field_freedom_cost_weighted"].min() + results_df["field_freedom_cost_weighted"] = ( + results_df["field_freedom_cost_weighted"] - min_ + ) + results_df["cost"] = results_df["cost"] - min_ + + min_scores_df = results_df.sort_values("cost") + min_scores_df = min_scores_df.drop_duplicates( + "suggested_blocking_rules_for_prediction" + ) + + return min_scores_df