From 6777de589d1200c6a0d067fc4d31f1e2d5105274 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sat, 13 Mar 2021 19:51:15 +0000 Subject: [PATCH 1/9] start fix --- splink/default_settings.py | 103 +++++++++++++++++-------------------- splink/settings.py | 4 +- 2 files changed, 49 insertions(+), 58 deletions(-) diff --git a/splink/default_settings.py b/splink/default_settings.py index fdf32ca6c9..48c6aec110 100644 --- a/splink/default_settings.py +++ b/splink/default_settings.py @@ -4,6 +4,7 @@ from copy import deepcopy + from .validate import validate_settings, _get_default_value from .case_statements import ( _check_jaro_registered, @@ -120,48 +121,55 @@ def _complete_case_expression(col_settings, spark): col_settings["case_expression"] = new_case_stmt -def _complete_probabilities(col_settings: dict, setting_name: str): +def _complete_probabilities(cc, setting_name: str): """ Args: - col_settings (dict): Column settings dictionary + cc (ComparisonColumn): ComparisonColumn setting_name (str): Either 'm_probabilities' or 'u_probabilities' """ - if setting_name not in col_settings: - levels = col_settings["num_levels"] + if setting_name not in cc.column_dict: + levels = cc["num_levels"] probs = _get_default_probabilities(setting_name, levels) - col_settings[setting_name] = probs - else: - levels = col_settings["num_levels"] - probs = col_settings[setting_name] + cc.column_dict[setting_name] = probs - # Check for m and u manually set to zero (https://github.com/moj-analytical-services/splink/issues/161) - if not all(col_settings[setting_name]): - if "custom_name" in col_settings: - col_name = col_settings["custom_name"] - else: - col_name = col_settings["col_name"] + # Normalise probabilities if possible + if None in cc.column_dict[setting_name]: + warnings.warn( + "Your m probabilities contain a None value " + "so could not be normalised to 1" + ) + elif sum(cc.column_dict[setting_name]) == 0: + raise ValueError( + f"Your {setting_name} for {cc.name } sum to zero and cannot be used " + "They should sum to 1" + ) + else: + cc.column_dict[setting_name] = _normalise_prob_list( + cc.column_dict[setting_name] + ) - if setting_name == "m_probabilities": - letter = "m" - elif setting_name == "u_probabilities": - letter = "u" + # Check for m and u manually set to zero (https://github.com/moj-analytical-services/splink/issues/161) + if 0 in cc.column_dict[setting_name]: - warnings.warn( - f"Your {setting_name} for {col_name} include zeroes. " - f"Where {letter}=0 for a given level, it remains fixed rather than being estimated " - "along with other model parameters, and all comparisons at this level " - f"are assigned a match score of {1. if letter=='u' else 0.}, regardless of other comparisons columns." - ) + if setting_name == "m_probabilities": + letter = "m" + elif setting_name == "u_probabilities": + letter = "u" - if len(probs) != levels: - raise ValueError( - f"Number of {setting_name} provided is not equal to number of levels specified" - ) + warnings.warn( + f"Your {setting_name} for {cc.name} include zeroes. " + f"Where {letter}=0 for a given level, it remains fixed rather than being estimated " + "along with other model parameters, and all comparisons at this level " + f"are assigned a match score of {1. if letter=='u' else 0.}, regardless of other comparisons columns." + ) - col_settings[setting_name] = col_settings[setting_name] + if len(probs) != levels: + raise ValueError( + f"Number of {setting_name} provided is not equal to number of levels specified" + ) def complete_settings_dict(settings_dict: dict, spark: SparkSession): @@ -203,9 +211,12 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession): "because it will generate comparisons equal to the number of rows squared." ) - c_cols = settings_dict["comparison_columns"] - for gamma_index, col_settings in enumerate(c_cols): + settings_obj = Settings(settings_dict) + # c_cols = settings_dict["comparison_columns"] + for gamma_index, cc in enumerate(settings_obj.comparison_columns_list): + # Gamma index refers to the position in the comparison vector + # i.e. it's a counter for comparison columns col_settings["gamma_index"] = gamma_index # Populate non-existing keys from defaults @@ -218,33 +229,13 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession): ] for key in keys_for_defaults: - if key not in col_settings: + if key not in cc.column_dict: default = _get_default_value(key, is_column_setting=True) - col_settings[key] = default + cc.column_dict[key] = default # Doesn't need assignment because we're modify the col_settings dictionary - _complete_case_expression(col_settings, spark) - _complete_probabilities(col_settings, "m_probabilities") - _complete_probabilities(col_settings, "u_probabilities") - - if None not in col_settings["m_probabilities"]: - col_settings["m_probabilities"] = _normalise_prob_list( - col_settings["m_probabilities"] - ) - else: - warnings.warn( - "Your m probabilities contain a None value " - "so could not be normalised to 1" - ) - - if None not in col_settings["u_probabilities"]: - col_settings["u_probabilities"] = _normalise_prob_list( - col_settings["u_probabilities"] - ) - else: - warnings.warn( - "Your u probabilities contain a None value " - "so could not be normalised to 1" - ) + _complete_case_expression(cc.column_dict, spark) + _complete_probabilities(cc.columns_dict, "m_probabilities") + _complete_probabilities(cc.column_dict, "u_probabilities") return settings_dict diff --git a/splink/settings.py b/splink/settings.py index c21cdef288..47887b0ee7 100644 --- a/splink/settings.py +++ b/splink/settings.py @@ -127,11 +127,11 @@ def reset_probabilities(self, force: bool = False): fixed_u = self._dict_key_else_default_value("fix_u_probabilities") if not fixed_m or force: if "m_probabilities" in cd: - cd["m_probabilities"] = [0 for c in cd["m_probabilities"]] + cd["m_probabilities"] = [None for c in cd["m_probabilities"]] if not fixed_u or force: if "u_probabilities" in cd: - cd["u_probabilities"] = [0 for c in cd["u_probabilities"]] + cd["u_probabilities"] = [None for c in cd["u_probabilities"]] def level_as_dict(self, gamma_index, proportion_of_matches=None): From 796ad83d9d02bb5b7dff96b40aef521ce086cc36 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sat, 13 Mar 2021 21:14:00 +0000 Subject: [PATCH 2/9] demos work --- splink/__init__.py | 11 +++- splink/default_settings.py | 112 ++++++++++++++++--------------------- splink/estimate.py | 5 +- splink/expectation_step.py | 10 +++- splink/model.py | 5 +- splink/settings.py | 4 +- splink/term_frequencies.py | 8 ++- splink/validate.py | 37 +++++++++++- 8 files changed, 112 insertions(+), 80 deletions(-) diff --git a/splink/__init__.py b/splink/__init__.py index 5e192d2af6..420a76382f 100644 --- a/splink/__init__.py +++ b/splink/__init__.py @@ -4,9 +4,10 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.session import SparkSession from splink.validate import ( - validate_settings, + validate_settings_against_schema, validate_input_datasets, validate_link_type, + validate_probabilities, ) from splink.model import Model, load_model_from_json from splink.case_statements import _check_jaro_registered @@ -23,6 +24,8 @@ default_break_lineage_scored_comparisons, ) +from splink.default_settings import normalise_probabilities + @typechecked class Splink: @@ -53,11 +56,13 @@ def __init__( self.break_lineage_scored_comparisons = break_lineage_scored_comparisons _check_jaro_registered(spark) - validate_settings(settings) + validate_settings_against_schema(settings) validate_link_type(df_or_dfs, settings) + self.model = Model(settings, spark) self.settings_dict = self.model.current_settings_obj.settings_dict - + self.settings_dict = normalise_probabilities(self.settings_dict) + validate_probabilities(self.settings_dict) # dfs is a list of dfs irrespective of whether input was a df or list of dfs if type(df_or_dfs) == DataFrame: dfs = [df_or_dfs] diff --git a/splink/default_settings.py b/splink/default_settings.py index 48c6aec110..9c7d4df843 100644 --- a/splink/default_settings.py +++ b/splink/default_settings.py @@ -4,8 +4,8 @@ from copy import deepcopy +from .validate import get_default_value_from_schema -from .validate import validate_settings, _get_default_value from .case_statements import ( _check_jaro_registered, sql_gen_case_smnt_strict_equality_2, @@ -67,9 +67,9 @@ def _get_default_case_statement_fn(default_statements, data_type, levels): def _get_default_probabilities(m_or_u, levels): - if levels > 5: + if levels > 6: raise ValueError( - f"No default m and u probabilities available when levels > 4, " + f"No default m and u probabilities available when levels > 6, " "please specify custom values for 'm_probabilities' and 'u_probabilities' " "within your settings dictionary" ) @@ -77,16 +77,18 @@ def _get_default_probabilities(m_or_u, levels): # Note all m and u probabilities are automatically normalised to sum to 1 default_m_u_probabilities = { "m_probabilities": { - 2: [1, 9], - 3: [1, 2, 7], - 4: [1, 1, 1, 7], - 5: [0.33, 0.67, 1, 2, 6], + 2: _normalise_prob_list([1, 9]), + 3: _normalise_prob_list([1, 2, 7]), + 4: _normalise_prob_list([1, 1, 1, 7]), + 5: _normalise_prob_list([0.33, 0.67, 1, 2, 6]), + 6: _normalise_prob_list([0.33, 0.67, 1, 2, 3, 6]), }, "u_probabilities": { - 2: [9, 1], - 3: [7, 2, 1], - 4: [7, 1, 1, 1], - 5: [6, 2, 1, 0.33, 0.67], + 2: _normalise_prob_list([9, 1]), + 3: _normalise_prob_list([7, 2, 1]), + 4: _normalise_prob_list([7, 1, 1, 1]), + 5: _normalise_prob_list([6, 2, 1, 0.33, 0.67]), + 6: _normalise_prob_list([6, 3, 2, 1, 0.33, 0.67]), }, } @@ -121,55 +123,19 @@ def _complete_case_expression(col_settings, spark): col_settings["case_expression"] = new_case_stmt -def _complete_probabilities(cc, setting_name: str): +def _complete_probabilities(col_settings: dict, mu_probabilities: str): """ Args: - cc (ComparisonColumn): ComparisonColumn - setting_name (str): Either 'm_probabilities' or 'u_probabilities' + col_settings (dict): Column settings dictionary + mu_probabilities (str): Either 'm_probabilities' or 'u_probabilities' """ - if setting_name not in cc.column_dict: - levels = cc["num_levels"] - probs = _get_default_probabilities(setting_name, levels) - cc.column_dict[setting_name] = probs - - # Normalise probabilities if possible - if None in cc.column_dict[setting_name]: - warnings.warn( - "Your m probabilities contain a None value " - "so could not be normalised to 1" - ) - elif sum(cc.column_dict[setting_name]) == 0: - raise ValueError( - f"Your {setting_name} for {cc.name } sum to zero and cannot be used " - "They should sum to 1" - ) - else: - cc.column_dict[setting_name] = _normalise_prob_list( - cc.column_dict[setting_name] - ) - - # Check for m and u manually set to zero (https://github.com/moj-analytical-services/splink/issues/161) - if 0 in cc.column_dict[setting_name]: - - if setting_name == "m_probabilities": - letter = "m" - elif setting_name == "u_probabilities": - letter = "u" - - warnings.warn( - f"Your {setting_name} for {cc.name} include zeroes. " - f"Where {letter}=0 for a given level, it remains fixed rather than being estimated " - "along with other model parameters, and all comparisons at this level " - f"are assigned a match score of {1. if letter=='u' else 0.}, regardless of other comparisons columns." - ) - - if len(probs) != levels: - raise ValueError( - f"Number of {setting_name} provided is not equal to number of levels specified" - ) + if mu_probabilities not in col_settings: + levels = col_settings["num_levels"] + probs = _get_default_probabilities(mu_probabilities, levels) + col_settings[mu_probabilities] = probs def complete_settings_dict(settings_dict: dict, spark: SparkSession): @@ -184,7 +150,6 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession): dict: A `splink` settings dictionary with all keys populated. """ settings_dict = deepcopy(settings_dict) - validate_settings(settings_dict) # Complete non-column settings from their default values if not exist non_col_keys = [ @@ -200,7 +165,9 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession): ] for key in non_col_keys: if key not in settings_dict: - settings_dict[key] = _get_default_value(key, is_column_setting=False) + settings_dict[key] = get_default_value_from_schema( + key, is_column_setting=False + ) if "blocking_rules" in settings_dict: if len(settings_dict["blocking_rules"]) == 0: @@ -211,9 +178,8 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession): "because it will generate comparisons equal to the number of rows squared." ) - settings_obj = Settings(settings_dict) - # c_cols = settings_dict["comparison_columns"] - for gamma_index, cc in enumerate(settings_obj.comparison_columns_list): + c_cols = settings_dict["comparison_columns"] + for gamma_index, col_settings in enumerate(c_cols): # Gamma index refers to the position in the comparison vector # i.e. it's a counter for comparison columns @@ -229,13 +195,29 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession): ] for key in keys_for_defaults: - if key not in cc.column_dict: - default = _get_default_value(key, is_column_setting=True) - cc.column_dict[key] = default + if key not in col_settings: + default = get_default_value_from_schema(key, is_column_setting=True) + col_settings[key] = default # Doesn't need assignment because we're modify the col_settings dictionary - _complete_case_expression(cc.column_dict, spark) - _complete_probabilities(cc.columns_dict, "m_probabilities") - _complete_probabilities(cc.column_dict, "u_probabilities") + _complete_case_expression(col_settings, spark) + _complete_probabilities(col_settings, "m_probabilities") + _complete_probabilities(col_settings, "u_probabilities") return settings_dict + + +def normalise_probabilities(settings_dict: dict): + """Normalise all probabilities in a settings dictionary to sum + to one, of possible + + Args: + settings_dict (dict): Splink settings dictionary + """ + + c_cols = settings_dict["comparison_columns"] + for col_settings in c_cols: + for p in ["m_probabilities", "u_probabilities"]: + if p in col_settings: + col_settings[p] = _normalise_prob_list(col_settings[p]) + return settings_dict \ No newline at end of file diff --git a/splink/estimate.py b/splink/estimate.py index 0ded653d0f..3e9d36d056 100644 --- a/splink/estimate.py +++ b/splink/estimate.py @@ -1,4 +1,5 @@ from copy import deepcopy +from splink.default_settings import normalise_probabilities, _normalise_prob_list from .blocking import block_using_rules from .gammas import add_gammas @@ -109,9 +110,9 @@ def estimate_u_values( u_probs = new_settings["comparison_columns"][i]["u_probabilities"] # Ensure non-zero u (https://github.com/moj-analytical-services/splink/issues/161) u_probs = [u or 1 / target_rows for u in u_probs] - + u_probs = _normalise_prob_list(u_probs) col["u_probabilities"] = u_probs if fix_u_probabilities: col["fix_u_probabilities"] = True - return orig_settings + return orig_settings diff --git a/splink/expectation_step.py b/splink/expectation_step.py index a2a83267ba..092aba0c39 100644 --- a/splink/expectation_step.py +++ b/splink/expectation_step.py @@ -213,9 +213,13 @@ def _sql_gen_gamma_case_when(comparison_column, match): case_statements.append(f"WHEN {cc.gamma_name} = -1 THEN cast(1 as double)") for gamma_index, prob in enumerate(probs): - case_stmt = ( - f"when {cc.gamma_name} = {gamma_index} then cast({prob:.35f} as double)" - ) + if prob is not None: + case_stmt = ( + f"when {cc.gamma_name} = {gamma_index} then cast({prob:.35f} as double)" + ) + else: + case_stmt = f"when {cc.gamma_name} = {gamma_index} then null" + case_statements.append(case_stmt) case_statements = "\n".join(case_statements) diff --git a/splink/model.py b/splink/model.py index 81878d3eb5..c8dea4661a 100644 --- a/splink/model.py +++ b/splink/model.py @@ -95,7 +95,10 @@ def is_converged(self): for gamma_index in range(c_latest.num_levels): val_latest = c_latest[m_or_u][gamma_index] val_previous = c_previous[m_or_u][gamma_index] - diff = abs(val_latest - val_previous) + if val_latest is not None: + diff = abs(val_latest - val_previous) + else: + diff = 0 diffs.append( {"col_name": c_latest.name, "diff": diff, "level": gamma_index} ) diff --git a/splink/settings.py b/splink/settings.py index 47887b0ee7..6f4c8d513a 100644 --- a/splink/settings.py +++ b/splink/settings.py @@ -1,5 +1,5 @@ from .default_settings import complete_settings_dict -from .validate import _get_default_value +from .validate import get_default_value_from_schema from copy import deepcopy from math import log2 from .charts import load_chart_definition, altair_if_installed_else_json @@ -18,7 +18,7 @@ def _dict_key_else_default_value(self, key): if key in cd: return cd[key] else: - return _get_default_value(key, True) + return get_default_value_from_schema(key, True) @property def custom_comparison(self): diff --git a/splink/term_frequencies.py b/splink/term_frequencies.py index 4863669928..2499ef5b7b 100644 --- a/splink/term_frequencies.py +++ b/splink/term_frequencies.py @@ -59,11 +59,17 @@ def sql_gen_generate_adjusted_lambda(column_name, model, table_name="df_e"): u = cc["u_probabilities"][max_level] # ensure average adj calculation doesnt divide by zero (see issue 118) - if math.isclose((m + u), 0.0, rel_tol=1e-9, abs_tol=0.0): + + is_none = m is None or u is None + + no_adjust = is_none or math.isclose((m + u), 0.0, rel_tol=1e-9, abs_tol=0.0) + + if no_adjust: average_adjustment = 0.5 warnings.warn( f"There were no comparisons in column {column_name} which were in the highest level of similarity, so no adjustment could be made" ) + else: average_adjustment = m / (m + u) diff --git a/splink/validate.py b/splink/validate.py index e261d36652..89ea5f299d 100644 --- a/splink/validate.py +++ b/splink/validate.py @@ -1,6 +1,6 @@ import pkg_resources from functools import lru_cache - +import math from jsonschema import validate, ValidationError import json @@ -39,7 +39,7 @@ def _get_schema(setting_dict_should_be_complete=False): @typechecked -def validate_settings(settings_dict: dict): +def validate_settings_against_schema(settings_dict: dict): """Validate a splink settings object against its jsonschema Args: @@ -75,7 +75,7 @@ def validate_settings(settings_dict: dict): raise ValidationError(message) -def _get_default_value(key, is_column_setting): +def get_default_value_from_schema(key, is_column_setting): schema = _get_schema() if is_column_setting: @@ -123,3 +123,34 @@ def validate_link_type(df_or_dfs, settings): "If you provide a list of dfs, link_type must be " "link_only or link_and_dedupe, not dedupe_only" ) + + +def validate_probabilities(settings_dict): + from .settings import Settings + + settings_obj = Settings(settings_dict) + + for cc in settings_obj.comparison_columns_list: + + for mu_probabilities in ["m_probabilities", "u_probabilities"]: + + if mu_probabilities in cc.column_dict: + if None in cc[mu_probabilities]: + raise ValueError( + f"Your {mu_probabilities} for {cc.name} contain None " + "They should all be populated and sum to 1" + ) + + sum_p = sum(cc[mu_probabilities]) + + if not math.isclose(sum_p, 1.0, rel_tol=1e-9, abs_tol=0.0): + raise ValueError( + f"Your {mu_probabilities} for {cc.name} do not sum to 1 " + "They should all be populated and sum to 1" + ) + + if len(cc[mu_probabilities]) != cc["num_levels"]: + raise ValueError( + f"Number of probs provided in {mu_probabilities} in {cc.name} " + "is not equal to number of levels specified" + ) From a81e33c92edcc8b0dcc7a0cbbbf7b6f80e30d056 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sat, 13 Mar 2021 21:19:15 +0000 Subject: [PATCH 3/9] demos work --- splink/default_settings.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/splink/default_settings.py b/splink/default_settings.py index 9c7d4df843..726057f987 100644 --- a/splink/default_settings.py +++ b/splink/default_settings.py @@ -219,5 +219,7 @@ def normalise_probabilities(settings_dict: dict): for col_settings in c_cols: for p in ["m_probabilities", "u_probabilities"]: if p in col_settings: - col_settings[p] = _normalise_prob_list(col_settings[p]) - return settings_dict \ No newline at end of file + if None not in col_settings[p]: + if sum(col_settings[p]) != 0: + col_settings[p] = _normalise_prob_list(col_settings[p]) + return settings_dict From a37d6a92a5e73f1133c9e9c725df44c59836909d Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sun, 14 Mar 2021 16:44:40 +0000 Subject: [PATCH 4/9] lineage --- splink/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/splink/__init__.py b/splink/__init__.py index 420a76382f..782da64668 100644 --- a/splink/__init__.py +++ b/splink/__init__.py @@ -81,6 +81,8 @@ def manually_apply_fellegi_sunter_weights(self): """ df_comparison = block_using_rules(self.settings_dict, self.df, self.spark) df_gammas = add_gammas(df_comparison, self.settings_dict, self.spark) + # see https://github.com/moj-analytical-services/splink/issues/187 + df_gammas = self.break_lineage_blocked_comparisons(df_gammas, self.spark) return run_expectation_step(df_gammas, self.model, self.spark) def get_scored_comparisons(self, compute_ll=False): From 1cf098bd214705d58106ab990606e74623015532 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sun, 14 Mar 2021 16:48:14 +0000 Subject: [PATCH 5/9] fix chart --- splink/settings.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/splink/settings.py b/splink/settings.py index 6f4c8d513a..f7aa8d1d9b 100644 --- a/splink/settings.py +++ b/splink/settings.py @@ -150,7 +150,11 @@ def level_as_dict(self, gamma_index, proportion_of_matches=None): lam = proportion_of_matches m = d["m_probability"] u = d["u_probability"] - d["level_proportion"] = m * lam + u * (1 - lam) + # Check they both not None + if m and u: + d["level_proportion"] = m * lam + u * (1 - lam) + else: + d["level_proportion"] = None return d From d50e8ce012a4a9ad67b6ca6a481eb41683c8a1fa Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 16 Mar 2021 17:05:53 +0000 Subject: [PATCH 6/9] fix tests and aggregation --- splink/combine_models.py | 14 ++++++++++++-- tests/test_combine_estimates.py | 3 ++- tests/test_model.py | 2 +- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/splink/combine_models.py b/splink/combine_models.py index 0ecaed30e0..b0e39af17d 100644 --- a/splink/combine_models.py +++ b/splink/combine_models.py @@ -16,8 +16,6 @@ def _apply_aggregate_function(zipped_probs, aggregate_function): "The aggregation function produced an error when " f"operating on the following data: {probs_list}. " "The result of this aggreation has been set to None. " - "You may wish to provide a aggreation function that is robust to nulls " - "or check why there's a None in your parameter estimates. " f"The error was {e}" ) reduced = None @@ -33,6 +31,13 @@ def _format_probs_for_report(probs): return f"{probs_as_strings}" +def _filter_nones(list_of_lists): + def filter_none(sublist): + return [item for item in sublist if item is not None] + + return [filter_none(sublist) for sublist in list_of_lists] + + def _zip_m_and_u_probabilities(cc_estimates: list): """Groups together the different estimates of the same parameter. @@ -47,7 +52,12 @@ def _zip_m_and_u_probabilities(cc_estimates: list): """ zipped_m_probs = zip(*[cc["m_probabilities"] for cc in cc_estimates]) + zipped_m_probs = _filter_nones(zipped_m_probs) zipped_u_probs = zip(*[cc["u_probabilities"] for cc in cc_estimates]) + zipped_u_probs = _filter_nones(zipped_u_probs) + + print({"zipped_m": zipped_m_probs, "zipped_u": zipped_u_probs}) + return {"zipped_m": zipped_m_probs, "zipped_u": zipped_u_probs} diff --git a/tests/test_combine_estimates.py b/tests/test_combine_estimates.py index 832d288b10..a44319dfa8 100644 --- a/tests/test_combine_estimates.py +++ b/tests/test_combine_estimates.py @@ -190,11 +190,12 @@ def test_average_calc_m_u(spark): # "comparison_columns_for_global_lambda": [dob_cc], } - mc = ModelCombiner([dict1, dict2, dict3, dict4]) + mc = ModelCombiner([dict4]) with pytest.warns(UserWarning): settings_dict = mc.get_combined_settings_dict(median) + mc = ModelCombiner([dict1, dict2, dict3, dict4]) settings = Settings(settings_dict) forename = settings.get_comparison_column("forename") actual = forename["m_probabilities"][0] diff --git a/tests/test_model.py b/tests/test_model.py index a53ec839c4..cd185064df 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -102,7 +102,7 @@ def test_update(model_example): model_example.current_settings_obj.get_comparison_column("fname")[ "m_probabilities" ][0] - == 0 + is None ) model_example._populate_model_from_maximisation_step(0.2, pi_df_collected) From 1d92c633cebcc3fc5312156da46a8326219bb720 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 16 Mar 2021 17:07:15 +0000 Subject: [PATCH 7/9] bump version, changelog --- CHANGELOG.md | 6 +++++- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aaa5dd8951..8ad104bbcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [1.0.5] + +### Changed + +- `m` and `u` probabilities are now reset to `None` rather than `0` in EM iteration when they cannot be estimated ## [1.0.3] - 2020-02-04 diff --git a/pyproject.toml b/pyproject.toml index 69537eb436..9b381a56d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "1.0.4" +version = "1.0.5" description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage." authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis"] license = "MIT" From e2f7a6c717e8911e79f9e8d82580ee0354bd99d0 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 16 Mar 2021 17:16:08 +0000 Subject: [PATCH 8/9] fix typo --- tests/test_fix_probs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_fix_probs.py b/tests/test_fix_probs.py index d0d21e6ea0..594038312a 100644 --- a/tests/test_fix_probs.py +++ b/tests/test_fix_probs.py @@ -52,9 +52,7 @@ def test_fix_u(spark): assert mob["u_probabilities"][0] == pytest.approx(0.8) assert mob["u_probabilities"][1] == pytest.approx(0.2) - first_name = mob = linker.model.current_settings_obj.get_comparison_column( - "first_name" - ) + first_name = linker.model.current_settings_obj.get_comparison_column("first_name") assert first_name["u_probabilities"][0] != pytest.approx(0.8) assert first_name["u_probabilities"][1] != pytest.approx(0.2) From 29e7e6fd78979e4b0aaa9cd3b8ff792ce4cd66f7 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 16 Mar 2021 17:35:59 +0000 Subject: [PATCH 9/9] fix tests --- tests/test_fix_probs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_fix_probs.py b/tests/test_fix_probs.py index 594038312a..94d38a60c6 100644 --- a/tests/test_fix_probs.py +++ b/tests/test_fix_probs.py @@ -53,8 +53,8 @@ def test_fix_u(spark): assert mob["u_probabilities"][1] == pytest.approx(0.2) first_name = linker.model.current_settings_obj.get_comparison_column("first_name") - assert first_name["u_probabilities"][0] != pytest.approx(0.8) - assert first_name["u_probabilities"][1] != pytest.approx(0.2) + assert first_name["u_probabilities"][0] != 0.8 + assert first_name["u_probabilities"][1] != 0.2 settings = { "link_type": "dedupe_only", @@ -79,8 +79,8 @@ def test_fix_u(spark): # Want to check that the "u_probabilities" in the latest parameters are no longer 0.8, 0.2 mob = linker.model.current_settings_obj.get_comparison_column("mob") - assert mob["u_probabilities"][0] != pytest.approx(0.8) - assert mob["u_probabilities"][0] != pytest.approx(0.2) + assert mob["u_probabilities"][0] != 0.8 + assert mob["u_probabilities"][0] != 0.2 settings = { "link_type": "dedupe_only", @@ -106,8 +106,8 @@ def test_fix_u(spark): linker.get_scored_comparisons() mob = linker.model.current_settings_obj.get_comparison_column("mob") - assert mob["u_probabilities"][0] != pytest.approx(0.75) - assert mob["u_probabilities"][1] != pytest.approx(0.25) + assert mob["u_probabilities"][0] != 0.75 + assert mob["u_probabilities"][1] != 0.25 mob = linker.model.current_settings_obj.get_comparison_column("mob") assert mob["m_probabilities"][0] == pytest.approx(0.04)