Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further br fixes #2106

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion splink/analyse_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .blocking import BlockingRule, _sql_gen_where_condition, block_using_rules_sqls
from .misc import calculate_cartesian, calculate_reduction_ratio
from .pipeline import CTEPipeline
from .vertically_concatenate import compute_df_concat
from .vertically_concatenate import compute_df_concat, enqueue_df_concat

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
if TYPE_CHECKING:
Expand Down Expand Up @@ -258,3 +258,19 @@ def count_comparisons_from_blocking_rule_pre_filter_conditions_sqls(
sqls.append({"sql": sql, "output_table_name": "__splink__total_of_block_counts"})

return sqls


def count_comparisons_from_blocking_rule_pre_filter_conditions(
Copy link
Member Author

@RobinL RobinL Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved out of linker to reduce line count and so it can be more easily called from elsewhere

linker: "Linker", blocking_rule: Union[str, "BlockingRule"]
):
pipeline = CTEPipeline()
pipeline = enqueue_df_concat(linker, pipeline)

sqls = count_comparisons_from_blocking_rule_pre_filter_conditions_sqls(
linker, blocking_rule
)
pipeline.enqueue_list_of_sqls(sqls)

df_res = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline)
res = df_res.as_record_dict()[0]
return int(res["count_of_pairwise_comparisons_generated"])
35 changes: 12 additions & 23 deletions splink/find_brs_with_comparison_counts_below_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

import pandas as pd

from .analyse_blocking import (
count_comparisons_from_blocking_rule_pre_filter_conditions,
)
from .blocking import BlockingRule
from .blocking_rule_creator import BlockingRuleCreator
from .blocking_rule_library import CustomRule, block_on
from .input_column import InputColumn

if TYPE_CHECKING:
Expand Down Expand Up @@ -64,30 +69,13 @@ def _generate_blocking_rule(
"""Generate a Splink blocking rule given a list of column names which
are provided as as string"""

# TODO: Refactor in Splink4
dialect = linker._sql_dialect

module_mapping = {
"presto": "splink.athena.blocking_rule_library",
"duckdb": "splink.duckdb.blocking_rule_library",
"postgres": "splink.postgres.blocking_rule_library",
"spark": "splink.spark.blocking_rule_library",
"sqlite": "splink.sqlite.blocking_rule_library",
}

if dialect not in module_mapping:
raise ValueError(f"Unsupported SQL dialect: {dialect}")

module_name = module_mapping[dialect]
block_on_module = __import__(module_name, fromlist=["block_on"])
block_on = block_on_module.block_on

if len(cols_as_string) == 0:
return block_on("1")
br: BlockingRuleCreator = CustomRule("1=1", linker._sql_dialect)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type hint was needed for mypy, without it you got an splink/find_brs_with_comparison_counts_below_threshold.py:75: error: Incompatible types in assignment (expression has type "BlockingRuleCreator", variable has type "CustomRule") [assignment] error

else:

br = block_on(cols_as_string)
br = block_on(*cols_as_string)

return br
return br.get_blocking_rule(linker._sql_dialect)


def _search_tree_for_blocking_rules_below_threshold_count(
Expand Down Expand Up @@ -165,8 +153,9 @@ def _search_tree_for_blocking_rules_below_threshold_count(
return results # All fields included, meaning we're at a leaf so exit recursion

br = _generate_blocking_rule(linker, current_combination)
comparison_count = (
linker._count_num_comparisons_from_blocking_rule_pre_filter_conditions(br)

comparison_count = count_comparisons_from_blocking_rule_pre_filter_conditions(
linker, br
)

already_visited.add(frozenset(current_combination))
Expand Down
19 changes: 7 additions & 12 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
truth_space_table_from_labels_table,
)
from .analyse_blocking import (
count_comparisons_from_blocking_rule_pre_filter_conditions_sqls,
count_comparisons_from_blocking_rule_pre_filter_conditions,
cumulative_comparisons_generated_by_blocking_rules,
number_of_comparisons_generated_by_blocking_rule_post_filters_sql,
)
Expand Down Expand Up @@ -2787,7 +2787,7 @@ def count_num_comparisons_from_blocking_rule(

def _count_num_comparisons_from_blocking_rule_pre_filter_conditions(
self,
blocking_rule: BlockingRule,
blocking_rule: BlockingRuleCreator | str | dict,
) -> int:
"""Compute the number of pairwise record comparisons that would be generated by
a blocking rule, prior to any filters (non equi-join conditions) being applied
Expand All @@ -2803,17 +2803,12 @@ def _count_num_comparisons_from_blocking_rule_pre_filter_conditions(
int: The number of comparisons generated by the blocking rule
"""

pipeline = CTEPipeline()
pipeline = enqueue_df_concat(self, pipeline)

sqls = count_comparisons_from_blocking_rule_pre_filter_conditions_sqls(
self, blocking_rule
blocking_rule_obj = to_blocking_rule_creator(blocking_rule).get_blocking_rule(
self._sql_dialect
)
return count_comparisons_from_blocking_rule_pre_filter_conditions(
self, blocking_rule_obj
)
pipeline.enqueue_list_of_sqls(sqls)

df_res = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)
res = df_res.as_record_dict()[0]
return int(res["count_of_pairwise_comparisons_generated"])

def cumulative_comparisons_from_blocking_rules_records(
self,
Expand Down
Loading