Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finds blocking rules which return a comparison count below a given threshold #1665

Merged
249 changes: 249 additions & 0 deletions splink/find_brs_with_comparison_counts_below_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import logging
import string
from typing import TYPE_CHECKING, Dict, List, Set

import pandas as pd

from .blocking import BlockingRule
from .input_column import InputColumn

if TYPE_CHECKING:
from .linker import Linker
logger = logging.getLogger(__name__)


def sanitise_column_name_for_one_hot_encoding(column_name) -> str:
allowed_chars = string.ascii_letters + string.digits + "_"
sanitised_name = "".join(c for c in column_name if c in allowed_chars)
return sanitised_name


def _generate_output_combinations_table_row(
blocking_columns, splink_blocking_rule, comparison_count, all_columns
) -> dict:
row = {}

blocking_columns = [
sanitise_column_name_for_one_hot_encoding(c) for c in blocking_columns
]
all_columns = [sanitise_column_name_for_one_hot_encoding(c) for c in all_columns]

row["blocking_columns_sanitised"] = blocking_columns
row["splink_blocking_rule"] = splink_blocking_rule
row["comparison_count"] = comparison_count
row["num_equi_joins"] = len(blocking_columns)

for col in all_columns:
row[f"__fixed__{col}"] = 1 if col in blocking_columns else 0

return row


def _generate_combinations(
all_columns, current_combination, already_visited: Set[frozenset]
) -> list:
"""Generate combinations of columns to visit that haven't been visited already
irrespective of order
"""

combinations = []
for col in all_columns:
if col not in current_combination:
next_combination = current_combination + [col]
if frozenset(next_combination) not in already_visited:
combinations.append(next_combination)

return combinations


def _generate_blocking_rule(
linker: "Linker", cols_as_string: List[str]
) -> BlockingRule:
"""Generate a Splink blocking rule given a list of column names which
are provided as as string"""

dialect = linker._sql_dialect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worthwhile adding in a # TODO: remove in Splink4 tag here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a VSCode plugin that allows you to see all outstanding TODO comments that would ensure we don't forget about them too


module_mapping = {
"presto": "splink.athena.blocking_rule_library",
"duckdb": "splink.duckdb.blocking_rule_library",
"postgres": "splink.postgres.blocking_rule_library",
"spark": "splink.spark.blocking_rule_library",
"sqlite": "splink.sqlite.blocking_rule_library",
}

if dialect not in module_mapping:
raise ValueError(f"Unsupported SQL dialect: {dialect}")

module_name = module_mapping[dialect]
block_on_module = __import__(module_name, fromlist=["block_on"])
block_on = block_on_module.block_on

if len(cols_as_string) == 0:
return "1 = 1"

br = block_on(cols_as_string)

return br


def _search_tree_for_blocking_rules_below_threshold_count(
linker: "Linker",
all_columns: List[str],
threshold: float,
current_combination: List[str] = None,
already_visited: Set[frozenset] = None,
results: List[Dict[str, str]] = None,
) -> List[Dict[str, str]]:
"""
Recursively search combinations of fields to find ones that result in a count less
than the threshold.

Uses the new, fast counting function
linker._count_num_comparisons_from_blocking_rule_pre_filter_conditions
to count

The full tree looks like this, where c1 c2 are columns:
c1 count_comparisons(c1)
├── c2 count_comparisons(c1, c2)
│ └── c3 count_comparisons(c1, c2, c3)
├── c3 count_comparisons(c1, c3)
│ └── c2 count_comparisons(c1, c3, c2)
c2 count_comparisons(c2)
├── c1 count_comparisons(c2, c1)
│ └── c3 count_comparisons(c2, c1, c3)
Comment on lines +114 to +115
Copy link
Contributor

@ThomasHepworth ThomasHepworth Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth making it clear in this diagram that these have already been visited?

They're stored in the hashset, which should quickly trim them from next_combations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I may have misunderstood the order in which your loops work)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

├── c3 count_comparisons(c2, c3)
│ └── c1 count_comparisons(c2, c3, c1)

But many nodes do not need to be visited:
- Once the count is below the threshold, no branches from the node are explored.
- If a combination has alraedy been evaluated, it is not evaluated again. For
RobinL marked this conversation as resolved.
Show resolved Hide resolved
example, c2 -> c1 will not be evaluated because c1 -> c2 has already been
counted

When a count is below the threshold, create a dictionary with the relevant stats
like :
{
'blocking_columns_sanitised':['first_name'],
'splink_blocking_rule':<Custom rule>',
comparison_count':4827,
'complexity':1,
'__fixed__first_name':1,
'__fixed__surname':0,
'__fixed__dob':0,
'__fixed__city':0,
'__fixed__email':0,
'__fixed__cluster':0,
}

Return a list of these dicts.


Args:
linker: splink.Linker
fields (List[str]): List of fields to combine.
threshold (float): The count threshold.
current_combination (List[str], optional): Current combination of fields.
already_visited (Set[frozenset], optional): Set of visited combinations.
results (List[Dict[str, str]], optional): List of results. Defaults to [].

Returns:
List[Dict]: List of results. Each result is a dict with statistics like
the number of comparisons, the blocking rule etc.
"""
if current_combination is None:
current_combination = []
if already_visited is None:
already_visited = set()
if results is None:
results = []

if len(current_combination) == len(all_columns):
return results # All fields included, meaning we're at a leaf so exit recursion

br = _generate_blocking_rule(linker, current_combination)
comparison_count = (
linker._count_num_comparisons_from_blocking_rule_pre_filter_conditions(br)
)

already_visited.add(frozenset(current_combination))
ThomasHepworth marked this conversation as resolved.
Show resolved Hide resolved

if comparison_count > threshold:
# Generate all valid combinations and continue the search
combinations = _generate_combinations(
all_columns, current_combination, already_visited
)
for next_combination in combinations:
_search_tree_for_blocking_rules_below_threshold_count(
linker,
all_columns,
threshold,
next_combination,
already_visited,
results,
)
else:
row = _generate_output_combinations_table_row(
current_combination,
br,
comparison_count,
all_columns,
)
results.append(row)

return results


def find_blocking_rules_below_threshold_comparison_count(
linker: "Linker", max_comparisons_per_rule, column_expressions: List[str] = None
) -> pd.DataFrame:
"""
Finds blocking rules which return a comparison count below a given threshold.

In addition to returning blocking rules, returns the comparison count and
'num_equi_joins', which refers to the number of equi-joins used by the rule.

Also returns one-hot encoding that describes which columns are __fixed__ by the
blocking rule

e.g. equality on first_name and surname has num_equi_joins of 2

Args:
linker (Linker): The Linker object
max_comparisons_per_rule (int): Max comparisons allowed per blocking rule.
column_expressions: List[str] = Algorithm will find combinations of these
column expressions to use as blocking rules. If None, uses all columns used
by the ComparisonLevels of the Linker. Column expressions can be SQL
expressions, not just column names i.e. 'substr(surname, 1,1)' is a valid
entry in this list.

Returns:
pd.DataFrame: DataFrame with blocking rules, comparison_count and num_equi_joins
"""

if not column_expressions:
column_expressions = linker._input_columns(
include_unique_id_col_names=False,
include_additional_columns_to_retain=False,
)

column_expressions_as_strings = []

for c in column_expressions:
if isinstance(c, InputColumn):
column_expressions_as_strings.append(c.quote().name)
else:
column_expressions_as_strings.append(c)

results = _search_tree_for_blocking_rules_below_threshold_count(
ThomasHepworth marked this conversation as resolved.
Show resolved Hide resolved
linker, column_expressions_as_strings, max_comparisons_per_rule
)

if not results:
raise ValueError(
"No blocking rules could be found that produce a comparison count below "
"your chosen max_comparisons_per_rule threshold of "
f"{max_comparisons_per_rule}. Try increasing the threshold."
)

return pd.DataFrame(results)
34 changes: 30 additions & 4 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,26 @@ def __init__(

self.debug_mode = False

@property
def _input_columns(
self,
include_unique_id_col_names=True,
include_additional_columns_to_retain=True,
) -> list[InputColumn]:
"""Retrieve the column names from the input dataset(s)"""
"""Retrieve the column names from the input dataset(s) as InputColumns

Args:
include_unique_id_col_names (bool, optional): Whether to include unique ID
column names. Defaults to True.
include_additional_columns_to_retain (bool, optional): Whether to include
additional columns to retain. Defaults to True.

Raises:
SplinkException: If the input frames have different sets of columns.

Returns:
list[InputColumn]
"""

input_dfs = self._input_tables_dict.values()

# get a list of the column names for each input frame
Expand Down Expand Up @@ -280,13 +295,24 @@ def _input_columns(
+ ", ".join(problem_names)
)

return next(iter(input_dfs)).columns
columns = next(iter(input_dfs)).columns

remove_columns = []
if not include_unique_id_col_names:
remove_columns.extend(self._settings_obj._unique_id_input_columns)
if not include_additional_columns_to_retain:
remove_columns.extend(self._settings_obj._additional_columns_to_retain)

remove_id_cols = [c.unquote().name for c in remove_columns]
columns = [col for col in columns if col.unquote().name not in remove_id_cols]

return columns

@property
def _source_dataset_column_already_exists(self):
if self._settings_obj_ is None:
return False
input_cols = [c.unquote().name for c in self._input_columns]
input_cols = [c.unquote().name for c in self._input_columns()]
return self._settings_obj._source_dataset_column_name in input_cols

@property
Expand Down
2 changes: 1 addition & 1 deletion splink/missingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def missingness_sqls(columns, input_tablename):


def missingness_data(linker, input_tablename):
columns = linker._input_columns
columns = linker._input_columns()
if input_tablename is None:
splink_dataframe = linker._initialise_df_concat(materialise=True)
else:
Expand Down
2 changes: 1 addition & 1 deletion splink/profile_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def profile_columns(linker, column_expressions=None, top_n=10, bottom_n=10):
"""

if not column_expressions:
column_expressions = [col.name for col in linker._input_columns]
column_expressions = [col.name for col in linker._input_columns()]

df_concat = linker._initialise_df_concat()

Expand Down
Loading