From cafbd37bd6567cdd68db022fa9e00ffe84a651f9 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 14 Dec 2020 19:25:31 +0000 Subject: [PATCH 1/2] first go at fixing --- splink/expectation_step.py | 62 ++++++++++++++++++++++---------------- splink/params.py | 4 +-- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/splink/expectation_step.py b/splink/expectation_step.py index dc81df5f71..a52591c472 100644 --- a/splink/expectation_step.py +++ b/splink/expectation_step.py @@ -22,12 +22,15 @@ from .params import Params from .check_types import check_types + @check_types -def run_expectation_step(df_with_gamma: DataFrame, - params: Params, - settings: dict, - spark: SparkSession, - compute_ll=False): +def run_expectation_step( + df_with_gamma: DataFrame, + params: Params, + settings: dict, + spark: SparkSession, + compute_ll=False, +): """Run the expectation step of the EM algorithm described in the fastlink paper: http://imai.fas.harvard.edu/research/files/linkage.pdf @@ -40,8 +43,7 @@ def run_expectation_step(df_with_gamma: DataFrame, Returns: DataFrame: Spark dataframe with a match_probability column - """ - + """ sql = _sql_gen_gamma_prob_columns(params, settings) @@ -63,6 +65,9 @@ def run_expectation_step(df_with_gamma: DataFrame, df_e = spark.sql(sql) df_e.createOrReplaceTempView("df_e") + + params.save_params_to_iteration_history() + return df_e @@ -80,7 +85,6 @@ def _sql_gen_gamma_prob_columns(params, settings, table_name="df_with_gamma"): case_statement = _sql_gen_gamma_case_when(gamma_str, match, params) case_statements[alias] = case_statement - # Column order for case statement. We want orig_col_l, orig_col_r, gamma_orig_col, prob_gamma_u, prob_gamma_m select_cols = OrderedDict() select_cols = _add_left_right(select_cols, settings["unique_id_column_name"]) @@ -104,21 +108,24 @@ def _sql_gen_gamma_prob_columns(params, settings, table_name="df_with_gamma"): select_cols["gamma_" + col_name] = "gamma_" + col_name - select_cols[f"prob_gamma_{col_name}_non_match"] = case_statements[f"prob_gamma_{col_name}_non_match"] - select_cols[f"prob_gamma_{col_name}_match"] = case_statements[f"prob_gamma_{col_name}_match"] + select_cols[f"prob_gamma_{col_name}_non_match"] = case_statements[ + f"prob_gamma_{col_name}_non_match" + ] + select_cols[f"prob_gamma_{col_name}_match"] = case_statements[ + f"prob_gamma_{col_name}_match" + ] - if settings["link_type"] == 'link_and_dedupe': + if settings["link_type"] == "link_and_dedupe": select_cols = _add_left_right(select_cols, "_source_table") for c in settings["additional_columns_to_retain"]: select_cols = _add_left_right(select_cols, c) - if 'blocking_rules' in settings: + if "blocking_rules" in settings: if len(settings["blocking_rules"]) > 0: - select_cols['match_key'] = 'match_key' - - select_expr = ", ".join(select_cols.values()) + select_cols["match_key"] = "match_key" + select_expr = ", ".join(select_cols.values()) sql = f""" -- We use case statements for these lookups rather than joins for performance and simplicity @@ -153,32 +160,33 @@ def _column_order_df_e_select_expr(settings, tf_adj_cols=False): select_cols["gamma_" + col_name] = "gamma_" + col_name if settings["retain_intermediate_calculation_columns"]: - select_cols[f"prob_gamma_{col_name}_non_match"] = f"prob_gamma_{col_name}_non_match" + select_cols[ + f"prob_gamma_{col_name}_non_match" + ] = f"prob_gamma_{col_name}_non_match" select_cols[f"prob_gamma_{col_name}_match"] = f"prob_gamma_{col_name}_match" if tf_adj_cols: if col["term_frequency_adjustments"]: - select_cols[col_name+"_tf_adj"] = col_name+"_tf_adj" - - + select_cols[col_name + "_tf_adj"] = col_name + "_tf_adj" - if settings["link_type"] == 'link_and_dedupe': + if settings["link_type"] == "link_and_dedupe": select_cols = _add_left_right(select_cols, "_source_table") for c in settings["additional_columns_to_retain"]: select_cols = _add_left_right(select_cols, c) - if 'blocking_rules' in settings: + if "blocking_rules" in settings: if len(settings["blocking_rules"]) > 0: - select_cols['match_key'] = 'match_key' + select_cols["match_key"] = "match_key" return ", ".join(select_cols.values()) + def _sql_gen_expected_match_prob(params, settings, table_name="df_with_gamma_probs"): gamma_cols = params._gamma_cols numerator = " * ".join([f"prob_{g}_match" for g in gamma_cols]) denom_part = " * ".join([f"prob_{g}_non_match" for g in gamma_cols]) - λ = params.params['λ'] + λ = params.params["λ"] castλ = f"cast({λ} as double)" castoneminusλ = f"cast({1-λ} as double)" match_prob_expression = f"({castλ} * {numerator})/(( {castλ} * {numerator}) + ({castoneminusλ} * {denom_part})) as match_probability" @@ -192,6 +200,7 @@ def _sql_gen_expected_match_prob(params, settings, table_name="df_with_gamma_pro return sql + def _case_when_col_alias(gamma_str, match): if match == 1: @@ -201,6 +210,7 @@ def _case_when_col_alias(gamma_str, match): return f"prob_{gamma_str}{name_suffix}" + def _sql_gen_gamma_case_when(gamma_str, match, params): """ Create the case statements that look up the correct probabilities in the @@ -217,8 +227,8 @@ def _sql_gen_gamma_case_when(gamma_str, match, params): case_statements = [] case_statements.append(f"WHEN {gamma_str} = -1 THEN cast(1 as double)") for level in levels.values(): - case_stmt = f"when {gamma_str} = {level['value']} then cast({level['probability']:.35f} as double)" - case_statements.append(case_stmt) + case_stmt = f"when {gamma_str} = {level['value']} then cast({level['probability']:.35f} as double)" + case_statements.append(case_stmt) case_statements = "\n".join(case_statements) @@ -238,7 +248,7 @@ def _calculate_log_likelihood_df(df_with_gamma_probs, params, spark): gamma_cols = params._gamma_cols - λ = params.params['λ'] + λ = params.params["λ"] match_prob = " * ".join([f"prob_{g}_match" for g in gamma_cols]) match_prob = f"({λ} * {match_prob})" diff --git a/splink/params.py b/splink/params.py index 0d4cb41312..2f17706f06 100644 --- a/splink/params.py +++ b/splink/params.py @@ -297,7 +297,7 @@ def _reset_param_values_to_none(self): ].values(): level_value["probability"] = None - def _save_params_to_iteration_history(self): + def save_params_to_iteration_history(self): """ Take current params and """ @@ -357,7 +357,7 @@ def _update_params(self, lambda_value, pi_df_collected): Reset values Then update the parameters from the dataframe """ - self._save_params_to_iteration_history() + self._reset_param_values_to_none() self._populate_params(lambda_value, pi_df_collected) self.iteration += 1 From 973c6e6cb8df343265832ce1747cecbf9c46dcd3 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 14 Dec 2020 19:43:16 +0000 Subject: [PATCH 2/2] fix problem with iteration history --- pyproject.toml | 2 +- splink/expectation_step.py | 1 + splink/params.py | 12 ++-- tests/test_params.py | 119 +++++++++++++++++++++++++------------ 4 files changed, 87 insertions(+), 47 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index da2c6258a3..c715d3ad45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "0.3.8" +version = "0.3.9" description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage." authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis"] license = "MIT" diff --git a/splink/expectation_step.py b/splink/expectation_step.py index a52591c472..5466edd3db 100644 --- a/splink/expectation_step.py +++ b/splink/expectation_step.py @@ -67,6 +67,7 @@ def run_expectation_step( df_e.createOrReplaceTempView("df_e") params.save_params_to_iteration_history() + params.iteration += 1 return df_e diff --git a/splink/params.py b/splink/params.py index 2f17706f06..9726ea01ba 100644 --- a/splink/params.py +++ b/splink/params.py @@ -53,7 +53,7 @@ def __init__(self, settings: dict, spark: SparkSession): self.param_history = [] - self.iteration = 1 + self.iteration = 0 self.settings_original = copy.deepcopy(settings) self.settings = complete_settings_dict(settings, spark) @@ -262,14 +262,14 @@ def _iteration_history_df_gammas(self): data = [] for it_num, param_value in enumerate(self.param_history): data.extend(self._convert_params_dict_to_dataframe(param_value, it_num)) - data.extend(self._convert_params_dict_to_dataframe(self.params, it_num + 1)) + return data def _iteration_history_df_lambdas(self): data = [] for it_num, param_value in enumerate(self.param_history): data.append({"λ": param_value["λ"], "iteration": it_num}) - data.append({"λ": self.params["λ"], "iteration": it_num + 1}) + return data def _iteration_history_df_log_likelihood(self): @@ -278,9 +278,7 @@ def _iteration_history_df_log_likelihood(self): data.append( {"log_likelihood": param_value["log_likelihood"], "iteration": it_num} ) - data.append( - {"log_likelihood": self.params["log_likelihood"], "iteration": it_num + 1} - ) + return data def _reset_param_values_to_none(self): @@ -360,7 +358,7 @@ def _update_params(self, lambda_value, pi_df_collected): self._reset_param_values_to_none() self._populate_params(lambda_value, pi_df_collected) - self.iteration += 1 + def _to_dict(self): p_dict = {} diff --git a/tests/test_params.py b/tests/test_params.py index 7948cffa0a..49f7f45d22 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -3,18 +3,18 @@ # Light testing at the moment. Focus on aspects that could break main algo -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def param_example(): gamma_settings = { - "link_type": "dedupe_only", - "proportion_of_matches": 0.2, - "comparison_columns": [ - {"col_name": "fname"}, - {"col_name": "sname", - "num_levels": 3} - ], - "blocking_rules": [] - } + "link_type": "dedupe_only", + "proportion_of_matches": 0.2, + "comparison_columns": [ + {"col_name": "fname"}, + {"col_name": "sname", "num_levels": 3}, + ], + "blocking_rules": [], + } params = Params(gamma_settings, spark="supress_warnings") @@ -28,7 +28,7 @@ def test_prob_sum_one(param_example): for m in ["prob_dist_match", "prob_dist_non_match"]: for g in ["gamma_fname", "gamma_sname"]: - levels = p["π"][g][m] + levels = p["π"][g][m] total = 0 for l in levels: @@ -36,41 +36,78 @@ def test_prob_sum_one(param_example): assert total == pytest.approx(1.0) -def test_update(param_example): +def test_update(param_example): pi_df_collected = [ - {'gamma_value': 1, 'new_probability_match': 0.9, 'new_probability_non_match': 0.1, 'gamma_col': 'gamma_fname'}, - {'gamma_value': 0, 'new_probability_match': 0.2, 'new_probability_non_match': 0.8, 'gamma_col': 'gamma_fname'}, - {'gamma_value': 1, 'new_probability_match': 0.9, 'new_probability_non_match': 0.1, 'gamma_col': 'gamma_sname'}, - {'gamma_value': 2, 'new_probability_match': 0.7, 'new_probability_non_match': 0.3, 'gamma_col': 'gamma_sname'}, - {'gamma_value': 0, 'new_probability_match': 0.5, 'new_probability_non_match': 0.5, 'gamma_col': 'gamma_sname'}] - - param_example._save_params_to_iteration_history() + { + "gamma_value": 1, + "new_probability_match": 0.9, + "new_probability_non_match": 0.1, + "gamma_col": "gamma_fname", + }, + { + "gamma_value": 0, + "new_probability_match": 0.2, + "new_probability_non_match": 0.8, + "gamma_col": "gamma_fname", + }, + { + "gamma_value": 1, + "new_probability_match": 0.9, + "new_probability_non_match": 0.1, + "gamma_col": "gamma_sname", + }, + { + "gamma_value": 2, + "new_probability_match": 0.7, + "new_probability_non_match": 0.3, + "gamma_col": "gamma_sname", + }, + { + "gamma_value": 0, + "new_probability_match": 0.5, + "new_probability_non_match": 0.5, + "gamma_col": "gamma_sname", + }, + ] + + param_example.save_params_to_iteration_history() param_example._reset_param_values_to_none() - assert param_example.params["π"]["gamma_fname"]["prob_dist_match"]["level_0"]["probability"] is None + assert ( + param_example.params["π"]["gamma_fname"]["prob_dist_match"]["level_0"][ + "probability" + ] + is None + ) param_example._populate_params(0.2, pi_df_collected) new_params = param_example.params - assert new_params["π"]["gamma_fname"]["prob_dist_match"]["level_0"]["probability"] == 0.2 - assert new_params["π"]["gamma_fname"]["prob_dist_non_match"]["level_0"]["probability"] == 0.8 + assert ( + new_params["π"]["gamma_fname"]["prob_dist_match"]["level_0"]["probability"] + == 0.2 + ) + assert ( + new_params["π"]["gamma_fname"]["prob_dist_non_match"]["level_0"]["probability"] + == 0.8 + ) + def test_update_settings(): - + old_settings = { "link_type": "dedupe_only", "proportion_of_matches": 0.2, "comparison_columns": [ {"col_name": "fname"}, - {"col_name": "sname", - "num_levels": 3} + {"col_name": "sname", "num_levels": 3}, ], - "blocking_rules": [] + "blocking_rules": [], } params = Params(old_settings, spark="supress_warnings") - + new_settings = { "link_type": "dedupe_only", "blocking_rules": [], @@ -78,8 +115,8 @@ def test_update_settings(): { "col_name": "fname", "num_levels": 3, - "m_probabilities": [0.02,0.03,0.95], - "u_probabilities": [0.92,0.05,0.03] + "m_probabilities": [0.02, 0.03, 0.95], + "u_probabilities": [0.92, 0.05, 0.03], }, { "custom_name": "sname", @@ -89,17 +126,21 @@ def test_update_settings(): case when concat(fname_l, sname_l) = concat(fname_r, sname_r) then 1 else 0 end """, - "m_probabilities": [0.01,0.02,0.97], - "u_probabilities": [0.9,0.05,0.05] + "m_probabilities": [0.01, 0.02, 0.97], + "u_probabilities": [0.9, 0.05, 0.05], }, - {"col_name": "dob"} - ] + {"col_name": "dob"}, + ], } - + update = get_or_update_settings(params, new_settings) - + # new settings used due to num_levels mismatch - assert update["comparison_columns"][0]["m_probabilities"] == new_settings["comparison_columns"][0]["m_probabilities"] - # new settings updated with old settings - assert update["comparison_columns"][1]["u_probabilities"] == pytest.approx(params.settings["comparison_columns"][1]["u_probabilities"]) - \ No newline at end of file + assert ( + update["comparison_columns"][0]["m_probabilities"] + == new_settings["comparison_columns"][0]["m_probabilities"] + ) + # new settings updated with old settings + assert update["comparison_columns"][1]["u_probabilities"] == pytest.approx( + params.settings["comparison_columns"][1]["u_probabilities"] + )