Skip to content

Commit

Permalink
rename cum_{res,err} to cumulative_{residual,error}
Browse files Browse the repository at this point in the history
add better test coverage for test_residual_hist() and test_true_pred_hist()
remove add_dropdown() helper func
  • Loading branch information
janosh committed Aug 30, 2022
1 parent d21e6b4 commit 5d52181
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 60 deletions.
10 changes: 5 additions & 5 deletions assets/_generate_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pymatgen.ext.matproj import MPRester

from pymatviz.correlation import marchenko_pastur
from pymatviz.cumulative import cum_err, cum_res
from pymatviz.cumulative import cumulative_error, cumulative_residual
from pymatviz.histograms import (
hist_elemental_prevalence,
residual_hist,
Expand Down Expand Up @@ -143,20 +143,20 @@


error_decay_with_uncert(y_true, y_pred, y_std)
save_and_compress_svg("err-decay")
save_and_compress_svg("error-decay-with-uncert")

eps = 0.2 * np.random.randn(*y_std.shape)

error_decay_with_uncert(y_true, y_pred, {"better": y_std, "worse": y_std + eps})
save_and_compress_svg("err-decay-multiple")
save_and_compress_svg("error-decay-with-uncert-multiple")


# %% Cumulative Plots
cum_err(y_pred, y_true)
cumulative_error(y_pred, y_true)
save_and_compress_svg("cumulative-error")


cum_res(y_pred, y_true)
cumulative_residual(y_pred, y_true)
save_and_compress_svg("cumulative-residual")


Expand Down
File renamed without changes
File renamed without changes
5 changes: 2 additions & 3 deletions pymatviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# and https://peps.python.org/pep-0484/#stub-files 'Additional notes on stub files'
from pymatviz.correlation import marchenko_pastur as marchenko_pastur
from pymatviz.correlation import marchenko_pastur_pdf as marchenko_pastur_pdf
from pymatviz.cumulative import add_dropdown as add_dropdown
from pymatviz.cumulative import cum_err as cum_err
from pymatviz.cumulative import cum_res as cum_res
from pymatviz.cumulative import cumulative_error as cumulative_error
from pymatviz.cumulative import cumulative_residual as cumulative_residual
from pymatviz.histograms import hist_elemental_prevalence as hist_elemental_prevalence
from pymatviz.histograms import residual_hist as residual_hist
from pymatviz.histograms import spacegroup_hist as spacegroup_hist
Expand Down
41 changes: 14 additions & 27 deletions pymatviz/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,7 @@
from pymatviz.utils import Array


def add_dropdown(ax: plt.Axes, percentile: int, err: Array) -> None:
"""Add a dashed drop-down line at a given percentile.
Args:
ax (Axes): matplotlib Axes on which to add the dropdown.
percentile (int): Integer in range(100) at which to display dropdown line.
err (array): Numpy array of errors = abs(preds - targets).
"""
percent = int(percentile * (len(err) - 1) / 100 + 0.5)
ax.plot((0, err[percent]), (percentile, percentile), "--", color="grey", alpha=0.4)
ax.plot(
(err[percent], err[percent]), (0, percentile), "--", color="grey", alpha=0.4
)


def cum_res(
def cumulative_residual(
preds: Array, targets: Array, ax: plt.Axes = None, **kwargs: Any
) -> plt.Axes:
"""Plot the empirical cumulative distribution for the residuals (y - mu).
Expand Down Expand Up @@ -70,15 +55,15 @@ def cum_res(

# Label the plot
ax.set(xlabel="Residual", ylabel="Percentile", title="Cumulative Residual")
ax.legend(frameon=False)

return ax


def cum_err(
def cumulative_error(
preds: Array, targets: Array, ax: plt.Axes = None, **kwargs: Any
) -> plt.Axes:
"""Plot the empirical cumulative distribution for the absolute errors abs(y - y_hat).
"""Plot the empirical cumulative distribution of the absolute errors
abs(y_true - y_pred).
Args:
preds (array): Numpy array of predictions.
Expand All @@ -92,22 +77,24 @@ def cum_err(
if ax is None:
ax = plt.gca()

err = np.sort(np.abs(preds - targets))
n_data = len(err)
errors = np.sort(np.abs(preds - targets))
n_data = len(errors)

# Plot the empirical distribution
ax.plot(err, np.arange(n_data) / n_data * 100, **kwargs)
ax.plot(errors, np.arange(n_data) / n_data * 100, **kwargs)

# Get robust (and symmetrical) x axis limits
lim = np.percentile(err, 98)
# Get robust (and symmetrical) x-axis limits
lim = np.percentile(errors, 98)
ax.set(xlim=(0, lim), ylim=(0, 100))

line_kwargs = dict(linestyle="--", color="grey", alpha=0.4)
# Add some visual guidelines
add_dropdown(ax, 50, err)
add_dropdown(ax, 75, err)
for percentile in [50, 75]:
percent = int(percentile * (n_data - 1) / 100 + 0.5)
ax.plot((0, errors[percent]), (percentile, percentile), **line_kwargs)
ax.plot((errors[percent], errors[percent]), (0, percentile), **line_kwargs)

# Label the plot
ax.set(xlabel="Absolute Error", ylabel="Percentile", title="Cumulative Error")
ax.legend(frameon=False)

return ax
8 changes: 3 additions & 5 deletions pymatviz/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def residual_hist(
ax.plot(x_range, kde(x_range), linewidth=3, color="red", label=label)

ax.set(xlabel=xlabel)
ax.legend(loc=2, framealpha=0.5, handlelength=1)
ax.legend(loc="upper left", framealpha=0.5, handlelength=1)

return ax

Expand All @@ -71,7 +71,6 @@ def true_pred_hist(
y_std: Array,
ax: plt.Axes = None,
cmap: str = "hot",
bins: int = 50,
truth_color: str = "blue",
**kwargs: Any,
) -> plt.Axes:
Expand All @@ -85,7 +84,6 @@ def true_pred_hist(
y_std (array): model uncertainty
ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None.
cmap (str, optional): string identifier of a plt colormap. Defaults to 'hot'.
bins (int, optional): Histogram resolution. Defaults to 50.
truth_color (str, optional): Face color to use for y_true bars.
Defaults to 'blue'.
**kwargs: Additional keyword arguments to pass to ax.hist().
Expand All @@ -100,9 +98,9 @@ def true_pred_hist(
y_true, y_pred, y_std = np.array([y_true, y_pred, y_std])

_, bin_edges, bars = ax.hist(
y_pred, bins=bins, alpha=0.8, label=r"$y_\mathrm{pred}$", **kwargs
y_pred, alpha=0.8, label=r"$y_\mathrm{pred}$", **kwargs
)
ax.figure.set
kwargs.pop("bins", None)
ax.hist(
y_true,
bins=bin_edges,
Expand Down
2 changes: 1 addition & 1 deletion pymatviz/uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def error_decay_with_uncert(
)
ax.set(ylim=[0, rand_mean.mean() * 1.3], ylabel="MAE")

# n: Number of remaining points in err calculation after discarding the
# n: Number of remaining points in error calculation after discarding the
# (len(y_true) - n) most uncertain/hightest-error points
ax.set(xlabel="Confidence percentiles" if percentiles else "Excluded samples")
ax.legend(loc="lower left")
Expand Down
13 changes: 6 additions & 7 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ See [`pymatviz/uncertainty.py`](pymatviz/uncertainty.py).
| :-------------------------------------------------------------------------: | :-------------------------------------------------------------------------------: |
| ![normal-prob-plot] | ![normal-prob-plot-multiple] |
| [`error_decay_with_uncert(y_true, y_pred, y_std)`](pymatviz/uncertainty.py) | [`error_decay_with_uncert(y_true, y_pred, y_std: dict)`](pymatviz/uncertainty.py) |
| ![err-decay] | ![err-decay-multiple] |
| ![error-decay-with-uncert] | ![error-decay-with-uncert-multiple] |

## Cumulative Error and Residual

See [`pymatviz/cumulative.py`](pymatviz/cumulative.py).

| [`cum_err(preds, targets)`](pymatviz/cumulative.py) | [`cum_res(preds, targets)`](pymatviz/cumulative.py) |
| :-------------------------------------------------: | :-------------------------------------------------: |
| ![cumulative-error] | ![cumulative-residual] |
| [`cumulative_error(preds, targets)`](pymatviz/cumulative.py) | [`cumulative_residual(preds, targets)`](pymatviz/cumulative.py) |
| :----------------------------------------------------------: | :-------------------------------------------------------------: |
| ![cumulative-error] | ![cumulative-residual] |

## Classification Metrics

Expand Down Expand Up @@ -138,8 +138,8 @@ For the time being, Google Colab only supports Python 3.7. `pymatviz` uses Pytho
[density-hexbin]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/density-hexbin.svg
[density-scatter-with-hist]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/density-scatter-with-hist.svg
[density-scatter]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/density-scatter.svg
[err-decay-multiple]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/err-decay-multiple.svg
[err-decay]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/err-decay.svg
[error-decay-with-uncert-multiple]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/error-decay-with-uncert-multiple.svg
[error-decay-with-uncert]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/error-decay-with-uncert.svg
[hist-elemental-prevalence]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/hist-elemental-prevalence.svg
[marchenko-pastur-significant-eval]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/marchenko-pastur-significant-eval.svg
[marchenko-pastur]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/marchenko-pastur.svg
Expand All @@ -161,6 +161,5 @@ For the time being, Google Colab only supports Python 3.7. `pymatviz` uses Pytho
[spg-symbol-sunburst]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/spg-symbol-sunburst.svg
[struct-2d-mp-12712-Hf3Zr3Pd8-disordered]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/struct-2d-mp-12712-Hf3Zr3Pd8-disordered.svg
[struct-2d-mp-19017-Li4Fe3P4CO16-disordered]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/struct-2d-mp-19017-Li4Fe3P4CO16-disordered.svg
[true-pred-hist]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/true-pred-hist.svg
[sankey-from-2-df-cols-randints]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/sankey-from-2-df-cols-randints.svg
[sankey-spglib-vs-aflow-spacegroups]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/sankey-spglib-vs-aflow-spacegroups.svg
18 changes: 11 additions & 7 deletions tests/test_cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
import matplotlib.pyplot as plt
import pytest

from pymatviz import cum_err, cum_res

from .conftest import y_pred, y_true
from pymatviz import cumulative_error, cumulative_residual
from tests.conftest import y_pred, y_true


@pytest.mark.parametrize("alpha", [None, 0.5])
def test_cum_err(alpha: float) -> None:
ax = cum_err(y_pred, y_true, alpha=alpha)
def test_cumulative_error(alpha: float) -> None:
ax = cumulative_error(y_pred, y_true, alpha=alpha)
assert isinstance(ax, plt.Axes)


def test_cum_res():
ax = cum_res(y_pred, y_true)
def test_cumulative_residual():
ax = cumulative_residual(y_pred, y_true)
assert isinstance(ax, plt.Axes)
assert len(ax.lines) == 3
assert ax.get_xlabel() == "Residual"
assert ax.get_ylabel() == "Percentile"
assert ax.get_title() == "Cumulative Residual"
assert ax.get_ylim() == (0, 100)
21 changes: 16 additions & 5 deletions tests/test_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,29 @@
from pymatgen.core import Structure

from pymatviz import residual_hist, spacegroup_hist, true_pred_hist

from .conftest import y_pred, y_true
from tests.conftest import y_pred, y_true


@pytest.mark.parametrize("bins", [None, 1, 100])
@pytest.mark.parametrize("xlabel", [None, "foo"])
def test_residual_hist(bins: int | None, xlabel: str | None) -> None:
residual_hist(y_true, y_pred, bins=bins, xlabel=xlabel)
ax = residual_hist(y_true, y_pred, bins=bins, xlabel=xlabel)

assert isinstance(ax, plt.Axes)
assert (
ax.get_xlabel() == xlabel or r"Residual ($y_\mathrm{test} - y_\mathrm{pred}$)"
)
assert len(ax.lines) == 1
legend = ax.get_legend()
assert len(ax.patches) == bins or 10
assert legend._get_loc() == 2 # 2 meaning 'upper left'


def test_true_pred_hist():
true_pred_hist(y_true, y_pred, y_true - y_pred)
@pytest.mark.parametrize("bins", [None, 1, 100])
@pytest.mark.parametrize("cmap", ["hot", "Blues"])
def test_true_pred_hist(bins: int | None, cmap: str) -> None:
ax = true_pred_hist(y_true, y_pred, y_true - y_pred, bins=bins, cmap=cmap)
assert isinstance(ax, plt.Axes)


@pytest.mark.parametrize("xticks", ["all", "crys_sys_edges", 1, 50])
Expand Down

0 comments on commit 5d52181

Please sign in to comment.