Skip to content

Commit

Permalink
allow comparative roc
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Oct 28, 2021
1 parent e7eb996 commit 81a986c
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions splink/truth.py
Expand Up @@ -3,6 +3,8 @@
from pyspark.sql import Window
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.session import SparkSession
import pyspark
from typing import Union

altair_installed = True
try:
Expand Down Expand Up @@ -572,7 +574,7 @@ def truth_space_table(


def roc_chart(
df_labels_with_splink_scores: DataFrame,
df_labels_with_splink_scores: Union[DataFrame, dict],
spark: SparkSession,
threshold_actual: float = 0.5,
x_domain: list = None,
Expand All @@ -582,8 +584,10 @@ def roc_chart(
"""Create a ROC chart from labelled data
Args:
df_labels_with_splink_scores (DataFrame): A dataframe of labels and associated splink scores
usually the output of the truth.labels_with_splink_scores function
df_labels_with_splink_scores (Union[DataFrame, dict]): A dataframe of labels and associated splink scores
usually the output of the truth.labels_with_splink_scores function. Or, a dict containing
one such dataframe per key. {'model 1': df1, 'model 2': df2}. If a dict is provided, the
ROC charts for each model will be plotted on the same figure.
spark (SparkSession): SparkSession object
threshold_actual (float, optional): Threshold to use in categorising clerical match
scores into match or no match. Defaults to 0.5.
Expand All @@ -594,7 +598,6 @@ def roc_chart(
"""

roc_chart_def = {
"$schema": "https://vega.github.io/schema/vega-lite/v4.8.1.json",
"config": {"view": {"continuousWidth": 400, "continuousHeight": 300}},
"data": {"name": "data-fadd0e93e9546856cbc745a99e65285d", "values": None},
"mark": {"type": "line", "clip": True, "point": True},
Expand Down Expand Up @@ -622,6 +625,10 @@ def roc_chart(
"sort": ["truth_threshold"],
"title": "True Positive Rate amongst clerically reviewed records",
},
"color": {
"type": "nominal",
"field": "roc_label",
},
},
"selection": {
"selector076": {
Expand All @@ -635,9 +642,18 @@ def roc_chart(
"width": width,
}

data = truth_space_table(
df_labels_with_splink_scores, spark, threshold_actual=threshold_actual
).toPandas()
if type(df_labels_with_splink_scores) == pyspark.sql.DataFrame:
del roc_chart_def["encoding"]["color"]
df_labels_with_splink_scores = {"model1": df_labels_with_splink_scores}

dfs = []
for key, df in df_labels_with_splink_scores.items():
data = truth_space_table(
df, spark, threshold_actual=threshold_actual
).toPandas()
data["roc_label"] = key

dfs.append(data)

if not x_domain:

Expand All @@ -649,9 +665,12 @@ def roc_chart(

roc_chart_def["encoding"]["x"]["scale"] = {"domain": x_domain}

data = data.to_dict(orient="records")
records = []
for df in dfs:
recs = df.to_dict(orient="records")
records.extend(recs)

roc_chart_def["data"]["values"] = data
roc_chart_def["data"]["values"] = records

if altair_installed:
return alt.Chart.from_dict(roc_chart_def)
Expand Down

0 comments on commit 81a986c

Please sign in to comment.