Skip to content

Commit

Permalink
Merge pull request #153 from moj-analytical-services/fix_iteration_hi…
Browse files Browse the repository at this point in the history
…story

Fix iteration history
  • Loading branch information
RobinL committed Dec 14, 2020
2 parents 019cd1f + 973c6e6 commit 9a55f6a
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 75 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "splink"
version = "0.3.8"
version = "0.3.9"
description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage."
authors = ["Robin Linacre <robinlinacre@hotmail.com>", "Sam Lindsay", "Theodore Manassis"]
license = "MIT"
Expand Down
63 changes: 37 additions & 26 deletions splink/expectation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
from .params import Params
from .check_types import check_types


@check_types
def run_expectation_step(df_with_gamma: DataFrame,
params: Params,
settings: dict,
spark: SparkSession,
compute_ll=False):
def run_expectation_step(
df_with_gamma: DataFrame,
params: Params,
settings: dict,
spark: SparkSession,
compute_ll=False,
):
"""Run the expectation step of the EM algorithm described in the fastlink paper:
http://imai.fas.harvard.edu/research/files/linkage.pdf
Expand All @@ -40,8 +43,7 @@ def run_expectation_step(df_with_gamma: DataFrame,
Returns:
DataFrame: Spark dataframe with a match_probability column
"""

"""

sql = _sql_gen_gamma_prob_columns(params, settings)

Expand All @@ -63,6 +65,10 @@ def run_expectation_step(df_with_gamma: DataFrame,
df_e = spark.sql(sql)

df_e.createOrReplaceTempView("df_e")

params.save_params_to_iteration_history()
params.iteration += 1

return df_e


Expand All @@ -80,7 +86,6 @@ def _sql_gen_gamma_prob_columns(params, settings, table_name="df_with_gamma"):
case_statement = _sql_gen_gamma_case_when(gamma_str, match, params)
case_statements[alias] = case_statement


# Column order for case statement. We want orig_col_l, orig_col_r, gamma_orig_col, prob_gamma_u, prob_gamma_m
select_cols = OrderedDict()
select_cols = _add_left_right(select_cols, settings["unique_id_column_name"])
Expand All @@ -104,21 +109,24 @@ def _sql_gen_gamma_prob_columns(params, settings, table_name="df_with_gamma"):

select_cols["gamma_" + col_name] = "gamma_" + col_name

select_cols[f"prob_gamma_{col_name}_non_match"] = case_statements[f"prob_gamma_{col_name}_non_match"]
select_cols[f"prob_gamma_{col_name}_match"] = case_statements[f"prob_gamma_{col_name}_match"]
select_cols[f"prob_gamma_{col_name}_non_match"] = case_statements[
f"prob_gamma_{col_name}_non_match"
]
select_cols[f"prob_gamma_{col_name}_match"] = case_statements[
f"prob_gamma_{col_name}_match"
]

if settings["link_type"] == 'link_and_dedupe':
if settings["link_type"] == "link_and_dedupe":
select_cols = _add_left_right(select_cols, "_source_table")

for c in settings["additional_columns_to_retain"]:
select_cols = _add_left_right(select_cols, c)

if 'blocking_rules' in settings:
if "blocking_rules" in settings:
if len(settings["blocking_rules"]) > 0:
select_cols['match_key'] = 'match_key'

select_expr = ", ".join(select_cols.values())
select_cols["match_key"] = "match_key"

select_expr = ", ".join(select_cols.values())

sql = f"""
-- We use case statements for these lookups rather than joins for performance and simplicity
Expand Down Expand Up @@ -153,32 +161,33 @@ def _column_order_df_e_select_expr(settings, tf_adj_cols=False):
select_cols["gamma_" + col_name] = "gamma_" + col_name

if settings["retain_intermediate_calculation_columns"]:
select_cols[f"prob_gamma_{col_name}_non_match"] = f"prob_gamma_{col_name}_non_match"
select_cols[
f"prob_gamma_{col_name}_non_match"
] = f"prob_gamma_{col_name}_non_match"
select_cols[f"prob_gamma_{col_name}_match"] = f"prob_gamma_{col_name}_match"
if tf_adj_cols:
if col["term_frequency_adjustments"]:
select_cols[col_name+"_tf_adj"] = col_name+"_tf_adj"


select_cols[col_name + "_tf_adj"] = col_name + "_tf_adj"

if settings["link_type"] == 'link_and_dedupe':
if settings["link_type"] == "link_and_dedupe":
select_cols = _add_left_right(select_cols, "_source_table")

for c in settings["additional_columns_to_retain"]:
select_cols = _add_left_right(select_cols, c)

if 'blocking_rules' in settings:
if "blocking_rules" in settings:
if len(settings["blocking_rules"]) > 0:
select_cols['match_key'] = 'match_key'
select_cols["match_key"] = "match_key"
return ", ".join(select_cols.values())


def _sql_gen_expected_match_prob(params, settings, table_name="df_with_gamma_probs"):
gamma_cols = params._gamma_cols

numerator = " * ".join([f"prob_{g}_match" for g in gamma_cols])
denom_part = " * ".join([f"prob_{g}_non_match" for g in gamma_cols])

λ = params.params['λ']
λ = params.params["λ"]
castλ = f"cast({λ} as double)"
castoneminusλ = f"cast({1-λ} as double)"
match_prob_expression = f"({castλ} * {numerator})/(( {castλ} * {numerator}) + ({castoneminusλ} * {denom_part})) as match_probability"
Expand All @@ -192,6 +201,7 @@ def _sql_gen_expected_match_prob(params, settings, table_name="df_with_gamma_pro

return sql


def _case_when_col_alias(gamma_str, match):

if match == 1:
Expand All @@ -201,6 +211,7 @@ def _case_when_col_alias(gamma_str, match):

return f"prob_{gamma_str}{name_suffix}"


def _sql_gen_gamma_case_when(gamma_str, match, params):
"""
Create the case statements that look up the correct probabilities in the
Expand All @@ -217,8 +228,8 @@ def _sql_gen_gamma_case_when(gamma_str, match, params):
case_statements = []
case_statements.append(f"WHEN {gamma_str} = -1 THEN cast(1 as double)")
for level in levels.values():
case_stmt = f"when {gamma_str} = {level['value']} then cast({level['probability']:.35f} as double)"
case_statements.append(case_stmt)
case_stmt = f"when {gamma_str} = {level['value']} then cast({level['probability']:.35f} as double)"
case_statements.append(case_stmt)

case_statements = "\n".join(case_statements)

Expand All @@ -238,7 +249,7 @@ def _calculate_log_likelihood_df(df_with_gamma_probs, params, spark):

gamma_cols = params._gamma_cols

λ = params.params['λ']
λ = params.params["λ"]

match_prob = " * ".join([f"prob_{g}_match" for g in gamma_cols])
match_prob = f"({λ} * {match_prob})"
Expand Down
16 changes: 7 additions & 9 deletions splink/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, settings: dict, spark: SparkSession):

self.param_history = []

self.iteration = 1
self.iteration = 0

self.settings_original = copy.deepcopy(settings)
self.settings = complete_settings_dict(settings, spark)
Expand Down Expand Up @@ -262,14 +262,14 @@ def _iteration_history_df_gammas(self):
data = []
for it_num, param_value in enumerate(self.param_history):
data.extend(self._convert_params_dict_to_dataframe(param_value, it_num))
data.extend(self._convert_params_dict_to_dataframe(self.params, it_num + 1))

return data

def _iteration_history_df_lambdas(self):
data = []
for it_num, param_value in enumerate(self.param_history):
data.append({"λ": param_value["λ"], "iteration": it_num})
data.append({"λ": self.params["λ"], "iteration": it_num + 1})

return data

def _iteration_history_df_log_likelihood(self):
Expand All @@ -278,9 +278,7 @@ def _iteration_history_df_log_likelihood(self):
data.append(
{"log_likelihood": param_value["log_likelihood"], "iteration": it_num}
)
data.append(
{"log_likelihood": self.params["log_likelihood"], "iteration": it_num + 1}
)

return data

def _reset_param_values_to_none(self):
Expand All @@ -297,7 +295,7 @@ def _reset_param_values_to_none(self):
].values():
level_value["probability"] = None

def _save_params_to_iteration_history(self):
def save_params_to_iteration_history(self):
"""
Take current params and
"""
Expand Down Expand Up @@ -357,10 +355,10 @@ def _update_params(self, lambda_value, pi_df_collected):
Reset values
Then update the parameters from the dataframe
"""
self._save_params_to_iteration_history()

self._reset_param_values_to_none()
self._populate_params(lambda_value, pi_df_collected)
self.iteration += 1


def _to_dict(self):
p_dict = {}
Expand Down

0 comments on commit 9a55f6a

Please sign in to comment.