Skip to content

Commit

Permalink
fix problem with iteration history
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Dec 14, 2020
1 parent cafbd37 commit 973c6e6
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 47 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "splink"
version = "0.3.8"
version = "0.3.9"
description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage."
authors = ["Robin Linacre <robinlinacre@hotmail.com>", "Sam Lindsay", "Theodore Manassis"]
license = "MIT"
Expand Down
1 change: 1 addition & 0 deletions splink/expectation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def run_expectation_step(
df_e.createOrReplaceTempView("df_e")

params.save_params_to_iteration_history()
params.iteration += 1

return df_e

Expand Down
12 changes: 5 additions & 7 deletions splink/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, settings: dict, spark: SparkSession):

self.param_history = []

self.iteration = 1
self.iteration = 0

self.settings_original = copy.deepcopy(settings)
self.settings = complete_settings_dict(settings, spark)
Expand Down Expand Up @@ -262,14 +262,14 @@ def _iteration_history_df_gammas(self):
data = []
for it_num, param_value in enumerate(self.param_history):
data.extend(self._convert_params_dict_to_dataframe(param_value, it_num))
data.extend(self._convert_params_dict_to_dataframe(self.params, it_num + 1))

return data

def _iteration_history_df_lambdas(self):
data = []
for it_num, param_value in enumerate(self.param_history):
data.append({"λ": param_value["λ"], "iteration": it_num})
data.append({"λ": self.params["λ"], "iteration": it_num + 1})

return data

def _iteration_history_df_log_likelihood(self):
Expand All @@ -278,9 +278,7 @@ def _iteration_history_df_log_likelihood(self):
data.append(
{"log_likelihood": param_value["log_likelihood"], "iteration": it_num}
)
data.append(
{"log_likelihood": self.params["log_likelihood"], "iteration": it_num + 1}
)

return data

def _reset_param_values_to_none(self):
Expand Down Expand Up @@ -360,7 +358,7 @@ def _update_params(self, lambda_value, pi_df_collected):

self._reset_param_values_to_none()
self._populate_params(lambda_value, pi_df_collected)
self.iteration += 1


def _to_dict(self):
p_dict = {}
Expand Down
119 changes: 80 additions & 39 deletions tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@

# Light testing at the moment. Focus on aspects that could break main algo

@pytest.fixture(scope='module')

@pytest.fixture(scope="module")
def param_example():
gamma_settings = {
"link_type": "dedupe_only",
"proportion_of_matches": 0.2,
"comparison_columns": [
{"col_name": "fname"},
{"col_name": "sname",
"num_levels": 3}
],
"blocking_rules": []
}
"link_type": "dedupe_only",
"proportion_of_matches": 0.2,
"comparison_columns": [
{"col_name": "fname"},
{"col_name": "sname", "num_levels": 3},
],
"blocking_rules": [],
}

params = Params(gamma_settings, spark="supress_warnings")

Expand All @@ -28,58 +28,95 @@ def test_prob_sum_one(param_example):

for m in ["prob_dist_match", "prob_dist_non_match"]:
for g in ["gamma_fname", "gamma_sname"]:
levels = p["π"][g][m]
levels = p["π"][g][m]

total = 0
for l in levels:
total += levels[l]["probability"]

assert total == pytest.approx(1.0)

def test_update(param_example):

def test_update(param_example):

pi_df_collected = [
{'gamma_value': 1, 'new_probability_match': 0.9, 'new_probability_non_match': 0.1, 'gamma_col': 'gamma_fname'},
{'gamma_value': 0, 'new_probability_match': 0.2, 'new_probability_non_match': 0.8, 'gamma_col': 'gamma_fname'},
{'gamma_value': 1, 'new_probability_match': 0.9, 'new_probability_non_match': 0.1, 'gamma_col': 'gamma_sname'},
{'gamma_value': 2, 'new_probability_match': 0.7, 'new_probability_non_match': 0.3, 'gamma_col': 'gamma_sname'},
{'gamma_value': 0, 'new_probability_match': 0.5, 'new_probability_non_match': 0.5, 'gamma_col': 'gamma_sname'}]

param_example._save_params_to_iteration_history()
{
"gamma_value": 1,
"new_probability_match": 0.9,
"new_probability_non_match": 0.1,
"gamma_col": "gamma_fname",
},
{
"gamma_value": 0,
"new_probability_match": 0.2,
"new_probability_non_match": 0.8,
"gamma_col": "gamma_fname",
},
{
"gamma_value": 1,
"new_probability_match": 0.9,
"new_probability_non_match": 0.1,
"gamma_col": "gamma_sname",
},
{
"gamma_value": 2,
"new_probability_match": 0.7,
"new_probability_non_match": 0.3,
"gamma_col": "gamma_sname",
},
{
"gamma_value": 0,
"new_probability_match": 0.5,
"new_probability_non_match": 0.5,
"gamma_col": "gamma_sname",
},
]

param_example.save_params_to_iteration_history()
param_example._reset_param_values_to_none()
assert param_example.params["π"]["gamma_fname"]["prob_dist_match"]["level_0"]["probability"] is None
assert (
param_example.params["π"]["gamma_fname"]["prob_dist_match"]["level_0"][
"probability"
]
is None
)
param_example._populate_params(0.2, pi_df_collected)

new_params = param_example.params

assert new_params["π"]["gamma_fname"]["prob_dist_match"]["level_0"]["probability"] == 0.2
assert new_params["π"]["gamma_fname"]["prob_dist_non_match"]["level_0"]["probability"] == 0.8
assert (
new_params["π"]["gamma_fname"]["prob_dist_match"]["level_0"]["probability"]
== 0.2
)
assert (
new_params["π"]["gamma_fname"]["prob_dist_non_match"]["level_0"]["probability"]
== 0.8
)


def test_update_settings():

old_settings = {
"link_type": "dedupe_only",
"proportion_of_matches": 0.2,
"comparison_columns": [
{"col_name": "fname"},
{"col_name": "sname",
"num_levels": 3}
{"col_name": "sname", "num_levels": 3},
],
"blocking_rules": []
"blocking_rules": [],
}

params = Params(old_settings, spark="supress_warnings")

new_settings = {
"link_type": "dedupe_only",
"blocking_rules": [],
"comparison_columns": [
{
"col_name": "fname",
"num_levels": 3,
"m_probabilities": [0.02,0.03,0.95],
"u_probabilities": [0.92,0.05,0.03]
"m_probabilities": [0.02, 0.03, 0.95],
"u_probabilities": [0.92, 0.05, 0.03],
},
{
"custom_name": "sname",
Expand All @@ -89,17 +126,21 @@ def test_update_settings():
case when concat(fname_l, sname_l) = concat(fname_r, sname_r) then 1
else 0 end
""",
"m_probabilities": [0.01,0.02,0.97],
"u_probabilities": [0.9,0.05,0.05]
"m_probabilities": [0.01, 0.02, 0.97],
"u_probabilities": [0.9, 0.05, 0.05],
},
{"col_name": "dob"}
]
{"col_name": "dob"},
],
}

update = get_or_update_settings(params, new_settings)

# new settings used due to num_levels mismatch
assert update["comparison_columns"][0]["m_probabilities"] == new_settings["comparison_columns"][0]["m_probabilities"]
# new settings updated with old settings
assert update["comparison_columns"][1]["u_probabilities"] == pytest.approx(params.settings["comparison_columns"][1]["u_probabilities"])

assert (
update["comparison_columns"][0]["m_probabilities"]
== new_settings["comparison_columns"][0]["m_probabilities"]
)
# new settings updated with old settings
assert update["comparison_columns"][1]["u_probabilities"] == pytest.approx(
params.settings["comparison_columns"][1]["u_probabilities"]
)

0 comments on commit 973c6e6

Please sign in to comment.