Skip to content

Commit

Permalink
scatter_density() use x, y args as axis labels if strings
Browse files Browse the repository at this point in the history
update examples/mp_bimodal_e_form.ipynb with correct entry counts
(MP API was at one point returning all entries twice)
upgrade pre-commit hooks and apply new black format
  • Loading branch information
janosh committed Feb 5, 2023
1 parent c550332 commit 0f2386a
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 66 deletions.
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/PyCQA/isort
rev: 5.11.4
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black-jupyter

Expand Down Expand Up @@ -59,23 +59,23 @@ repos:
exclude_types: [csv, svg, html, yaml, jupyter]

- repo: https://github.com/PyCQA/autoflake
rev: v2.0.0
rev: v2.0.1
hooks:
- id: autoflake

- repo: https://github.com/PyCQA/pydocstyle
rev: 6.1.1
rev: 6.3.0
hooks:
- id: pydocstyle
exclude: tests

- repo: https://github.com/PyCQA/docformatter
rev: v1.5.1
rev: v1.6.0.rc1
hooks:
- id: docformatter

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.0
rev: 1.6.1
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
Expand All @@ -90,7 +90,7 @@ repos:
args: [--drop-empty-cells, --keep-output]

- repo: https://github.com/jendrikseipp/vulture
rev: v2.6
rev: v2.7
hooks:
- id: vulture
args: [., assets/vulture_whitelist]
204 changes: 166 additions & 38 deletions examples/mp_bimodal_e_form.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pymatviz/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def true_pred_hist(
)

for xmin, xmax, rect in zip(bin_edges, bin_edges[1:], bars.patches):

y_preds_in_rect = np.logical_and(y_pred > xmin, y_pred < xmax).nonzero()

color_value = y_std[y_preds_in_rect].mean()
Expand Down
9 changes: 7 additions & 2 deletions pymatviz/parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def density_scatter(
sort: bool = True,
log_cmap: bool = True,
density_bins: int = 100,
xlabel: str = "Actual",
ylabel: str = "Predicted",
xlabel: str = None,
ylabel: str = None,
identity: bool = True,
stats: bool | dict[str, Any] = True,
**kwargs: Any,
Expand Down Expand Up @@ -90,6 +90,11 @@ def density_scatter(
Returns:
ax: The plot's matplotlib Axes.
"""
if xlabel is None:
xlabel = getattr(x, "name", x if isinstance(x, str) else "Actual")
if ylabel is None:
ylabel = getattr(y, "name", y if isinstance(y, str) else "Predicted")

x, y = df_to_arrays(df, x, y)
ax = ax or plt.gca()

Expand Down
2 changes: 0 additions & 2 deletions pymatviz/ptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def ptable_heatmap(
text_style = dict(horizontalalignment="center", fontsize=16, fontweight="semibold")

for symbol, row, column, *_ in df_ptable.itertuples():

row = n_rows - row # makes periodic table right side up
heat_val = elem_values.get(symbol)

Expand Down Expand Up @@ -272,7 +271,6 @@ def ptable_heatmap(
ax.add_patch(rect)

if heat_mode is not None:

# colorbar position and size: [x, y, width, height]
# anchored at lower left corner
cb_ax = ax.inset_axes([0.18, 0.8, 0.42, 0.05], transform=ax.transAxes)
Expand Down
1 change: 0 additions & 1 deletion pymatviz/uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def qq_gaussian(

lines = [] # collect plotted lines to show second legend with miscalibration areas
for key, std in y_std.items():

z_scored = (np.array(res) / std).reshape(-1, 1)

exp_proportions = np.linspace(0, 1, resolution)
Expand Down
11 changes: 9 additions & 2 deletions pymatviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def annotate_bars(
y_max = 0

for rect, label in zip(ax.patches, labels):

y_pos = rect.get_height()
x_pos = rect.get_x() + rect.get_width() / 2 + h_offset

Expand Down Expand Up @@ -290,7 +289,15 @@ def save_fig(
if path.lower().endswith((".svelte", ".html")):
config = dict(
showTips=False,
modeBarButtonsToRemove=["lasso2d", "select2d", "autoScale2d", "toImage"],
modeBarButtonsToRemove=[
"lasso2d",
"select2d",
"autoScale2d",
"toImage",
"toggleSpikelines",
"hoverClosestCartesian",
"hoverCompareCartesian",
],
responsive=True,
displaylogo=False,
)
Expand Down
24 changes: 12 additions & 12 deletions site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
"make-api-docs": "python ../assets/make_api_docs.py"
},
"devDependencies": {
"@iconify/svelte": "^3.0.1",
"@sveltejs/adapter-static": "^1.0.4",
"@sveltejs/kit": "^1.1.1",
"@iconify/svelte": "^3.1.0",
"@sveltejs/adapter-static": "^1.0.5",
"@sveltejs/kit": "^1.3.9",
"@sveltejs/vite-plugin-svelte": "^2.0.2",
"@typescript-eslint/eslint-plugin": "^5.48.2",
"@typescript-eslint/parser": "^5.48.2",
"eslint": "^8.32.0",
"@typescript-eslint/eslint-plugin": "^5.50.0",
"@typescript-eslint/parser": "^5.50.0",
"eslint": "^8.33.0",
"eslint-plugin-svelte3": "^4.0.0",
"hastscript": "^7.2.0",
"highlight.js": "^11.7.0",
Expand All @@ -31,13 +31,13 @@
"rehype-autolink-headings": "^6.1.1",
"rehype-slug": "^5.1.0",
"svelte": "^3.55.1",
"svelte-check": "^3.0.2",
"svelte-preprocess": "^5.0.0",
"svelte-check": "^3.0.3",
"svelte-preprocess": "^5.0.1",
"svelte-toc": "^0.5.2",
"svelte-zoo": "^0.2.2",
"svelte2tsx": "^0.6.0",
"tslib": "^2.4.1",
"typescript": "^4.9.4",
"svelte-zoo": "^0.2.4",
"svelte2tsx": "^0.6.1",
"tslib": "^2.5.0",
"typescript": "^4.9.5",
"vite": "^4.0.4"
},
"prettier": {
Expand Down
16 changes: 15 additions & 1 deletion tests/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -34,9 +36,21 @@ def test_density_scatter(
cmap: str | None,
stats: bool | dict[str, Any],
) -> None:
density_scatter(
ax = density_scatter(
df=df, x=x, y=y, log_cmap=log_cmap, sort=sort, cmap=cmap, stats=stats
)
assert isinstance(ax, plt.Axes)
assert ax.get_xlabel() == x if isinstance(x, str) else "Actual"
assert ax.get_ylabel() == y if isinstance(y, str) else "Predicted"


def test_density_scatter_uses_series_name_as_label() -> None:
x = pd.Series(np.random.rand(5), name="x")
y = pd.Series(np.random.rand(5), name="y")
ax = density_scatter(x=x, y=y)

assert ax.get_xlabel() == "x"
assert ax.get_ylabel() == "y"


@pytest.mark.parametrize("df, x, y", df_x_y)
Expand Down

0 comments on commit 0f2386a

Please sign in to comment.