From e04e41103863b2fdaaadeb929c5c77b13d5446cf Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Fri, 11 Dec 2020 14:18:47 +0000 Subject: [PATCH 1/4] faster and more granular truth space --- splink/truth.py | 147 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 145 insertions(+), 2 deletions(-) diff --git a/splink/truth.py b/splink/truth.py index 36126b595e..7ae435ff72 100644 --- a/splink/truth.py +++ b/splink/truth.py @@ -393,7 +393,7 @@ def df_e_with_truth_categories( return spark.sql(sql) -def truth_space_table( +def truth_space_table_old( df_labels_with_splink_scores: DataFrame, spark: SparkSession, threshold_actual: float = 0.5, @@ -439,6 +439,148 @@ def truth_space_table( return all_roc_df +def truth_space_table( + df_labels_with_splink_scores: DataFrame, + spark: SparkSession, + threshold_actual: float = 0.5, + score_colname: str = None, +): + """Create a table of the ROC space i.e. truth table statistics + for each discrimination threshold + + Args: + df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores + usually the output of the truth.labels_with_splink_scores function + threshold_actual (float, optional): Threshold to use in categorising clerical match + scores into match or no match. Defaults to 0.5. + score_colname (float, optional): Allows user to explicitly state the column name + in the Splink dataset containing the Splink score. If none will be inferred + + Returns: + DataFrame: Table of 'truth space' i.e. truth categories for each threshold level + """ + + # At a truth threshold of 1.0, we say a splink score of 1.0 is a positive in ROC space. i.e it's inclusive, so if there are splink scores of exactly 1.0 it's not possible to have zero positives in the truth table. + # This means that at a truth threshold of 0.0 we say a splink score of 0.0 positive. so it's possible to have zero negatives in the truth table. + + # This code provides an efficient way to compute the truth space + # It's more complex than the previous code, but executes much faster because only a single SQL query/PySpark Action is needed + # The previous implementation, which is easier to understand, is [here](https://github.com/moj-analytical-services/splink/blob/b4f601e6d180c6abfd64ab40775dca3e3513c0b5/splink/truth.py#L396) + + # We start with df_labels_with_splink_scores + # This is a table of each pair of clerically labelled records accompanied by the Splink match score. + # It is sorted in order to clerical_match_score, low to high + + # This means for any row, if we picked a threshold equal to clerical_match_score, all rows _above_ (in the table) are categoried by splink as non-matches. + + # For instance, if a clerical_match_score is 0.25, then any records above this in the table have a score of <0.25, and are therefore negative. We categorise a score of exactly 0.25 as positive. + + # In addition, we can categorise any indiviual row as containing a false positive or false negative _at_ the clerical match score for the row. + + # This allows us to say things like: Of the records above this row, we have _classified_ them all as negative, but we have _seen_ to true (clerically labelled) positives. Thefore these must be false negatives. + + # In particular, the calculations are as follows: + # False positives: The cumulative total of positive labels in records BELOW this row, INCLUSIVE (because this one is being counted as positive) + # True positives: The total number of positives minus false positives + + # False negatives: The total number of negatives, minus negatives seen above this row + # True negatives: The cumulative total of negative labels in records aboev this row + + # We want percentiles of score to compute + score_colname = _get_score_colname(df_labels_with_splink_scores, score_colname) + + df_labels_with_splink_scores.createOrReplaceTempView("df_labels_with_splink_scores") + sql = f""" + select + *, + {score_colname} as truth_threshold, + case when clerical_match_score > {threshold_actual} then 1 + else 0 + end + as c_P, + case when clerical_match_score > {threshold_actual} then 0 + else 1 + end + as c_N + from df_labels_with_splink_scores + order by {score_colname} + """ + df_with_labels = spark.sql(sql) + df_with_labels.createOrReplaceTempView("df_with_labels") + + sql = """ + select truth_threshold, count(*) as num_records_in_row, sum(c_P) as c_P, sum(c_N) as c_N + from + df_with_labels + group by truth_threshold + order by truth_threshold + """ + df_with_labels_grouped = spark.sql(sql) + df_with_labels_grouped.createOrReplaceTempView("df_with_labels_grouped") + + sql = """ + select + truth_threshold, + + (sum(c_P) over (order by truth_threshold desc)) as cum_clerical_P, + (sum(c_N) over (order by truth_threshold)) - c_N as cum_clerical_N, + + (select sum(c_P) from df_with_labels_grouped) as total_clerical_P, + (select sum(c_N) from df_with_labels_grouped) as total_clerical_N, + (select sum(num_records_in_row) from df_with_labels_grouped) as row_count, + + -num_records_in_row + sum(num_records_in_row) over (order by truth_threshold) as N_labels, + sum(num_records_in_row) over (order by truth_threshold desc) as P_labels + from df_with_labels_grouped + order by truth_threshold + """ + df_with_cumulative_labels = spark.sql(sql) + df_with_cumulative_labels.createOrReplaceTempView("df_with_cumulative_labels") + + sql = """ + select + truth_threshold, + row_count, + total_clerical_P as P, + total_clerical_N as N, + + P_labels - cum_clerical_P as FP, + cum_clerical_P as TP, + + N_labels - cum_clerical_N as FN, + cum_clerical_N as TN + + from df_with_cumulative_labels + """ + df_with_truth_cats = spark.sql(sql) + df_with_truth_cats.createOrReplaceTempView("df_with_truth_cats") + df_with_truth_cats.toPandas() + + sql = """ + select + truth_threshold, + row_count, + P, + N, + TP, + TN, + FP, + FN, + P/row_count as P_rate, + N/row_count as N_rate, + TP/P as TP_rate, + TN/N as TN_rate, + FP/N as FP_rate, + FN/P as FN_rate, + TP/(TP+FP) as precision, + TP/(TP+FN) as recall + from df_with_truth_cats + """ + df_truth_space = spark.sql(sql) + + return df_truth_space + + def roc_chart( df_labels_with_splink_scores: DataFrame, spark: SparkSession, @@ -446,6 +588,7 @@ def roc_chart( x_domain: list = None, width: int = 400, height: int = 400, + truth_space_fn=truth_space_table, ): """Create a ROC chart from labelled data @@ -496,7 +639,7 @@ def roc_chart( "width": width, } - data = truth_space_table( + data = truth_space_fn( df_labels_with_splink_scores, spark, threshold_actual=threshold_actual ).toPandas() From d826695f9f1453bda0a96e82291b8a51f820774d Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Fri, 11 Dec 2020 16:33:14 +0000 Subject: [PATCH 2/4] ensure results are consistent with previous implementation --- splink/truth.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/splink/truth.py b/splink/truth.py index 7ae435ff72..432af6413e 100644 --- a/splink/truth.py +++ b/splink/truth.py @@ -494,11 +494,11 @@ def truth_space_table( select *, {score_colname} as truth_threshold, - case when clerical_match_score > {threshold_actual} then 1 + case when clerical_match_score >= {threshold_actual} then 1 else 0 end as c_P, - case when clerical_match_score > {threshold_actual} then 0 + case when clerical_match_score >= {threshold_actual} then 0 else 1 end as c_N @@ -588,7 +588,6 @@ def roc_chart( x_domain: list = None, width: int = 400, height: int = 400, - truth_space_fn=truth_space_table, ): """Create a ROC chart from labelled data @@ -639,7 +638,7 @@ def roc_chart( "width": width, } - data = truth_space_fn( + data = truth_space_table( df_labels_with_splink_scores, spark, threshold_actual=threshold_actual ).toPandas() From 27f0bd9f1cbe78d3cec47c8db81450ea2279590e Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Fri, 11 Dec 2020 16:59:28 +0000 Subject: [PATCH 3/4] private function for old implementation --- splink/truth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/truth.py b/splink/truth.py index 432af6413e..2ee4a2330c 100644 --- a/splink/truth.py +++ b/splink/truth.py @@ -393,7 +393,7 @@ def df_e_with_truth_categories( return spark.sql(sql) -def truth_space_table_old( +def _truth_space_table_old( df_labels_with_splink_scores: DataFrame, spark: SparkSession, threshold_actual: float = 0.5, From 23ed5f5900f17a1a6bd41abfe58a56dfdba328cc Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Fri, 11 Dec 2020 19:47:19 +0000 Subject: [PATCH 4/4] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index da0a817601..b2d57af634 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "splink" -version = "0.3.4" +version = "0.3.5" description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage." authors = ["Robin Linacre ", "Sam Lindsay", "Theodore Manassis"] license = "MIT"