Skip to content

Commit

Permalink
ENH add quantile to reliability diagram (#79)
Browse files Browse the repository at this point in the history
* MNT replace assert_almost_equal by assert_allclose

* ENH add quantile to plot_reliability_diagram
  • Loading branch information
lorentzenchr committed Jul 10, 2023
1 parent a66cb76 commit 54252f6
Show file tree
Hide file tree
Showing 4 changed files with 2,447 additions and 39 deletions.
68 changes: 48 additions & 20 deletions src/model_diagnostics/calibration/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy.typing as npt
import polars as pl
from scipy.stats import bootstrap
from sklearn.isotonic import IsotonicRegression
from sklearn.isotonic import IsotonicRegression as IsotonicRegression_skl

from model_diagnostics._utils._array import (
array_name,
Expand All @@ -17,6 +17,7 @@
get_sorted_array_names,
length_of_second_dimension,
)
from model_diagnostics._utils.isotonic import IsotonicRegression

from .identification import compute_bias

Expand All @@ -26,6 +27,8 @@ def plot_reliability_diagram(
y_pred: npt.ArrayLike,
weights: Optional[npt.ArrayLike] = None,
*,
functional: str = "mean",
level: float = 0.5,
n_bootstrap: Optional[str] = None,
confidence_level: float = 0.9,
diagram_type: str = "reliability",
Expand All @@ -49,6 +52,17 @@ def plot_reliability_diagram(
Predicted values of the conditional expectation of Y, `E(Y|X)`.
weights : array-like of shape (n_obs) or None
Case weights.
functional : str
The functional that is induced by the identification function `V`. Options are:
- `"mean"`. Argument `level` is neglected.
- `"median"`. Argument `level` is neglected.
- `"expectile"`
- `"quantile"`
level : float
The level of the expectile or quantile. (Often called \(\alpha\).)
It must be `0 <= level <= 1`.
`level=0.5` and `functional="expectile"` gives the mean.
`level=0.5` and `functional="quantile"` gives the median.
n_bootstrap : int or None
If not `None`, then `scipy.stats.bootstrap` with `n_resamples=n_bootstrap`
is used to calculate confidence intervals at level `confidence_level`.
Expand Down Expand Up @@ -117,26 +131,40 @@ def plot_reliability_diagram(
ax.hlines(0, xmin=y_min, xmax=y_max, color="k", linestyle="dotted")

if n_bootstrap is not None:
if functional == "mean":

def iso_statistic(y_obs, y_pred, weights=None, x_values=None):
iso_b = (
IsotonicRegression(out_of_bounds="clip")
.set_output(transform="default")
.fit(y_pred, y_obs, sample_weight=weights)
)
return iso_b.predict(x_values)
def iso_statistic(y_obs, y_pred, weights=None, x_values=None):
iso_b = (
IsotonicRegression_skl(out_of_bounds="clip")
.set_output(transform="default")
.fit(y_pred, y_obs, sample_weight=weights)
)
return iso_b.predict(x_values)

else:

def iso_statistic(y_obs, y_pred, weights=None, x_values=None):
iso_b = IsotonicRegression(functional=functional, level=level).fit(
y_pred, y_obs, sample_weight=weights
)
return iso_b.predict(x_values)

n_pred = length_of_second_dimension(y_pred)
pred_names, _ = get_sorted_array_names(y_pred)

for i in range(len(pred_names)):
y_pred_i = y_pred if n_pred == 0 else get_second_dimension(y_pred, i)

iso = (
IsotonicRegression()
.set_output(transform="default")
.fit(y_pred_i, y_obs, sample_weight=weights)
)
if functional == "mean":
iso = (
IsotonicRegression_skl()
.set_output(transform="default")
.fit(y_pred_i, y_obs, sample_weight=weights)
)
else:
iso = IsotonicRegression(functional=functional, level=level).fit(
y_pred_i, y_obs, sample_weight=weights
)

# confidence intervals
if n_bootstrap is not None:
Expand Down Expand Up @@ -229,20 +257,20 @@ def plot_bias(
Predicted values of the conditional expectation of Y, `E(Y|X)`.
feature : array-like of shape (n_obs) or None
Some feature column.
functional : str
The functional that is induced by the identification function `V`. Options are:
- `"mean"`. Argument `level` is neglected.
- `"median"`. Argument `level` is neglected.
- `"expectile"`
- `"quantile"`
weights : array-like of shape (n_obs) or None
Case weights. If given, the bias is calculated as weighted average of the
identification function with these weights.
Note that the standard errors and p-values in the output are based on the
assumption that the variance of the bias is inverse proportional to the
weights. See the Notes section for details.
functional : str
The functional that is induced by the identification function `V`. Options are:
- `"mean"`. Argument `level` is neglected.
- `"median"`. Argument `level` is neglected.
- `"expectile"`
- `"quantile"`
level : float
The level of the expectile of quantile. (Often called \(\alpha\).)
The level of the expectile or quantile. (Often called \(\alpha\).)
It must be `0 <= level <= 1`.
`level=0.5` and `functional="expectile"` gives the mean.
`level=0.5` and `functional="quantile"` gives the median.
Expand Down
20 changes: 16 additions & 4 deletions src/model_diagnostics/calibration/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@

@pytest.mark.parametrize(
("param", "value", "msg"),
[("diagram_type", "XXX", "Parameter diagram_type must be either.*XXX")],
[
("diagram_type", "XXX", "Parameter diagram_type must be either.*XXX"),
("functional", "XXX", "Argument functional must be one of.*XXX"),
("level", 2, "Argument level must fulfil 0 < level < 1, got 2"),
],
)
def test_plot_reliability_diagram_raises(param, value, msg):
"""Test that plot_reliability_diagram raises errors."""
y_obs = [0, 1]
y_pred = [-1, 1]
d = {param: value}
if "functional" not in d.keys():
d["functional"] = "quantile" # as a default
with pytest.raises(ValueError, match=msg):
plot_reliability_diagram(y_obs=y_obs, y_pred=y_pred, **{param: value})
plot_reliability_diagram(y_obs=y_obs, y_pred=y_pred, **d)


def test_plot_reliability_diagram_raises_y_obs_multdim():
Expand All @@ -34,17 +41,20 @@ def test_plot_reliability_diagram_raises_y_obs_multdim():


@pytest.mark.parametrize("diagram_type", ["reliability", "bias"])
@pytest.mark.parametrize("functional", ["mean", "expectile", "quantile"])
@pytest.mark.parametrize("n_bootstrap", [None, 10])
@pytest.mark.parametrize("weights", [None, True])
@pytest.mark.parametrize("ax", [None, plt.subplots()[1]])
def test_plot_reliability_diagram(diagram_type, n_bootstrap, weights, ax):
def test_plot_reliability_diagram(diagram_type, functional, n_bootstrap, weights, ax):
"""Test that plot_reliability_diagram works."""
X, y = make_classification(random_state=42, n_classes=2)
if weights is None:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
w_train, w_test = None, None
elif functional == "quantile":
pytest.skip("Weighted quantiles are not implemented.")
else:
weights = np.random.default_rng(42).integers(low=0, high=10, size=y.shape)
weights = np.random.default_rng(42).integers(low=1, high=10, size=y.shape)
X_train, X_test, y_train, y_test, w_train, w_test = train_test_split(
X, y, weights, random_state=0
)
Expand All @@ -56,6 +66,8 @@ def test_plot_reliability_diagram(diagram_type, n_bootstrap, weights, ax):
y_pred=y_pred,
weights=w_test,
ax=ax,
functional=functional,
level=0.8,
n_bootstrap=n_bootstrap,
diagram_type=diagram_type,
)
Expand Down
20 changes: 5 additions & 15 deletions src/model_diagnostics/scoring/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ def plot_murphy_diagram(
):
r"""Plot a Murphy diagram.
A reliability diagram or calibration curve assesses auto-calibration. It plots the
conditional expectation given the predictions `E(y_obs|y_pred)` (y-axis) vs the
predictions `y_pred` (x-axis).
The conditional expectation is estimated via isotonic regression (PAV algorithm)
of `y_obs` on `y_pred`.
A Murphy diagram plots the scores of elementary scoring functions `ElementaryScore`
over a range of their free parameter `eta`. This shows, if a model dominates all
others over a wide class of scoring functions or if the ranking is very much
dependent on the choice of scoring function.
See Notes for further details.
Parameters
Expand Down Expand Up @@ -67,16 +66,7 @@ def plot_murphy_diagram(
Notes
-----
The expectation conditional on the predictions is \(E(Y|y_{pred})\). This object is
estimated by the pool-adjacent violator (PAV) algorithm, which has very desirable
properties:
- It is non-parametric without any tuning parameter. Thus, the results are
easily reproducible.
- Optimal selection of bins
- Statistical consistent estimator
For details, refer to [Dimitriadis2021].
For details, refer to [Ehm2015].
References
----------
Expand Down
2,378 changes: 2,378 additions & 0 deletions untitled.ipynb

Large diffs are not rendered by default.

0 comments on commit 54252f6

Please sign in to comment.