From 21ce588c65aa19d1f477a410177739d5becaabee Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 14 Dec 2020 16:38:52 +0000 Subject: [PATCH] add tests --- pyproject.toml | 2 +- tests/test_fix_probs.py | 119 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 tests/test_fix_probs.py diff --git a/pyproject.toml b/pyproject.toml index 926d052f14..c2001ff0df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "0.3.6" +version = "0.3.7" description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage." authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis"] license = "MIT" diff --git a/tests/test_fix_probs.py b/tests/test_fix_probs.py new file mode 100644 index 0000000000..386db051ff --- /dev/null +++ b/tests/test_fix_probs.py @@ -0,0 +1,119 @@ +import pytest + +from pyspark.sql import Row + +from splink import Splink + + +def test_fix_u(spark, link_dedupe_data): + settings = { + "link_type": "link_only", + "comparison_columns": [{"col_name": "first_name"}, {"col_name": "surname"}], + "blocking_rules": [], + } + + # We expect u on the cartesian product of MoB to be around + df = [ + {"unique_id": 1, "mob": "1", "first_name": "a", "surname": "a"}, + {"unique_id": 2, "mob": "2", "first_name": "b", "surname": "b"}, + {"unique_id": 3, "mob": "3", "first_name": "c", "surname": "c"}, + {"unique_id": 4, "mob": "4", "first_name": "d", "surname": "d"}, + {"unique_id": 5, "mob": "5", "first_name": "e", "surname": "e"}, + {"unique_id": 6, "mob": "6", "first_name": "f", "surname": "f"}, + {"unique_id": 7, "mob": "7", "first_name": "g", "surname": "g"}, + {"unique_id": 9, "mob": "9", "first_name": "h", "surname": "h"}, + {"unique_id": 10, "mob": "10", "first_name": "i", "surname": "i"}, + {"unique_id": 10, "mob": "10", "first_name": "i", "surname": "i"}, + ] + + df = spark.createDataFrame(Row(**x) for x in df) + + settings = { + "link_type": "dedupe_only", + "proportion_of_matches": 0.1, + "comparison_columns": [ + { + "col_name": "mob", + "num_levels": 2, + "u_probabilities": [0.8, 0.2], + "fix_u_probabilities": True, + }, + { + "col_name": "first_name", + "u_probabilities": [0.8, 0.2], + }, + {"col_name": "surname"}, + ], + "blocking_rules": [], + "max_iterations": 3, + } + + linker = Splink(settings, spark, df=df) + + df_e = linker.get_scored_comparisons() + + # Want to check that the "u_probabilities" in the latest parameters are still 0.8, 0.2 + mob = linker.params.params["π"]["gamma_mob"]["prob_dist_non_match"] + assert mob["level_0"]["probability"] == pytest.approx(0.8) + assert mob["level_1"]["probability"] == pytest.approx(0.2) + + first_name = linker.params.params["π"]["gamma_first_name"]["prob_dist_non_match"] + assert first_name["level_0"]["probability"] != pytest.approx(0.8) + assert first_name["level_1"]["probability"] != pytest.approx(0.2) + + settings = { + "link_type": "dedupe_only", + "proportion_of_matches": 0.1, + "comparison_columns": [ + { + "col_name": "mob", + "num_levels": 2, + "u_probabilities": [0.8, 0.2], + "fix_u_probabilities": False, + }, + {"col_name": "first_name"}, + {"col_name": "surname"}, + ], + "blocking_rules": [], + "max_iterations": 3, + } + + linker = Splink(settings, spark, df=df) + + df_e = linker.get_scored_comparisons() + + # Want to check that the "u_probabilities" in the latest parameters are no longer 0.8, 0.2 + mob = linker.params.params["π"]["gamma_mob"]["prob_dist_non_match"] + assert mob["level_0"]["probability"] != pytest.approx(0.8) + assert mob["level_1"]["probability"] != pytest.approx(0.2) + + settings = { + "link_type": "dedupe_only", + "proportion_of_matches": 0.1, + "comparison_columns": [ + { + "col_name": "mob", + "num_levels": 2, + "m_probabilities": [0.04, 0.96], + "fix_m_probabilities": True, + "u_probabilities": [0.75, 0.25], + "fix_u_probabilities": False, + }, + {"col_name": "first_name"}, + {"col_name": "surname"}, + ], + "blocking_rules": [], + "max_iterations": 3, + } + + linker = Splink(settings, spark, df=df) + + df_e = linker.get_scored_comparisons() + + mob = linker.params.params["π"]["gamma_mob"]["prob_dist_non_match"] + assert mob["level_0"]["probability"] != pytest.approx(0.75) + assert mob["level_1"]["probability"] != pytest.approx(0.25) + + mob = linker.params.params["π"]["gamma_mob"]["prob_dist_match"] + assert mob["level_0"]["probability"] == pytest.approx(0.04) + assert mob["level_1"]["probability"] == pytest.approx(0.96) \ No newline at end of file