Skip to content

Commit

Permalink
Merge 73875d5 into b15d1f2
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Oct 27, 2021
2 parents b15d1f2 + 73875d5 commit f140016
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 30 deletions.
17 changes: 3 additions & 14 deletions splink/__init__.py
Expand Up @@ -19,10 +19,7 @@
from splink.vertically_concat import (
vertically_concatenate_datasets,
)
from splink.break_lineage import (
default_break_lineage_blocked_comparisons,
default_break_lineage_scored_comparisons,
)
from splink.break_lineage import default_break_lineage_blocked_comparisons

from splink.default_settings import normalise_probabilities

Expand All @@ -36,7 +33,6 @@ def __init__(
spark: SparkSession,
save_state_fn: Callable = None,
break_lineage_blocked_comparisons: Callable = default_break_lineage_blocked_comparisons,
break_lineage_scored_comparisons: Callable = default_break_lineage_scored_comparisons,
):
"""Splink data linker
Expand All @@ -57,15 +53,12 @@ def __init__(
unless the lineage is broken after blocking. This is a user-provided function that takes one argument
- df - and allows the user to break lineage. For example, the function might save df to the AWS s3
file system, and then reload it from the saved files.
break_lineage_scored_comparisons (function, optional): Large jobs will likely run into memory errors unless
the lineage is broken after comparisons are scored and before term frequency adjustments. This is a
user-provided function that takes one argument - df - and allows the user to break lineage. For
example, the function might save df to the AWS s3 file system, and then reload it from the saved files.
"""

self.spark = spark
self.break_lineage_blocked_comparisons = break_lineage_blocked_comparisons
self.break_lineage_scored_comparisons = break_lineage_scored_comparisons

_check_jaro_registered(spark)

validate_settings_against_schema(settings)
Expand Down Expand Up @@ -121,10 +114,6 @@ def get_scored_comparisons(self):
# In case the user's break lineage function has persisted it
df_gammas.unpersist()

df_e = self.break_lineage_scored_comparisons(df_e, self.spark)

df_e.unpersist()

return df_e

def save_model_as_json(self, path: str, overwrite=False):
Expand Down
6 changes: 0 additions & 6 deletions splink/break_lineage.py
Expand Up @@ -30,9 +30,3 @@ def default_break_lineage_blocked_comparisons(df_gammas, spark):
df_gammas = cutLineage(df_gammas)
df_gammas.persist()
return df_gammas


def default_break_lineage_scored_comparisons(df_e, spark):
df_e = cutLineage(df_e)
df_e.persist()
return df_e
4 changes: 1 addition & 3 deletions splink/expectation_step.py
Expand Up @@ -89,7 +89,6 @@ def _sql_gen_gamma_bf_columns(
select_cols = _add_left_right(select_cols, col_name)

if col["term_frequency_adjustments"]:
select_cols = _add_left_right(select_cols, cc.name)
select_cols.add(case_statements[f"bf_tf_adj_{cc.name}"])

select_cols.add("gamma_" + cc.name)
Expand Down Expand Up @@ -131,8 +130,6 @@ def _column_order_df_e_select_expr(
if settings["retain_matching_columns"]:
for col_name in cc.columns_used:
select_cols = _add_left_right(select_cols, col_name)
if col["term_frequency_adjustments"]:
select_cols = _add_left_right(select_cols, cc.name)

select_cols.add(f"gamma_{cc.name}")

Expand All @@ -150,6 +147,7 @@ def _column_order_df_e_select_expr(
if "blocking_rules" in settings:
if len(settings["blocking_rules"]) > 1:
select_cols.add("match_key")

return ", ".join(select_cols)


Expand Down
19 changes: 12 additions & 7 deletions splink/gammas.py
Expand Up @@ -32,7 +32,9 @@ def _add_unique_id_and_source_dataset(
return cols_set


def _get_select_expression_gammas(settings: dict, retain_source_dataset_col: bool, retain_tf_cols: bool):
def _get_select_expression_gammas(
settings: dict, retain_source_dataset_col: bool, retain_tf_cols: bool
):
"""Get a select expression which picks which columns to keep in df_gammas
Args:
Expand Down Expand Up @@ -60,7 +62,6 @@ def _get_select_expression_gammas(settings: dict, retain_source_dataset_col: boo
for col_name in cc.columns_used:
select_columns = _add_left_right(select_columns, col_name)
if col["term_frequency_adjustments"]:
select_columns = _add_left_right(select_columns, cc.name)
if retain_tf_cols:
select_columns = _add_left_right(select_columns, f"tf_{cc.name}")
select_columns.add(col["case_expression"])
Expand Down Expand Up @@ -90,20 +91,22 @@ def _retain_source_dataset_column(settings_dict, df):
else:
return False


def _retain_tf_columns(settings_dict, df):
# If all necessary TF columns are in the data,
# make sure they are retained
tf_cols = [
f"tf_{cc['col_name']}" if "col_name" in cc else f"tf_{cc['custom_name']}"
for cc in settings_dict["comparison_columns"]
if cc["term_frequency_adjustments"]
f"tf_{cc['col_name']}" if "col_name" in cc else f"tf_{cc['custom_name']}"
for cc in settings_dict["comparison_columns"]
if cc["term_frequency_adjustments"]
]

cols = OrderedSet()
[_add_left_right(cols, c) for c in tf_cols]

return all([col in df.columns for col in cols])


def _sql_gen_add_gammas(
settings: dict,
df_comparison: DataFrame,
Expand All @@ -122,7 +125,9 @@ def _sql_gen_add_gammas(

retain_source_dataset = _retain_source_dataset_column(settings, df_comparison)
retain_tf_cols = _retain_tf_columns(settings, df_comparison)
select_cols_expr = _get_select_expression_gammas(settings, retain_source_dataset, retain_tf_cols)
select_cols_expr = _get_select_expression_gammas(
settings, retain_source_dataset, retain_tf_cols
)

sql = f"""
select {select_cols_expr}
Expand Down

0 comments on commit f140016

Please sign in to comment.