Skip to content

Commit

Permalink
Merge pull request #2106 from moj-analytical-services/further_br_fixes
Browse files Browse the repository at this point in the history
Further br fixes
  • Loading branch information
RobinL committed Mar 28, 2024
2 parents 37854de + 5ae52c9 commit a35d4bb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 36 deletions.
18 changes: 17 additions & 1 deletion splink/analyse_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .blocking import BlockingRule, _sql_gen_where_condition, block_using_rules_sqls
from .misc import calculate_cartesian, calculate_reduction_ratio
from .pipeline import CTEPipeline
from .vertically_concatenate import compute_df_concat
from .vertically_concatenate import compute_df_concat, enqueue_df_concat

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
if TYPE_CHECKING:
Expand Down Expand Up @@ -258,3 +258,19 @@ def count_comparisons_from_blocking_rule_pre_filter_conditions_sqls(
sqls.append({"sql": sql, "output_table_name": "__splink__total_of_block_counts"})

return sqls


def count_comparisons_from_blocking_rule_pre_filter_conditions(
linker: "Linker", blocking_rule: Union[str, "BlockingRule"]
):
pipeline = CTEPipeline()
pipeline = enqueue_df_concat(linker, pipeline)

sqls = count_comparisons_from_blocking_rule_pre_filter_conditions_sqls(
linker, blocking_rule
)
pipeline.enqueue_list_of_sqls(sqls)

df_res = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline)
res = df_res.as_record_dict()[0]
return int(res["count_of_pairwise_comparisons_generated"])
35 changes: 12 additions & 23 deletions splink/find_brs_with_comparison_counts_below_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

import pandas as pd

from .analyse_blocking import (
count_comparisons_from_blocking_rule_pre_filter_conditions,
)
from .blocking import BlockingRule
from .blocking_rule_creator import BlockingRuleCreator
from .blocking_rule_library import CustomRule, block_on
from .input_column import InputColumn

if TYPE_CHECKING:
Expand Down Expand Up @@ -64,30 +69,13 @@ def _generate_blocking_rule(
"""Generate a Splink blocking rule given a list of column names which
are provided as as string"""

# TODO: Refactor in Splink4
dialect = linker._sql_dialect

module_mapping = {
"presto": "splink.athena.blocking_rule_library",
"duckdb": "splink.duckdb.blocking_rule_library",
"postgres": "splink.postgres.blocking_rule_library",
"spark": "splink.spark.blocking_rule_library",
"sqlite": "splink.sqlite.blocking_rule_library",
}

if dialect not in module_mapping:
raise ValueError(f"Unsupported SQL dialect: {dialect}")

module_name = module_mapping[dialect]
block_on_module = __import__(module_name, fromlist=["block_on"])
block_on = block_on_module.block_on

if len(cols_as_string) == 0:
return block_on("1")
br: BlockingRuleCreator = CustomRule("1=1", linker._sql_dialect)
else:

br = block_on(cols_as_string)
br = block_on(*cols_as_string)

return br
return br.get_blocking_rule(linker._sql_dialect)


def _search_tree_for_blocking_rules_below_threshold_count(
Expand Down Expand Up @@ -165,8 +153,9 @@ def _search_tree_for_blocking_rules_below_threshold_count(
return results # All fields included, meaning we're at a leaf so exit recursion

br = _generate_blocking_rule(linker, current_combination)
comparison_count = (
linker._count_num_comparisons_from_blocking_rule_pre_filter_conditions(br)

comparison_count = count_comparisons_from_blocking_rule_pre_filter_conditions(
linker, br
)

already_visited.add(frozenset(current_combination))
Expand Down
19 changes: 7 additions & 12 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
truth_space_table_from_labels_table,
)
from .analyse_blocking import (
count_comparisons_from_blocking_rule_pre_filter_conditions_sqls,
count_comparisons_from_blocking_rule_pre_filter_conditions,
cumulative_comparisons_generated_by_blocking_rules,
number_of_comparisons_generated_by_blocking_rule_post_filters_sql,
)
Expand Down Expand Up @@ -2766,7 +2766,7 @@ def count_num_comparisons_from_blocking_rule(

def _count_num_comparisons_from_blocking_rule_pre_filter_conditions(
self,
blocking_rule: BlockingRule,
blocking_rule: BlockingRuleCreator | str | dict,
) -> int:
"""Compute the number of pairwise record comparisons that would be generated by
a blocking rule, prior to any filters (non equi-join conditions) being applied
Expand All @@ -2782,17 +2782,12 @@ def _count_num_comparisons_from_blocking_rule_pre_filter_conditions(
int: The number of comparisons generated by the blocking rule
"""

pipeline = CTEPipeline()
pipeline = enqueue_df_concat(self, pipeline)

sqls = count_comparisons_from_blocking_rule_pre_filter_conditions_sqls(
self, blocking_rule
blocking_rule_obj = to_blocking_rule_creator(blocking_rule).get_blocking_rule(
self._sql_dialect
)
return count_comparisons_from_blocking_rule_pre_filter_conditions(
self, blocking_rule_obj
)
pipeline.enqueue_list_of_sqls(sqls)

df_res = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)
res = df_res.as_record_dict()[0]
return int(res["count_of_pairwise_comparisons_generated"])

def cumulative_comparisons_from_blocking_rules_records(
self,
Expand Down

0 comments on commit a35d4bb

Please sign in to comment.