From 8a9429de0a4dd5a860e8276f9a588125f034b875 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 21 Mar 2024 14:02:11 +0000 Subject: [PATCH 1/2] use new compute and enqueue functions --- splink/accuracy.py | 8 ++++-- splink/blocking.py | 9 +++--- splink/em_training_session.py | 4 ++- splink/estimate_u.py | 4 ++- splink/linker.py | 49 ++++++++++---------------------- splink/vertically_concatenate.py | 41 ++++++++++++++++++++++++++ 6 files changed, 73 insertions(+), 42 deletions(-) diff --git a/splink/accuracy.py b/splink/accuracy.py index 3978c23c99..5e5ca1c67d 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -7,6 +7,7 @@ from .pipeline import CTEPipeline from .predict import predict_from_comparison_vectors_sqls_using_settings from .sql_transform import move_l_r_table_prefix_to_column_suffix +from .vertically_concatenate import compute_df_concat_with_tf if TYPE_CHECKING: from .linker import Linker @@ -168,7 +169,9 @@ def truth_space_table_from_labels_table( linker, labels_tablename, threshold_actual=0.5, match_weight_round_to_nearest=None ): pipeline = CTEPipeline(reusable=False) - pipeline = linker._enqueue_df_concat_with_tf(pipeline) + + nodes_with_tf = compute_df_concat_with_tf(linker, pipeline) + pipeline = CTEPipeline([nodes_with_tf], reusable=False) sqls = predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename) pipeline.enqueue_list_of_sqls(sqls) @@ -269,7 +272,8 @@ def prediction_errors_from_labels_table( threshold=0.5, ): pipeline = CTEPipeline(reusable=False) - pipeline = linker._enqueue_df_concat_with_tf(pipeline) + nodes_with_tf = compute_df_concat_with_tf(linker, pipeline) + pipeline = CTEPipeline([nodes_with_tf], reusable=False) sqls = predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename) diff --git a/splink/blocking.py b/splink/blocking.py index f147af570b..636d2a8127 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -13,6 +13,7 @@ from .pipeline import CTEPipeline from .splink_dataframe import SplinkDataFrame from .unique_id_concat import _composite_unique_id_from_nodes_sql +from .vertically_concatenate import compute_df_concat_with_tf logger = logging.getLogger(__name__) @@ -397,13 +398,12 @@ def materialise_exploded_id_tables(linker: Linker): exploding_blocking_rules = [ br for br in blocking_rules if isinstance(br, ExplodingBlockingRule) ] + if len(exploding_blocking_rules) == 0: + return [] exploded_tables = [] pipeline = CTEPipeline(reusable=False) - linker._enqueue_df_concat_with_tf(pipeline) - nodes_with_tf = linker._intermediate_table_cache.get_with_logging( - "__splink__df_concat_with_tf" - ) + nodes_with_tf = compute_df_concat_with_tf(linker, pipeline) input_colnames = {col.name for col in nodes_with_tf.columns} @@ -434,6 +434,7 @@ def materialise_exploded_id_tables(linker: Linker): marginal_ids_table = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline) br.exploded_id_pair_table = marginal_ids_table exploded_tables.append(marginal_ids_table) + return exploding_blocking_rules diff --git a/splink/em_training_session.py b/splink/em_training_session.py index 28e055a2bf..aa970bcfa3 100644 --- a/splink/em_training_session.py +++ b/splink/em_training_session.py @@ -26,6 +26,7 @@ Settings, TrainingSettings, ) +from .vertically_concatenate import compute_df_concat_with_tf logger = logging.getLogger(__name__) @@ -178,7 +179,8 @@ def _comparison_vectors(self): self._training_log_message() pipeline = CTEPipeline() - pipeline = self._original_linker._enqueue_df_concat_with_tf(pipeline) + nodes_with_tf = compute_df_concat_with_tf(self._original_linker, pipeline) + pipeline = CTEPipeline([nodes_with_tf], reusable=False) sqls = block_using_rules_sqls( self._original_linker, [self._blocking_rule_for_training] diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 499294733e..ce628c41f0 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -16,6 +16,7 @@ m_u_records_to_lookup_dict, ) from .pipeline import CTEPipeline +from .vertically_concatenate import compute_df_concat_with_tf # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: @@ -57,7 +58,8 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): logger.info("----- Estimating u probabilities using random sampling -----") pipeline = CTEPipeline(reusable=False) - pipeline = linker._enqueue_df_concat_with_tf(pipeline) + nodes_with_tf = compute_df_concat_with_tf(linker, pipeline) + pipeline = CTEPipeline([nodes_with_tf], reusable=False) original_settings_obj = linker._settings_obj diff --git a/splink/linker.py b/splink/linker.py index fdea60ac59..c130b6a0c7 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -9,7 +9,7 @@ from statistics import median from typing import Dict, Optional, Union - +from .vertically_concatenate import enqueue_df_concat_with_tf, compute_df_concat_with_tf from splink.input_column import InputColumn from splink.settings_validation.log_invalid_columns import ( InvalidColumnsLogger, @@ -536,35 +536,6 @@ def _initialise_df_concat_with_tf(self, materialise=True): return nodes_with_tf - def _enqueue_df_concat_with_tf(self, pipeline: CTEPipeline, materialise=True): - - cache = self._intermediate_table_cache - - if "__splink__df_concat_with_tf" in cache: - nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf") - pipeline.append_input_dataframe(nodes_with_tf) - return pipeline - - # In duckdb, calls to random() in a CTE pipeline cause problems: - # https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139 - if self._settings_obj.salting_required: - materialise = True - - sql = vertically_concatenate_sql(self) - pipeline.enqueue_sql(sql, "__splink__df_concat") - - sqls = compute_all_term_frequencies_sqls(self) - pipeline.enqueue_list_of_sqls(sqls) - - if materialise: - # Can't use break lineage here because we need nodes_with_tf - # so it can be explicitly set to the named cache - nodes_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline) - cache["__splink__df_concat_with_tf"] = nodes_with_tf - pipeline = CTEPipeline(input_dataframes=[nodes_with_tf]) - - return pipeline - def _table_to_splink_dataframe( self, templated_name, physical_name ) -> SplinkDataFrame: @@ -978,7 +949,10 @@ def deterministic_link(self) -> SplinkDataFrame: # to set the cluster threshold to 1 self._deterministic_link_mode = True - pipeline = self._enqueue_df_concat_with_tf(pipeline) + df_concat_with_tf = compute_df_concat_with_tf( + self, pipeline, self._intermediate_table_cache + ) + pipeline = CTEPipeline([df_concat_with_tf], reusable=False) exploding_br_with_id_tables = materialise_exploded_id_tables(self) @@ -1321,9 +1295,16 @@ def predict( # calls predict, it runs as a single pipeline with no materialisation # of anything. - pipeline = self._enqueue_df_concat_with_tf( - pipeline, materialise=materialise_after_computing_term_frequencies - ) + # In duckdb, calls to random() in a CTE pipeline cause problems: + # https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139 + if ( + materialise_after_computing_term_frequencies + or self._sql_dialect == "duckdb" + ): + df_concat_with_tf = compute_df_concat_with_tf(self, pipeline, self.db_api) + pipeline = CTEPipeline([df_concat_with_tf], reusable=False) + else: + pipeline = enqueue_df_concat_with_tf(self, pipeline) # If exploded blocking rules exist, we need to materialise # the tables of ID pairs diff --git a/splink/vertically_concatenate.py b/splink/vertically_concatenate.py index 68add6e51a..86fe8c7d9e 100644 --- a/splink/vertically_concatenate.py +++ b/splink/vertically_concatenate.py @@ -3,6 +3,11 @@ import logging from typing import TYPE_CHECKING +from .database_api import DatabaseAPI +from .pipeline import CTEPipeline +from .splink_dataframe import SplinkDataFrame +from .term_frequencies import compute_all_term_frequencies_sqls + logger = logging.getLogger(__name__) # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports @@ -74,3 +79,39 @@ def vertically_concatenate_sql(linker: Linker) -> str: """ return sql + + +def enqueue_df_concat_with_tf(linker: Linker, pipeline: CTEPipeline) -> CTEPipeline: + + cache = linker._intermediate_table_cache + if "__splink__df_concat_with_tf" in cache: + nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf") + pipeline.append_input_dataframe(nodes_with_tf) + return pipeline + + sql = vertically_concatenate_sql(linker) + pipeline.enqueue_sql(sql, "__splink__df_concat") + + sqls = compute_all_term_frequencies_sqls(linker) + pipeline.enqueue_list_of_sqls(sqls) + + return pipeline + + +def compute_df_concat_with_tf( + linker: Linker, pipeline, db_api: DatabaseAPI +) -> SplinkDataFrame: + cache = linker._intermediate_table_cache + + if "__splink__df_concat_with_tf" in cache: + return cache.get_with_logging("__splink__df_concat_with_tf") + + sql = vertically_concatenate_sql(linker) + pipeline.enqueue_sql(sql, "__splink__df_concat") + + sqls = compute_all_term_frequencies_sqls(linker) + pipeline.enqueue_list_of_sqls(sqls) + + nodes_with_tf = db_api.sql_pipeline_to_splink_dataframe(pipeline) + cache["__splink__df_concat_with_tf"] = nodes_with_tf + return nodes_with_tf From e70d66d5edf13f4e9f3cb2eb251c33fbccf84b03 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 21 Mar 2024 14:23:56 +0000 Subject: [PATCH 2/2] simplify api --- splink/labelling_tool.py | 10 +++++++--- splink/linker.py | 6 ++---- splink/vertically_concatenate.py | 6 ++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/splink/labelling_tool.py b/splink/labelling_tool.py index 0089a1e3d5..49ea7d7260 100644 --- a/splink/labelling_tool.py +++ b/splink/labelling_tool.py @@ -8,7 +8,9 @@ from jinja2 import Template from .misc import EverythingEncoder, read_resource +from .pipeline import CTEPipeline from .splink_dataframe import SplinkDataFrame +from .vertically_concatenate import compute_df_concat_with_tf # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: @@ -21,8 +23,10 @@ def generate_labelling_tool_comparisons( linker: "Linker", unique_id, source_dataset, match_weight_threshold=-4 ): # ensure the tf table exists - concat_with_tf = linker._initialise_df_concat_with_tf() + pipeline = CTEPipeline(reusable=False) + nodes_with_tf = compute_df_concat_with_tf(linker, pipeline) + pipeline = CTEPipeline([nodes_with_tf], reusable=False) settings = linker._settings_obj source_dataset_condition = "" @@ -40,8 +44,8 @@ def generate_labelling_tool_comparisons( {source_dataset_condition} """ - linker._enqueue_sql(sql, "__splink__df_labelling_tool_record") - splink_df = linker._execute_sql_pipeline([concat_with_tf]) + pipeline.enqueue_sql(sql, "__splink__df_labelling_tool_record") + splink_df = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline) matches = linker.find_matches_to_new_records( splink_df.physical_name, match_weight_threshold=match_weight_threshold diff --git a/splink/linker.py b/splink/linker.py index c130b6a0c7..071dbc1a77 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -949,9 +949,7 @@ def deterministic_link(self) -> SplinkDataFrame: # to set the cluster threshold to 1 self._deterministic_link_mode = True - df_concat_with_tf = compute_df_concat_with_tf( - self, pipeline, self._intermediate_table_cache - ) + df_concat_with_tf = compute_df_concat_with_tf(self, pipeline) pipeline = CTEPipeline([df_concat_with_tf], reusable=False) exploding_br_with_id_tables = materialise_exploded_id_tables(self) @@ -1301,7 +1299,7 @@ def predict( materialise_after_computing_term_frequencies or self._sql_dialect == "duckdb" ): - df_concat_with_tf = compute_df_concat_with_tf(self, pipeline, self.db_api) + df_concat_with_tf = compute_df_concat_with_tf(self, pipeline) pipeline = CTEPipeline([df_concat_with_tf], reusable=False) else: pipeline = enqueue_df_concat_with_tf(self, pipeline) diff --git a/splink/vertically_concatenate.py b/splink/vertically_concatenate.py index 86fe8c7d9e..6592ade62b 100644 --- a/splink/vertically_concatenate.py +++ b/splink/vertically_concatenate.py @@ -3,7 +3,6 @@ import logging from typing import TYPE_CHECKING -from .database_api import DatabaseAPI from .pipeline import CTEPipeline from .splink_dataframe import SplinkDataFrame from .term_frequencies import compute_all_term_frequencies_sqls @@ -98,10 +97,9 @@ def enqueue_df_concat_with_tf(linker: Linker, pipeline: CTEPipeline) -> CTEPipel return pipeline -def compute_df_concat_with_tf( - linker: Linker, pipeline, db_api: DatabaseAPI -) -> SplinkDataFrame: +def compute_df_concat_with_tf(linker: Linker, pipeline) -> SplinkDataFrame: cache = linker._intermediate_table_cache + db_api = linker.db_api if "__splink__df_concat_with_tf" in cache: return cache.get_with_logging("__splink__df_concat_with_tf")