Skip to content

Commit

Permalink
Merge 30b3fce into b264206
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Mar 22, 2021
2 parents b264206 + 30b3fce commit 26b164e
Show file tree
Hide file tree
Showing 14 changed files with 188 additions and 128 deletions.
13 changes: 11 additions & 2 deletions CHANGELOG.md
Expand Up @@ -4,11 +4,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
## [1.0.5]

### Fixed

- Bug that meant default numerical case statements were not available. See [here](https://github.com/moj-analytical-services/splink/issues/189). Thanks to [geobetts](https://github.com/geobetts)

### Changed

- `m` and `u` probabilities are now reset to `None` rather than `0` in EM iteration when they cannot be estimated
- Now use `_repr_pretty_` so that objects display nicely in Jupyter Lab rather than `__repr__`, which had been interfering with the interpretatino of stack trace errors

## [1.0.3] - 2020-02-04

### Fixed


- Bug whereby Splink lowercased case expressions, see [here](https://github.com/moj-analytical-services/splink/issues/174)
## [1.0.2] - 2020-02-02
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "splink"
version = "1.0.4"
version = "1.0.5"
description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage."
authors = ["Robin Linacre <robinlinacre@hotmail.com>", "Sam Lindsay", "Theodore Manassis"]
license = "MIT"
Expand Down
13 changes: 10 additions & 3 deletions splink/__init__.py
Expand Up @@ -4,9 +4,10 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.session import SparkSession
from splink.validate import (
validate_settings,
validate_settings_against_schema,
validate_input_datasets,
validate_link_type,
validate_probabilities,
)
from splink.model import Model, load_model_from_json
from splink.case_statements import _check_jaro_registered
Expand All @@ -23,6 +24,8 @@
default_break_lineage_scored_comparisons,
)

from splink.default_settings import normalise_probabilities


@typechecked
class Splink:
Expand Down Expand Up @@ -53,11 +56,13 @@ def __init__(
self.break_lineage_scored_comparisons = break_lineage_scored_comparisons
_check_jaro_registered(spark)

validate_settings(settings)
validate_settings_against_schema(settings)
validate_link_type(df_or_dfs, settings)

self.model = Model(settings, spark)
self.settings_dict = self.model.current_settings_obj.settings_dict

self.settings_dict = normalise_probabilities(self.settings_dict)
validate_probabilities(self.settings_dict)
# dfs is a list of dfs irrespective of whether input was a df or list of dfs
if type(df_or_dfs) == DataFrame:
dfs = [df_or_dfs]
Expand All @@ -76,6 +81,8 @@ def manually_apply_fellegi_sunter_weights(self):
"""
df_comparison = block_using_rules(self.settings_dict, self.df, self.spark)
df_gammas = add_gammas(df_comparison, self.settings_dict, self.spark)
# see https://github.com/moj-analytical-services/splink/issues/187
df_gammas = self.break_lineage_blocked_comparisons(df_gammas, self.spark)
return run_expectation_step(df_gammas, self.model, self.spark)

def get_scored_comparisons(self, compute_ll=False):
Expand Down
20 changes: 15 additions & 5 deletions splink/combine_models.py
Expand Up @@ -16,8 +16,6 @@ def _apply_aggregate_function(zipped_probs, aggregate_function):
"The aggregation function produced an error when "
f"operating on the following data: {probs_list}. "
"The result of this aggreation has been set to None. "
"You may wish to provide a aggreation function that is robust to nulls "
"or check why there's a None in your parameter estimates. "
f"The error was {e}"
)
reduced = None
Expand All @@ -33,6 +31,13 @@ def _format_probs_for_report(probs):
return f"{probs_as_strings}"


def _filter_nones(list_of_lists):
def filter_none(sublist):
return [item for item in sublist if item is not None]

return [filter_none(sublist) for sublist in list_of_lists]


def _zip_m_and_u_probabilities(cc_estimates: list):
"""Groups together the different estimates of the same parameter.
Expand All @@ -47,7 +52,10 @@ def _zip_m_and_u_probabilities(cc_estimates: list):
"""

zipped_m_probs = zip(*[cc["m_probabilities"] for cc in cc_estimates])
zipped_m_probs = _filter_nones(zipped_m_probs)
zipped_u_probs = zip(*[cc["u_probabilities"] for cc in cc_estimates])
zipped_u_probs = _filter_nones(zipped_u_probs)

return {"zipped_m": zipped_m_probs, "zipped_u": zipped_u_probs}


Expand Down Expand Up @@ -238,7 +246,9 @@ def comparison_chart(self):

return altair_if_installed_else_json(chart_def)

def __repr__(self):
return self.summary_report(
summary_name="harmonic_mean", aggregate_function=harmonic_mean
def _repr_pretty_(self, p, cycle):
p.text(
self.summary_report(
summary_name="harmonic_mean", aggregate_function=harmonic_mean
)
)
107 changes: 41 additions & 66 deletions splink/default_settings.py
Expand Up @@ -4,7 +4,8 @@

from copy import deepcopy

from .validate import validate_settings, _get_default_value
from .validate import get_default_value_from_schema

from .case_statements import (
_check_jaro_registered,
sql_gen_case_smnt_strict_equality_2,
Expand Down Expand Up @@ -66,26 +67,28 @@ def _get_default_case_statement_fn(default_statements, data_type, levels):

def _get_default_probabilities(m_or_u, levels):

if levels > 5:
if levels > 6:
raise ValueError(
f"No default m and u probabilities available when levels > 4, "
f"No default m and u probabilities available when levels > 6, "
"please specify custom values for 'm_probabilities' and 'u_probabilities' "
"within your settings dictionary"
)

# Note all m and u probabilities are automatically normalised to sum to 1
default_m_u_probabilities = {
"m_probabilities": {
2: [1, 9],
3: [1, 2, 7],
4: [1, 1, 1, 7],
5: [0.33, 0.67, 1, 2, 6],
2: _normalise_prob_list([1, 9]),
3: _normalise_prob_list([1, 2, 7]),
4: _normalise_prob_list([1, 1, 1, 7]),
5: _normalise_prob_list([0.33, 0.67, 1, 2, 6]),
6: _normalise_prob_list([0.33, 0.67, 1, 2, 3, 6]),
},
"u_probabilities": {
2: [9, 1],
3: [7, 2, 1],
4: [7, 1, 1, 1],
5: [6, 2, 1, 0.33, 0.67],
2: _normalise_prob_list([9, 1]),
3: _normalise_prob_list([7, 2, 1]),
4: _normalise_prob_list([7, 1, 1, 1]),
5: _normalise_prob_list([6, 2, 1, 0.33, 0.67]),
6: _normalise_prob_list([6, 3, 2, 1, 0.33, 0.67]),
},
}

Expand Down Expand Up @@ -120,48 +123,19 @@ def _complete_case_expression(col_settings, spark):
col_settings["case_expression"] = new_case_stmt


def _complete_probabilities(col_settings: dict, setting_name: str):
def _complete_probabilities(col_settings: dict, mu_probabilities: str):
"""
Args:
col_settings (dict): Column settings dictionary
setting_name (str): Either 'm_probabilities' or 'u_probabilities'
mu_probabilities (str): Either 'm_probabilities' or 'u_probabilities'
"""

if setting_name not in col_settings:
levels = col_settings["num_levels"]
probs = _get_default_probabilities(setting_name, levels)
col_settings[setting_name] = probs
else:
if mu_probabilities not in col_settings:
levels = col_settings["num_levels"]
probs = col_settings[setting_name]

# Check for m and u manually set to zero (https://github.com/moj-analytical-services/splink/issues/161)
if not all(col_settings[setting_name]):
if "custom_name" in col_settings:
col_name = col_settings["custom_name"]
else:
col_name = col_settings["col_name"]

if setting_name == "m_probabilities":
letter = "m"
elif setting_name == "u_probabilities":
letter = "u"

warnings.warn(
f"Your {setting_name} for {col_name} include zeroes. "
f"Where {letter}=0 for a given level, it remains fixed rather than being estimated "
"along with other model parameters, and all comparisons at this level "
f"are assigned a match score of {1. if letter=='u' else 0.}, regardless of other comparisons columns."
)

if len(probs) != levels:
raise ValueError(
f"Number of {setting_name} provided is not equal to number of levels specified"
)

col_settings[setting_name] = col_settings[setting_name]
probs = _get_default_probabilities(mu_probabilities, levels)
col_settings[mu_probabilities] = probs


def complete_settings_dict(settings_dict: dict, spark: SparkSession):
Expand All @@ -176,7 +150,6 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession):
dict: A `splink` settings dictionary with all keys populated.
"""
settings_dict = deepcopy(settings_dict)
validate_settings(settings_dict)

# Complete non-column settings from their default values if not exist
non_col_keys = [
Expand All @@ -192,7 +165,9 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession):
]
for key in non_col_keys:
if key not in settings_dict:
settings_dict[key] = _get_default_value(key, is_column_setting=False)
settings_dict[key] = get_default_value_from_schema(
key, is_column_setting=False
)

if "blocking_rules" in settings_dict:
if len(settings_dict["blocking_rules"]) == 0:
Expand All @@ -206,6 +181,8 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession):
c_cols = settings_dict["comparison_columns"]
for gamma_index, col_settings in enumerate(c_cols):

# Gamma index refers to the position in the comparison vector
# i.e. it's a counter for comparison columns
col_settings["gamma_index"] = gamma_index

# Populate non-existing keys from defaults
Expand All @@ -219,32 +196,30 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession):

for key in keys_for_defaults:
if key not in col_settings:
default = _get_default_value(key, is_column_setting=True)
default = get_default_value_from_schema(key, is_column_setting=True)
col_settings[key] = default

# Doesn't need assignment because we're modify the col_settings dictionary
_complete_case_expression(col_settings, spark)
_complete_probabilities(col_settings, "m_probabilities")
_complete_probabilities(col_settings, "u_probabilities")

if None not in col_settings["m_probabilities"]:
col_settings["m_probabilities"] = _normalise_prob_list(
col_settings["m_probabilities"]
)
else:
warnings.warn(
"Your m probabilities contain a None value "
"so could not be normalised to 1"
)
return settings_dict

if None not in col_settings["u_probabilities"]:
col_settings["u_probabilities"] = _normalise_prob_list(
col_settings["u_probabilities"]
)
else:
warnings.warn(
"Your u probabilities contain a None value "
"so could not be normalised to 1"
)

def normalise_probabilities(settings_dict: dict):
"""Normalise all probabilities in a settings dictionary to sum
to one, of possible
Args:
settings_dict (dict): Splink settings dictionary
"""

c_cols = settings_dict["comparison_columns"]
for col_settings in c_cols:
for p in ["m_probabilities", "u_probabilities"]:
if p in col_settings:
if None not in col_settings[p]:
if sum(col_settings[p]) != 0:
col_settings[p] = _normalise_prob_list(col_settings[p])
return settings_dict
5 changes: 3 additions & 2 deletions splink/estimate.py
@@ -1,4 +1,5 @@
from copy import deepcopy
from splink.default_settings import normalise_probabilities, _normalise_prob_list

from .blocking import block_using_rules
from .gammas import add_gammas
Expand Down Expand Up @@ -109,9 +110,9 @@ def estimate_u_values(
u_probs = new_settings["comparison_columns"][i]["u_probabilities"]
# Ensure non-zero u (https://github.com/moj-analytical-services/splink/issues/161)
u_probs = [u or 1 / target_rows for u in u_probs]

u_probs = _normalise_prob_list(u_probs)
col["u_probabilities"] = u_probs
if fix_u_probabilities:
col["fix_u_probabilities"] = True

return orig_settings
return orig_settings
10 changes: 7 additions & 3 deletions splink/expectation_step.py
Expand Up @@ -213,9 +213,13 @@ def _sql_gen_gamma_case_when(comparison_column, match):
case_statements.append(f"WHEN {cc.gamma_name} = -1 THEN cast(1 as double)")

for gamma_index, prob in enumerate(probs):
case_stmt = (
f"when {cc.gamma_name} = {gamma_index} then cast({prob:.35f} as double)"
)
if prob is not None:
case_stmt = (
f"when {cc.gamma_name} = {gamma_index} then cast({prob:.35f} as double)"
)
else:
case_stmt = f"when {cc.gamma_name} = {gamma_index} then null"

case_statements.append(case_stmt)

case_statements = "\n".join(case_statements)
Expand Down
11 changes: 7 additions & 4 deletions splink/model.py
Expand Up @@ -95,7 +95,10 @@ def is_converged(self):
for gamma_index in range(c_latest.num_levels):
val_latest = c_latest[m_or_u][gamma_index]
val_previous = c_previous[m_or_u][gamma_index]
diff = abs(val_latest - val_previous)
if val_latest is not None:
diff = abs(val_latest - val_previous)
else:
diff = 0
diffs.append(
{"col_name": c_latest.name, "diff": diff, "level": gamma_index}
)
Expand Down Expand Up @@ -273,9 +276,9 @@ def all_charts_write_html_file(
with open(filename, "w") as f:
f.write(template.format(**fmt_dict))

def __repr__(self): # pragma: no cover
p = self.current_settings_obj
return p.__repr__()
def _repr_pretty_(self, p, cycle): # pragma: no cover

return p.pretty(self.current_settings_obj)


def load_model_from_json(path):
Expand Down

0 comments on commit 26b164e

Please sign in to comment.