Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Dec 14, 2020
1 parent 63042ee commit 21ce588
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "splink"
version = "0.3.6"
version = "0.3.7"
description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage."
authors = ["Robin Linacre <robinlinacre@hotmail.com>", "Sam Lindsay", "Theodore Manassis"]
license = "MIT"
Expand Down
119 changes: 119 additions & 0 deletions tests/test_fix_probs.py
@@ -0,0 +1,119 @@
import pytest

from pyspark.sql import Row

from splink import Splink


def test_fix_u(spark, link_dedupe_data):
settings = {
"link_type": "link_only",
"comparison_columns": [{"col_name": "first_name"}, {"col_name": "surname"}],
"blocking_rules": [],
}

# We expect u on the cartesian product of MoB to be around
df = [
{"unique_id": 1, "mob": "1", "first_name": "a", "surname": "a"},
{"unique_id": 2, "mob": "2", "first_name": "b", "surname": "b"},
{"unique_id": 3, "mob": "3", "first_name": "c", "surname": "c"},
{"unique_id": 4, "mob": "4", "first_name": "d", "surname": "d"},
{"unique_id": 5, "mob": "5", "first_name": "e", "surname": "e"},
{"unique_id": 6, "mob": "6", "first_name": "f", "surname": "f"},
{"unique_id": 7, "mob": "7", "first_name": "g", "surname": "g"},
{"unique_id": 9, "mob": "9", "first_name": "h", "surname": "h"},
{"unique_id": 10, "mob": "10", "first_name": "i", "surname": "i"},
{"unique_id": 10, "mob": "10", "first_name": "i", "surname": "i"},
]

df = spark.createDataFrame(Row(**x) for x in df)

settings = {
"link_type": "dedupe_only",
"proportion_of_matches": 0.1,
"comparison_columns": [
{
"col_name": "mob",
"num_levels": 2,
"u_probabilities": [0.8, 0.2],
"fix_u_probabilities": True,
},
{
"col_name": "first_name",
"u_probabilities": [0.8, 0.2],
},
{"col_name": "surname"},
],
"blocking_rules": [],
"max_iterations": 3,
}

linker = Splink(settings, spark, df=df)

df_e = linker.get_scored_comparisons()

# Want to check that the "u_probabilities" in the latest parameters are still 0.8, 0.2
mob = linker.params.params["π"]["gamma_mob"]["prob_dist_non_match"]
assert mob["level_0"]["probability"] == pytest.approx(0.8)
assert mob["level_1"]["probability"] == pytest.approx(0.2)

first_name = linker.params.params["π"]["gamma_first_name"]["prob_dist_non_match"]
assert first_name["level_0"]["probability"] != pytest.approx(0.8)
assert first_name["level_1"]["probability"] != pytest.approx(0.2)

settings = {
"link_type": "dedupe_only",
"proportion_of_matches": 0.1,
"comparison_columns": [
{
"col_name": "mob",
"num_levels": 2,
"u_probabilities": [0.8, 0.2],
"fix_u_probabilities": False,
},
{"col_name": "first_name"},
{"col_name": "surname"},
],
"blocking_rules": [],
"max_iterations": 3,
}

linker = Splink(settings, spark, df=df)

df_e = linker.get_scored_comparisons()

# Want to check that the "u_probabilities" in the latest parameters are no longer 0.8, 0.2
mob = linker.params.params["π"]["gamma_mob"]["prob_dist_non_match"]
assert mob["level_0"]["probability"] != pytest.approx(0.8)
assert mob["level_1"]["probability"] != pytest.approx(0.2)

settings = {
"link_type": "dedupe_only",
"proportion_of_matches": 0.1,
"comparison_columns": [
{
"col_name": "mob",
"num_levels": 2,
"m_probabilities": [0.04, 0.96],
"fix_m_probabilities": True,
"u_probabilities": [0.75, 0.25],
"fix_u_probabilities": False,
},
{"col_name": "first_name"},
{"col_name": "surname"},
],
"blocking_rules": [],
"max_iterations": 3,
}

linker = Splink(settings, spark, df=df)

df_e = linker.get_scored_comparisons()

mob = linker.params.params["π"]["gamma_mob"]["prob_dist_non_match"]
assert mob["level_0"]["probability"] != pytest.approx(0.75)
assert mob["level_1"]["probability"] != pytest.approx(0.25)

mob = linker.params.params["π"]["gamma_mob"]["prob_dist_match"]
assert mob["level_0"]["probability"] == pytest.approx(0.04)
assert mob["level_1"]["probability"] == pytest.approx(0.96)

0 comments on commit 21ce588

Please sign in to comment.