Skip to content

Commit

Permalink
Merge pull request #239 from moj-analytical-services/improve_histogram
Browse files Browse the repository at this point in the history
Allow match weight to be used in the diagnostic histogram
  • Loading branch information
RobinL committed Nov 8, 2021
2 parents 7607520 + 1b4112a commit 4116314
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions splink/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@
from .charts import load_chart_definition, altair_if_installed_else_json


def _equal_spaced_buckets(num_buckets, extent):
buckets = [x for x in range(num_buckets + 1)]
span = extent[1] - extent[0]
buckets = [extent[0] + span * (x / num_buckets) for x in buckets]
return buckets


@typechecked
def _calc_probability_density(
df_e: DataFrame,
spark: SparkSession,
buckets=None,
score_colname="match_probability",
symmetric=True,
):

"""perform splink score histogram calculations / internal function
Expand All @@ -25,33 +33,41 @@ def _calc_probability_density(
df_e (DataFrame): A dataframe of record comparisons containing a
splink score, e.g. as produced by the expectation step
spark (SparkSession): SparkSession object
score_colname: is the score in another column? defaults to match_probability
score_colname: is the score in another column? defaults to match_probability. also try match_weight
buckets: accepts either a list of split points or an integer number that is used
to create equally spaced split points. It defaults to 100 equally
spaced split points from 0.0 to 1.0
spaced split points
symmetric : if True then the histogram is symmetric
Returns:
(list) : list of rows of histogram bins for appropriate splink score variable ready to be plotted.
"""

# if splits a list then use it. if None... then create default. if integer then create equal bins

if score_colname == "match_probability":
extent = (0.0, 1.0)
else:
weight_max = df_e.agg({score_colname: "max"}).collect()[0][0]
weight_min = df_e.agg({score_colname: "min"}).collect()[0][0]
extent = (weight_min, weight_max)
if symmetric:
extent_max = max(abs(weight_max), abs(weight_min))
extent = (-extent_max, extent_max)

# if buckets a list then use it. if None... then create default. if integer then create equal bins
if isinstance(buckets, int) and buckets != 0:
buckets = [(x / buckets) for x in list(range(buckets))]
buckets = _equal_spaced_buckets(buckets, extent)
elif buckets is None:
buckets = [(x / 100) for x in list(range(100))]

# ensure 0.0 and 1.0 are included in histogram
buckets = _equal_spaced_buckets(100, extent)

if buckets[0] != 0:
buckets = [0.0] + buckets

if buckets[-1] != 1.0:
buckets = buckets + [1.0]
buckets.sort()

# ensure bucket splits are in ascending order
if score_colname == "match_probability":
if buckets[0] != 0:
buckets = [0.0] + buckets

buckets.sort()
if buckets[-1] != 1.0:
buckets = buckets + [1.0]

hist = df_e.select(score_colname).rdd.flatMap(lambda x: x).histogram(buckets)

Expand Down Expand Up @@ -105,6 +121,7 @@ def splink_score_histogram(
spark: SparkSession,
buckets=None,
score_colname=None,
symmetric=True,
):

"""splink score histogram diagnostic plot public API function
Expand All @@ -118,13 +135,18 @@ def splink_score_histogram(
score_colname : is the score in another column? defaults to None
buckets : accepts either a list of split points or an integer number that is used to
create equally spaced split points. It defaults to 100 equally spaced split points from 0.0 to 1.0
symmetric : if True then the histogram is symmetric
Returns:
if altair library is installed this function returns a histogram plot. if altair is not installed
then it returns the vega lite chart spec as a dictionary
"""

rows = _calc_probability_density(
df_e, spark=spark, buckets=buckets, score_colname=score_colname
df_e,
spark=spark,
buckets=buckets,
score_colname=score_colname,
symmetric=symmetric,
)

return _create_probability_density_plot(rows)

0 comments on commit 4116314

Please sign in to comment.