Skip to content

Commit

Permalink
added test for prob density
Browse files Browse the repository at this point in the history
  • Loading branch information
mamonu committed Dec 3, 2020
1 parent 9279b6e commit 59f9a85
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions tests/test_diagnostics.py
Expand Up @@ -89,21 +89,37 @@ def test_score_hist_output_json(spark, gamma_settings_4, params_4, sqlite_con_4)
test chart exported as dictionary is in fact a valid dictionary
"""

altair_installed = True
try:
import altair as alt
except ImportError:
altair_installed = False

dfpd = pd.read_sql("select * from df", sqlite_con_4)
df = spark.createDataFrame(dfpd)
df = df.withColumn("tf_adjusted_match_prob", 1.0 - (f.rand() / 10))

res3 = _calc_probability_density(df, spark=spark, buckets=5)


if (altair_installed):

if altair_installed:
assert isinstance(_create_probability_density_plot(res3).to_dict(), dict)
else:
assert isinstance(_create_probability_density_plot(res3), dict)


def test_prob_density(spark, gamma_settings_4, params_4, sqlite_con_4):

"""
a test that checks that probability density is computed correctly.
explicitly define a dataframe with tf_adjusted_match_prob = [0.1, 0.3, 0.5, 0.7, 0.9]
and make sure that the probability density is the correct value (0.2) with all 5 bins
"""

dfpd = pd.DataFrame([0.1, 0.3, 0.5, 0.7, 0.9], columns=["tf_adjusted_match_prob"])
spdf = spark.createDataFrame(dfpd)

res = _calc_probability_density(spdf, spark=spark, buckets=5)
assert all(value == pytest.approx(0.2) for value in res.normalised.values)

0 comments on commit 59f9a85

Please sign in to comment.