From 5d52181c5af3da9d11755b1b01a19a7dc059961e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 30 Aug 2022 15:26:13 -0700 Subject: [PATCH] rename cum_{res,err} to cumulative_{residual,error} add better test coverage for test_residual_hist() and test_true_pred_hist() remove add_dropdown() helper func --- assets/_generate_assets.py | 10 ++--- ...g => error-decay-with-uncert-multiple.svg} | 0 ...-decay.svg => error-decay-with-uncert.svg} | 0 pymatviz/__init__.py | 5 +-- pymatviz/cumulative.py | 41 +++++++------------ pymatviz/histograms.py | 8 ++-- pymatviz/uncertainty.py | 2 +- readme.md | 13 +++--- tests/test_cumulative.py | 18 ++++---- tests/test_histograms.py | 21 +++++++--- 10 files changed, 58 insertions(+), 60 deletions(-) rename assets/{err-decay-multiple.svg => error-decay-with-uncert-multiple.svg} (100%) rename assets/{err-decay.svg => error-decay-with-uncert.svg} (100%) diff --git a/assets/_generate_assets.py b/assets/_generate_assets.py index 29a5ef55..b38e57d9 100644 --- a/assets/_generate_assets.py +++ b/assets/_generate_assets.py @@ -6,7 +6,7 @@ from pymatgen.ext.matproj import MPRester from pymatviz.correlation import marchenko_pastur -from pymatviz.cumulative import cum_err, cum_res +from pymatviz.cumulative import cumulative_error, cumulative_residual from pymatviz.histograms import ( hist_elemental_prevalence, residual_hist, @@ -143,20 +143,20 @@ error_decay_with_uncert(y_true, y_pred, y_std) -save_and_compress_svg("err-decay") +save_and_compress_svg("error-decay-with-uncert") eps = 0.2 * np.random.randn(*y_std.shape) error_decay_with_uncert(y_true, y_pred, {"better": y_std, "worse": y_std + eps}) -save_and_compress_svg("err-decay-multiple") +save_and_compress_svg("error-decay-with-uncert-multiple") # %% Cumulative Plots -cum_err(y_pred, y_true) +cumulative_error(y_pred, y_true) save_and_compress_svg("cumulative-error") -cum_res(y_pred, y_true) +cumulative_residual(y_pred, y_true) save_and_compress_svg("cumulative-residual") diff --git a/assets/err-decay-multiple.svg b/assets/error-decay-with-uncert-multiple.svg similarity index 100% rename from assets/err-decay-multiple.svg rename to assets/error-decay-with-uncert-multiple.svg diff --git a/assets/err-decay.svg b/assets/error-decay-with-uncert.svg similarity index 100% rename from assets/err-decay.svg rename to assets/error-decay-with-uncert.svg diff --git a/pymatviz/__init__.py b/pymatviz/__init__.py index 71215bd6..aa72296e 100644 --- a/pymatviz/__init__.py +++ b/pymatviz/__init__.py @@ -3,9 +3,8 @@ # and https://peps.python.org/pep-0484/#stub-files 'Additional notes on stub files' from pymatviz.correlation import marchenko_pastur as marchenko_pastur from pymatviz.correlation import marchenko_pastur_pdf as marchenko_pastur_pdf -from pymatviz.cumulative import add_dropdown as add_dropdown -from pymatviz.cumulative import cum_err as cum_err -from pymatviz.cumulative import cum_res as cum_res +from pymatviz.cumulative import cumulative_error as cumulative_error +from pymatviz.cumulative import cumulative_residual as cumulative_residual from pymatviz.histograms import hist_elemental_prevalence as hist_elemental_prevalence from pymatviz.histograms import residual_hist as residual_hist from pymatviz.histograms import spacegroup_hist as spacegroup_hist diff --git a/pymatviz/cumulative.py b/pymatviz/cumulative.py index fc4ffc41..9c2637d6 100644 --- a/pymatviz/cumulative.py +++ b/pymatviz/cumulative.py @@ -6,22 +6,7 @@ from pymatviz.utils import Array -def add_dropdown(ax: plt.Axes, percentile: int, err: Array) -> None: - """Add a dashed drop-down line at a given percentile. - - Args: - ax (Axes): matplotlib Axes on which to add the dropdown. - percentile (int): Integer in range(100) at which to display dropdown line. - err (array): Numpy array of errors = abs(preds - targets). - """ - percent = int(percentile * (len(err) - 1) / 100 + 0.5) - ax.plot((0, err[percent]), (percentile, percentile), "--", color="grey", alpha=0.4) - ax.plot( - (err[percent], err[percent]), (0, percentile), "--", color="grey", alpha=0.4 - ) - - -def cum_res( +def cumulative_residual( preds: Array, targets: Array, ax: plt.Axes = None, **kwargs: Any ) -> plt.Axes: """Plot the empirical cumulative distribution for the residuals (y - mu). @@ -70,15 +55,15 @@ def cum_res( # Label the plot ax.set(xlabel="Residual", ylabel="Percentile", title="Cumulative Residual") - ax.legend(frameon=False) return ax -def cum_err( +def cumulative_error( preds: Array, targets: Array, ax: plt.Axes = None, **kwargs: Any ) -> plt.Axes: - """Plot the empirical cumulative distribution for the absolute errors abs(y - y_hat). + """Plot the empirical cumulative distribution of the absolute errors + abs(y_true - y_pred). Args: preds (array): Numpy array of predictions. @@ -92,22 +77,24 @@ def cum_err( if ax is None: ax = plt.gca() - err = np.sort(np.abs(preds - targets)) - n_data = len(err) + errors = np.sort(np.abs(preds - targets)) + n_data = len(errors) # Plot the empirical distribution - ax.plot(err, np.arange(n_data) / n_data * 100, **kwargs) + ax.plot(errors, np.arange(n_data) / n_data * 100, **kwargs) - # Get robust (and symmetrical) x axis limits - lim = np.percentile(err, 98) + # Get robust (and symmetrical) x-axis limits + lim = np.percentile(errors, 98) ax.set(xlim=(0, lim), ylim=(0, 100)) + line_kwargs = dict(linestyle="--", color="grey", alpha=0.4) # Add some visual guidelines - add_dropdown(ax, 50, err) - add_dropdown(ax, 75, err) + for percentile in [50, 75]: + percent = int(percentile * (n_data - 1) / 100 + 0.5) + ax.plot((0, errors[percent]), (percentile, percentile), **line_kwargs) + ax.plot((errors[percent], errors[percent]), (0, percentile), **line_kwargs) # Label the plot ax.set(xlabel="Absolute Error", ylabel="Percentile", title="Cumulative Error") - ax.legend(frameon=False) return ax diff --git a/pymatviz/histograms.py b/pymatviz/histograms.py index b57771a6..83f6db4a 100644 --- a/pymatviz/histograms.py +++ b/pymatviz/histograms.py @@ -60,7 +60,7 @@ def residual_hist( ax.plot(x_range, kde(x_range), linewidth=3, color="red", label=label) ax.set(xlabel=xlabel) - ax.legend(loc=2, framealpha=0.5, handlelength=1) + ax.legend(loc="upper left", framealpha=0.5, handlelength=1) return ax @@ -71,7 +71,6 @@ def true_pred_hist( y_std: Array, ax: plt.Axes = None, cmap: str = "hot", - bins: int = 50, truth_color: str = "blue", **kwargs: Any, ) -> plt.Axes: @@ -85,7 +84,6 @@ def true_pred_hist( y_std (array): model uncertainty ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None. cmap (str, optional): string identifier of a plt colormap. Defaults to 'hot'. - bins (int, optional): Histogram resolution. Defaults to 50. truth_color (str, optional): Face color to use for y_true bars. Defaults to 'blue'. **kwargs: Additional keyword arguments to pass to ax.hist(). @@ -100,9 +98,9 @@ def true_pred_hist( y_true, y_pred, y_std = np.array([y_true, y_pred, y_std]) _, bin_edges, bars = ax.hist( - y_pred, bins=bins, alpha=0.8, label=r"$y_\mathrm{pred}$", **kwargs + y_pred, alpha=0.8, label=r"$y_\mathrm{pred}$", **kwargs ) - ax.figure.set + kwargs.pop("bins", None) ax.hist( y_true, bins=bin_edges, diff --git a/pymatviz/uncertainty.py b/pymatviz/uncertainty.py index 05dbe34d..9dfa51e8 100644 --- a/pymatviz/uncertainty.py +++ b/pymatviz/uncertainty.py @@ -239,7 +239,7 @@ def error_decay_with_uncert( ) ax.set(ylim=[0, rand_mean.mean() * 1.3], ylabel="MAE") - # n: Number of remaining points in err calculation after discarding the + # n: Number of remaining points in error calculation after discarding the # (len(y_true) - n) most uncertain/hightest-error points ax.set(xlabel="Confidence percentiles" if percentiles else "Excluded samples") ax.legend(loc="lower left") diff --git a/readme.md b/readme.md index 11699c0c..cd626594 100644 --- a/readme.md +++ b/readme.md @@ -96,15 +96,15 @@ See [`pymatviz/uncertainty.py`](pymatviz/uncertainty.py). | :-------------------------------------------------------------------------: | :-------------------------------------------------------------------------------: | | ![normal-prob-plot] | ![normal-prob-plot-multiple] | | [`error_decay_with_uncert(y_true, y_pred, y_std)`](pymatviz/uncertainty.py) | [`error_decay_with_uncert(y_true, y_pred, y_std: dict)`](pymatviz/uncertainty.py) | -| ![err-decay] | ![err-decay-multiple] | +| ![error-decay-with-uncert] | ![error-decay-with-uncert-multiple] | ## Cumulative Error and Residual See [`pymatviz/cumulative.py`](pymatviz/cumulative.py). -| [`cum_err(preds, targets)`](pymatviz/cumulative.py) | [`cum_res(preds, targets)`](pymatviz/cumulative.py) | -| :-------------------------------------------------: | :-------------------------------------------------: | -| ![cumulative-error] | ![cumulative-residual] | +| [`cumulative_error(preds, targets)`](pymatviz/cumulative.py) | [`cumulative_residual(preds, targets)`](pymatviz/cumulative.py) | +| :----------------------------------------------------------: | :-------------------------------------------------------------: | +| ![cumulative-error] | ![cumulative-residual] | ## Classification Metrics @@ -138,8 +138,8 @@ For the time being, Google Colab only supports Python 3.7. `pymatviz` uses Pytho [density-hexbin]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/density-hexbin.svg [density-scatter-with-hist]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/density-scatter-with-hist.svg [density-scatter]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/density-scatter.svg -[err-decay-multiple]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/err-decay-multiple.svg -[err-decay]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/err-decay.svg +[error-decay-with-uncert-multiple]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/error-decay-with-uncert-multiple.svg +[error-decay-with-uncert]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/error-decay-with-uncert.svg [hist-elemental-prevalence]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/hist-elemental-prevalence.svg [marchenko-pastur-significant-eval]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/marchenko-pastur-significant-eval.svg [marchenko-pastur]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/marchenko-pastur.svg @@ -161,6 +161,5 @@ For the time being, Google Colab only supports Python 3.7. `pymatviz` uses Pytho [spg-symbol-sunburst]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/spg-symbol-sunburst.svg [struct-2d-mp-12712-Hf3Zr3Pd8-disordered]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/struct-2d-mp-12712-Hf3Zr3Pd8-disordered.svg [struct-2d-mp-19017-Li4Fe3P4CO16-disordered]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/struct-2d-mp-19017-Li4Fe3P4CO16-disordered.svg -[true-pred-hist]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/true-pred-hist.svg [sankey-from-2-df-cols-randints]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/sankey-from-2-df-cols-randints.svg [sankey-spglib-vs-aflow-spacegroups]: https://raw.githubusercontent.com/janosh/pymatviz/main/assets/sankey-spglib-vs-aflow-spacegroups.svg diff --git a/tests/test_cumulative.py b/tests/test_cumulative.py index df199023..9447780a 100644 --- a/tests/test_cumulative.py +++ b/tests/test_cumulative.py @@ -3,17 +3,21 @@ import matplotlib.pyplot as plt import pytest -from pymatviz import cum_err, cum_res - -from .conftest import y_pred, y_true +from pymatviz import cumulative_error, cumulative_residual +from tests.conftest import y_pred, y_true @pytest.mark.parametrize("alpha", [None, 0.5]) -def test_cum_err(alpha: float) -> None: - ax = cum_err(y_pred, y_true, alpha=alpha) +def test_cumulative_error(alpha: float) -> None: + ax = cumulative_error(y_pred, y_true, alpha=alpha) assert isinstance(ax, plt.Axes) -def test_cum_res(): - ax = cum_res(y_pred, y_true) +def test_cumulative_residual(): + ax = cumulative_residual(y_pred, y_true) assert isinstance(ax, plt.Axes) + assert len(ax.lines) == 3 + assert ax.get_xlabel() == "Residual" + assert ax.get_ylabel() == "Percentile" + assert ax.get_title() == "Cumulative Residual" + assert ax.get_ylim() == (0, 100) diff --git a/tests/test_histograms.py b/tests/test_histograms.py index 32546ffe..317d670d 100644 --- a/tests/test_histograms.py +++ b/tests/test_histograms.py @@ -7,18 +7,29 @@ from pymatgen.core import Structure from pymatviz import residual_hist, spacegroup_hist, true_pred_hist - -from .conftest import y_pred, y_true +from tests.conftest import y_pred, y_true @pytest.mark.parametrize("bins", [None, 1, 100]) @pytest.mark.parametrize("xlabel", [None, "foo"]) def test_residual_hist(bins: int | None, xlabel: str | None) -> None: - residual_hist(y_true, y_pred, bins=bins, xlabel=xlabel) + ax = residual_hist(y_true, y_pred, bins=bins, xlabel=xlabel) + + assert isinstance(ax, plt.Axes) + assert ( + ax.get_xlabel() == xlabel or r"Residual ($y_\mathrm{test} - y_\mathrm{pred}$)" + ) + assert len(ax.lines) == 1 + legend = ax.get_legend() + assert len(ax.patches) == bins or 10 + assert legend._get_loc() == 2 # 2 meaning 'upper left' -def test_true_pred_hist(): - true_pred_hist(y_true, y_pred, y_true - y_pred) +@pytest.mark.parametrize("bins", [None, 1, 100]) +@pytest.mark.parametrize("cmap", ["hot", "Blues"]) +def test_true_pred_hist(bins: int | None, cmap: str) -> None: + ax = true_pred_hist(y_true, y_pred, y_true - y_pred, bins=bins, cmap=cmap) + assert isinstance(ax, plt.Axes) @pytest.mark.parametrize("xticks", ["all", "crys_sys_edges", 1, 50])