Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enqueue and compute methods #2086

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions splink/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .pipeline import CTEPipeline
from .predict import predict_from_comparison_vectors_sqls_using_settings
from .sql_transform import move_l_r_table_prefix_to_column_suffix
from .vertically_concatenate import compute_df_concat_with_tf

if TYPE_CHECKING:
from .linker import Linker
Expand Down Expand Up @@ -168,7 +169,9 @@ def truth_space_table_from_labels_table(
linker, labels_tablename, threshold_actual=0.5, match_weight_round_to_nearest=None
):
pipeline = CTEPipeline(reusable=False)
pipeline = linker._enqueue_df_concat_with_tf(pipeline)

nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)
pipeline = CTEPipeline([nodes_with_tf], reusable=False)

sqls = predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename)
pipeline.enqueue_list_of_sqls(sqls)
Expand Down Expand Up @@ -269,7 +272,8 @@ def prediction_errors_from_labels_table(
threshold=0.5,
):
pipeline = CTEPipeline(reusable=False)
pipeline = linker._enqueue_df_concat_with_tf(pipeline)
nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)
pipeline = CTEPipeline([nodes_with_tf], reusable=False)

sqls = predictions_from_sample_of_pairwise_labels_sql(linker, labels_tablename)

Expand Down
9 changes: 5 additions & 4 deletions splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .pipeline import CTEPipeline
from .splink_dataframe import SplinkDataFrame
from .unique_id_concat import _composite_unique_id_from_nodes_sql
from .vertically_concatenate import compute_df_concat_with_tf

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -397,13 +398,12 @@ def materialise_exploded_id_tables(linker: Linker):
exploding_blocking_rules = [
br for br in blocking_rules if isinstance(br, ExplodingBlockingRule)
]
if len(exploding_blocking_rules) == 0:
return []
exploded_tables = []

pipeline = CTEPipeline(reusable=False)
linker._enqueue_df_concat_with_tf(pipeline)
nodes_with_tf = linker._intermediate_table_cache.get_with_logging(
"__splink__df_concat_with_tf"
)
nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)

input_colnames = {col.name for col in nodes_with_tf.columns}

Expand Down Expand Up @@ -434,6 +434,7 @@ def materialise_exploded_id_tables(linker: Linker):
marginal_ids_table = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline)
br.exploded_id_pair_table = marginal_ids_table
exploded_tables.append(marginal_ids_table)

return exploding_blocking_rules


Expand Down
4 changes: 3 additions & 1 deletion splink/em_training_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Settings,
TrainingSettings,
)
from .vertically_concatenate import compute_df_concat_with_tf

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -178,7 +179,8 @@ def _comparison_vectors(self):
self._training_log_message()

pipeline = CTEPipeline()
pipeline = self._original_linker._enqueue_df_concat_with_tf(pipeline)
nodes_with_tf = compute_df_concat_with_tf(self._original_linker, pipeline)
pipeline = CTEPipeline([nodes_with_tf], reusable=False)

sqls = block_using_rules_sqls(
self._original_linker, [self._blocking_rule_for_training]
Expand Down
4 changes: 3 additions & 1 deletion splink/estimate_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
m_u_records_to_lookup_dict,
)
from .pipeline import CTEPipeline
from .vertically_concatenate import compute_df_concat_with_tf

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
if TYPE_CHECKING:
Expand Down Expand Up @@ -57,7 +58,8 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None):
logger.info("----- Estimating u probabilities using random sampling -----")
pipeline = CTEPipeline(reusable=False)

pipeline = linker._enqueue_df_concat_with_tf(pipeline)
nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)
pipeline = CTEPipeline([nodes_with_tf], reusable=False)

original_settings_obj = linker._settings_obj

Expand Down
10 changes: 7 additions & 3 deletions splink/labelling_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from jinja2 import Template

from .misc import EverythingEncoder, read_resource
from .pipeline import CTEPipeline
from .splink_dataframe import SplinkDataFrame
from .vertically_concatenate import compute_df_concat_with_tf

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
if TYPE_CHECKING:
Expand All @@ -21,8 +23,10 @@ def generate_labelling_tool_comparisons(
linker: "Linker", unique_id, source_dataset, match_weight_threshold=-4
):
# ensure the tf table exists
concat_with_tf = linker._initialise_df_concat_with_tf()
pipeline = CTEPipeline(reusable=False)
nodes_with_tf = compute_df_concat_with_tf(linker, pipeline)

pipeline = CTEPipeline([nodes_with_tf], reusable=False)
settings = linker._settings_obj

source_dataset_condition = ""
Expand All @@ -40,8 +44,8 @@ def generate_labelling_tool_comparisons(
{source_dataset_condition}
"""

linker._enqueue_sql(sql, "__splink__df_labelling_tool_record")
splink_df = linker._execute_sql_pipeline([concat_with_tf])
pipeline.enqueue_sql(sql, "__splink__df_labelling_tool_record")
splink_df = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline)

matches = linker.find_matches_to_new_records(
splink_df.physical_name, match_weight_threshold=match_weight_threshold
Expand Down
47 changes: 13 additions & 34 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from statistics import median
from typing import Dict, Optional, Union


from .vertically_concatenate import enqueue_df_concat_with_tf, compute_df_concat_with_tf
from splink.input_column import InputColumn
from splink.settings_validation.log_invalid_columns import (
InvalidColumnsLogger,
Expand Down Expand Up @@ -536,35 +536,6 @@ def _initialise_df_concat_with_tf(self, materialise=True):

return nodes_with_tf

def _enqueue_df_concat_with_tf(self, pipeline: CTEPipeline, materialise=True):

cache = self._intermediate_table_cache

if "__splink__df_concat_with_tf" in cache:
nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf")
pipeline.append_input_dataframe(nodes_with_tf)
return pipeline

# In duckdb, calls to random() in a CTE pipeline cause problems:
# https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139
if self._settings_obj.salting_required:
materialise = True

sql = vertically_concatenate_sql(self)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(self)
pipeline.enqueue_list_of_sqls(sqls)

if materialise:
# Can't use break lineage here because we need nodes_with_tf
# so it can be explicitly set to the named cache
nodes_with_tf = self.db_api.sql_pipeline_to_splink_dataframe(pipeline)
cache["__splink__df_concat_with_tf"] = nodes_with_tf
pipeline = CTEPipeline(input_dataframes=[nodes_with_tf])

return pipeline

def _table_to_splink_dataframe(
self, templated_name, physical_name
) -> SplinkDataFrame:
Expand Down Expand Up @@ -978,7 +949,8 @@ def deterministic_link(self) -> SplinkDataFrame:
# to set the cluster threshold to 1
self._deterministic_link_mode = True

pipeline = self._enqueue_df_concat_with_tf(pipeline)
df_concat_with_tf = compute_df_concat_with_tf(self, pipeline)
pipeline = CTEPipeline([df_concat_with_tf], reusable=False)

exploding_br_with_id_tables = materialise_exploded_id_tables(self)

Expand Down Expand Up @@ -1321,9 +1293,16 @@ def predict(
# calls predict, it runs as a single pipeline with no materialisation
# of anything.

pipeline = self._enqueue_df_concat_with_tf(
pipeline, materialise=materialise_after_computing_term_frequencies
)
# In duckdb, calls to random() in a CTE pipeline cause problems:
# https://gist.github.com/RobinL/d329e7004998503ce91b68479aa41139
if (
materialise_after_computing_term_frequencies
or self._sql_dialect == "duckdb"
):
df_concat_with_tf = compute_df_concat_with_tf(self, pipeline)
pipeline = CTEPipeline([df_concat_with_tf], reusable=False)
else:
pipeline = enqueue_df_concat_with_tf(self, pipeline)

# If exploded blocking rules exist, we need to materialise
# the tables of ID pairs
Expand Down
39 changes: 39 additions & 0 deletions splink/vertically_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import logging
from typing import TYPE_CHECKING

from .pipeline import CTEPipeline
from .splink_dataframe import SplinkDataFrame
from .term_frequencies import compute_all_term_frequencies_sqls

logger = logging.getLogger(__name__)

# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
Expand Down Expand Up @@ -74,3 +78,38 @@ def vertically_concatenate_sql(linker: Linker) -> str:
"""

return sql


def enqueue_df_concat_with_tf(linker: Linker, pipeline: CTEPipeline) -> CTEPipeline:

cache = linker._intermediate_table_cache
if "__splink__df_concat_with_tf" in cache:
nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf")
pipeline.append_input_dataframe(nodes_with_tf)
return pipeline

sql = vertically_concatenate_sql(linker)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(linker)
pipeline.enqueue_list_of_sqls(sqls)

return pipeline


def compute_df_concat_with_tf(linker: Linker, pipeline) -> SplinkDataFrame:
cache = linker._intermediate_table_cache
db_api = linker.db_api

if "__splink__df_concat_with_tf" in cache:
return cache.get_with_logging("__splink__df_concat_with_tf")

sql = vertically_concatenate_sql(linker)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(linker)
pipeline.enqueue_list_of_sqls(sqls)

nodes_with_tf = db_api.sql_pipeline_to_splink_dataframe(pipeline)
cache["__splink__df_concat_with_tf"] = nodes_with_tf
return nodes_with_tf
Loading