Skip to content

Commit

Permalink
Merge pull request #148 from moj-analytical-services/roc2
Browse files Browse the repository at this point in the history
Make ROC faster and more granular
  • Loading branch information
RobinL committed Dec 12, 2020
2 parents b4f601e + 23ed5f5 commit d354b38
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "splink"
version = "0.3.4"
version = "0.3.5"
description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage."
authors = ["Robin Linacre <robinlinacre@hotmail.com>", "Sam Lindsay", "Theodore Manassis"]
license = "MIT"
Expand Down
144 changes: 143 additions & 1 deletion splink/truth.py
Expand Up @@ -393,7 +393,7 @@ def df_e_with_truth_categories(
return spark.sql(sql)


def truth_space_table(
def _truth_space_table_old(
df_labels_with_splink_scores: DataFrame,
spark: SparkSession,
threshold_actual: float = 0.5,
Expand Down Expand Up @@ -439,6 +439,148 @@ def truth_space_table(
return all_roc_df


def truth_space_table(
df_labels_with_splink_scores: DataFrame,
spark: SparkSession,
threshold_actual: float = 0.5,
score_colname: str = None,
):
"""Create a table of the ROC space i.e. truth table statistics
for each discrimination threshold
Args:
df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores
usually the output of the truth.labels_with_splink_scores function
threshold_actual (float, optional): Threshold to use in categorising clerical match
scores into match or no match. Defaults to 0.5.
score_colname (float, optional): Allows user to explicitly state the column name
in the Splink dataset containing the Splink score. If none will be inferred
Returns:
DataFrame: Table of 'truth space' i.e. truth categories for each threshold level
"""

# At a truth threshold of 1.0, we say a splink score of 1.0 is a positive in ROC space. i.e it's inclusive, so if there are splink scores of exactly 1.0 it's not possible to have zero positives in the truth table.
# This means that at a truth threshold of 0.0 we say a splink score of 0.0 positive. so it's possible to have zero negatives in the truth table.

# This code provides an efficient way to compute the truth space
# It's more complex than the previous code, but executes much faster because only a single SQL query/PySpark Action is needed
# The previous implementation, which is easier to understand, is [here](https://github.com/moj-analytical-services/splink/blob/b4f601e6d180c6abfd64ab40775dca3e3513c0b5/splink/truth.py#L396)

# We start with df_labels_with_splink_scores
# This is a table of each pair of clerically labelled records accompanied by the Splink match score.
# It is sorted in order to clerical_match_score, low to high

# This means for any row, if we picked a threshold equal to clerical_match_score, all rows _above_ (in the table) are categoried by splink as non-matches.

# For instance, if a clerical_match_score is 0.25, then any records above this in the table have a score of <0.25, and are therefore negative. We categorise a score of exactly 0.25 as positive.

# In addition, we can categorise any indiviual row as containing a false positive or false negative _at_ the clerical match score for the row.

# This allows us to say things like: Of the records above this row, we have _classified_ them all as negative, but we have _seen_ to true (clerically labelled) positives. Thefore these must be false negatives.

# In particular, the calculations are as follows:
# False positives: The cumulative total of positive labels in records BELOW this row, INCLUSIVE (because this one is being counted as positive)
# True positives: The total number of positives minus false positives

# False negatives: The total number of negatives, minus negatives seen above this row
# True negatives: The cumulative total of negative labels in records aboev this row

# We want percentiles of score to compute
score_colname = _get_score_colname(df_labels_with_splink_scores, score_colname)

df_labels_with_splink_scores.createOrReplaceTempView("df_labels_with_splink_scores")
sql = f"""
select
*,
{score_colname} as truth_threshold,
case when clerical_match_score >= {threshold_actual} then 1
else 0
end
as c_P,
case when clerical_match_score >= {threshold_actual} then 0
else 1
end
as c_N
from df_labels_with_splink_scores
order by {score_colname}
"""
df_with_labels = spark.sql(sql)
df_with_labels.createOrReplaceTempView("df_with_labels")

sql = """
select truth_threshold, count(*) as num_records_in_row, sum(c_P) as c_P, sum(c_N) as c_N
from
df_with_labels
group by truth_threshold
order by truth_threshold
"""
df_with_labels_grouped = spark.sql(sql)
df_with_labels_grouped.createOrReplaceTempView("df_with_labels_grouped")

sql = """
select
truth_threshold,
(sum(c_P) over (order by truth_threshold desc)) as cum_clerical_P,
(sum(c_N) over (order by truth_threshold)) - c_N as cum_clerical_N,
(select sum(c_P) from df_with_labels_grouped) as total_clerical_P,
(select sum(c_N) from df_with_labels_grouped) as total_clerical_N,
(select sum(num_records_in_row) from df_with_labels_grouped) as row_count,
-num_records_in_row + sum(num_records_in_row) over (order by truth_threshold) as N_labels,
sum(num_records_in_row) over (order by truth_threshold desc) as P_labels
from df_with_labels_grouped
order by truth_threshold
"""
df_with_cumulative_labels = spark.sql(sql)
df_with_cumulative_labels.createOrReplaceTempView("df_with_cumulative_labels")

sql = """
select
truth_threshold,
row_count,
total_clerical_P as P,
total_clerical_N as N,
P_labels - cum_clerical_P as FP,
cum_clerical_P as TP,
N_labels - cum_clerical_N as FN,
cum_clerical_N as TN
from df_with_cumulative_labels
"""
df_with_truth_cats = spark.sql(sql)
df_with_truth_cats.createOrReplaceTempView("df_with_truth_cats")
df_with_truth_cats.toPandas()

sql = """
select
truth_threshold,
row_count,
P,
N,
TP,
TN,
FP,
FN,
P/row_count as P_rate,
N/row_count as N_rate,
TP/P as TP_rate,
TN/N as TN_rate,
FP/N as FP_rate,
FN/P as FN_rate,
TP/(TP+FP) as precision,
TP/(TP+FN) as recall
from df_with_truth_cats
"""
df_truth_space = spark.sql(sql)

return df_truth_space


def roc_chart(
df_labels_with_splink_scores: DataFrame,
spark: SparkSession,
Expand Down

0 comments on commit d354b38

Please sign in to comment.