Skip to content

Commit

Permalink
breaking: absorb keywords sort and density_bins into hist_density_kwargs
Browse files Browse the repository at this point in the history
add color_bar (bool | dict, optional): Whether to add a color bar. Defaults to True. If dict, unpacked into ax.figure.colorbar(). E.g. dict(label="Density").
  • Loading branch information
janosh committed Dec 26, 2023
1 parent 36f4771 commit 64da6b2
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pymatviz/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def marchenko_pastur(
ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None.
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
ax = ax or plt.gca()

Expand Down
4 changes: 2 additions & 2 deletions pymatviz/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def cumulative_residual(
**kwargs: Additional keyword arguments passed to ax.fill_between().
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
ax = ax or plt.gca()

Expand Down Expand Up @@ -75,7 +75,7 @@ def cumulative_error(
**kwargs: Additional keyword arguments passed to ax.plot().
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
ax = ax or plt.gca()

Expand Down
6 changes: 3 additions & 3 deletions pymatviz/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def residual_hist(
**kwargs: Additional keyword arguments to pass to matplotlib.Axes.
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
ax = ax or plt.gca()

Expand Down Expand Up @@ -103,7 +103,7 @@ def true_pred_hist(
**kwargs: Additional keyword arguments to pass to ax.hist().
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
y_true, y_pred, y_std = df_to_arrays(df, y_true, y_pred, y_std)
y_true, y_pred, y_std = np.array([y_true, y_pred, y_std])
Expand Down Expand Up @@ -407,7 +407,7 @@ def hist_elemental_prevalence(
**kwargs (int): Keyword arguments passed to pandas.Series.plot.bar().
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
ax = ax or plt.gca()

Expand Down
67 changes: 37 additions & 30 deletions pymatviz/parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,36 +37,35 @@ def hist_density(
See scipy.interpolate.interpn() for options.
Returns:
tuple[array, array]: x and y values (sorted by density) and density itself
tuple[np.array, np.array, np.array]: x and y values (sorted by density) and
density itself
"""
x, y = df_to_arrays(df, x, y)
xs, ys = df_to_arrays(df, x, y)

data, x_e, y_e = np.histogram2d(x, y, bins=bins)
counts, x_bins, y_bins = np.histogram2d(xs, ys, bins=bins)

# get bin centers
points = 0.5 * (x_bins[1:] + x_bins[:-1]), 0.5 * (y_bins[1:] + y_bins[:-1])
zs = scipy.interpolate.interpn(
(0.5 * (x_e[1:] + x_e[:-1]), 0.5 * (y_e[1:] + y_e[:-1])),
data,
np.vstack([x, y]).T,
method=method,
bounds_error=False,
points, counts, np.vstack([xs, ys]).T, method=method, bounds_error=False
)

# Sort the points by density, so that the densest points are plotted last
# sort points by density, so that the densest points are plotted last
if sort:
idx = zs.argsort()
x, y, zs = x[idx], y[idx], zs[idx]
sort_idx = zs.argsort()
xs, ys, zs = xs[sort_idx], ys[sort_idx], zs[sort_idx]

return x, y, zs
return xs, ys, zs


def density_scatter(
x: ArrayLike | str,
y: ArrayLike | str,
df: pd.DataFrame | None = None,
ax: plt.Axes | None = None,
sort: bool = True,
log_cmap: bool = True,
density_bins: int = 100,
log_density: bool = True,
hist_density_kwargs: dict[str, Any] | None = None,
color_bar: bool | dict[str, Any] = True,
xlabel: str | None = None,
ylabel: str | None = None,
identity: bool = True,
Expand All @@ -81,21 +80,25 @@ def density_scatter(
df (pd.DataFrame, optional): DataFrame with x and y columns. Defaults to None.
ax (Axes, optional): matplotlib Axes on which to plot. Defaults to None.
sort (bool, optional): Whether to sort the data. Defaults to True.
log_cmap (bool, optional): Whether to log the color scale. Defaults to True.
density_bins (int, optional): How many density_bins to use for the density
histogram, i.e. granularity of the density color scale. Defaults to 100.
log_density (bool, optional): Whether to log the density color scale.
Defaults to True.
hist_density_kwargs (dict, optional): Passed to hist_density(). Use to change
sort (by density, default True), bins (default 100), or method (for
interpolation, default "nearest").
color_bar (bool | dict, optional): Whether to add a color bar. Defaults to True.
If dict, unpacked into ax.figure.colorbar(). E.g. dict(label="Density").
xlabel (str, optional): x-axis label. Defaults to "Actual".
ylabel (str, optional): y-axis label. Defaults to "Predicted".
identity (bool, optional): Whether to add an identity/parity line (y = x).
Defaults to True.
stats (bool | dict[str, Any], optional): Whether to display a text box with MAE
and R^2. Defaults to True. Can be dict to pass kwargs to annotate_metrics().
E.g. stats=dict(loc="upper left", prefix="Title", prop=dict(fontsize=16)).
**kwargs: Additional keyword arguments to pass to ax.scatter(). E.g. cmap to
change the color map.
**kwargs: Passed to ax.scatter(). Defaults to dict(s=6) to control marker size.
Other common keys are cmap, vmin, vamx, alpha, edgecolors, linewidths.
Returns:
ax: The plot's matplotlib Axes.
plt.Axes:
"""
if not isinstance(stats, (bool, dict)):
raise TypeError(f"stats must be bool or dict, got {type(stats)} instead.")
Expand All @@ -104,23 +107,27 @@ def density_scatter(
if ylabel is None:
ylabel = getattr(y, "name", y if isinstance(y, str) else "Predicted")

x, y = df_to_arrays(df, x, y)
xs, ys = df_to_arrays(df, x, y)
ax = ax or plt.gca()

x, y, cs = hist_density(x, y, sort=sort, bins=density_bins)

norm = mpl.colors.LogNorm() if log_cmap else None
xs, ys, cs = hist_density(xs, ys, **(hist_density_kwargs or {}))

ax.scatter(x, y, c=cs, norm=norm, **kwargs)
# decrease marker size
defaults = dict(s=6, norm=mpl.colors.LogNorm() if log_density else None)
ax.scatter(xs, ys, c=cs, **defaults | kwargs)

if identity:
add_identity_line(ax)

if stats:
annotate_metrics(x, y, fig=ax, **(stats if isinstance(stats, dict) else {}))
annotate_metrics(xs, ys, fig=ax, **(stats if isinstance(stats, dict) else {}))

ax.set(xlabel=xlabel, ylabel=ylabel)

if color_bar:
kwds = dict(label="Density") if color_bar is True else color_bar
color_bar = ax.figure.colorbar(ax.collections[0], **kwds)

return ax


Expand Down Expand Up @@ -153,7 +160,7 @@ def scatter_with_err_bar(
**kwargs: Additional keyword arguments to pass to ax.errorbar().
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
x, y = df_to_arrays(df, x, y)
ax = ax or plt.gca()
Expand Down Expand Up @@ -203,7 +210,7 @@ def density_hexbin(
**kwargs: Additional keyword arguments to pass to ax.hexbin().
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
x, y = df_to_arrays(df, x, y)
ax = ax or plt.gca()
Expand Down Expand Up @@ -282,7 +289,7 @@ def residual_vs_actual(
**kwargs: Additional keyword arguments passed to plt.plot()
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
y_true, y_pred = df_to_arrays(df, y_true, y_pred)
assert isinstance(y_true, np.ndarray) # noqa: S101
Expand Down
4 changes: 2 additions & 2 deletions pymatviz/ptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def ptable_heatmap(
**kwargs: Additional keyword arguments passed to plt.figure().
Returns:
ax: matplotlib Axes with the heatmap.
plt.Axes: matplotlib Axes with the heatmap.
"""
if fmt is None:
fmt = lambda x, _: si_fmt(x, ".1%" if heat_mode == "percent" else ".0f")
Expand Down Expand Up @@ -455,7 +455,7 @@ def ptable_heatmap_ratio(
**kwargs: Additional keyword arguments passed to ptable_heatmap().
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
values_num = count_elements(values_num, count_mode)

Expand Down
4 changes: 2 additions & 2 deletions pymatviz/uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def qq_gaussian(
ax (Axes): matplotlib Axes on which to plot. Defaults to None.
Returns:
ax: The plot's matplotlib Axes.
plt.Axes: matplotlib Axes object
"""
if isinstance(y_std, (str, pd.Index)):
y_true, y_pred, y_std = df_to_arrays(df, y_true, y_pred, y_std)
Expand Down Expand Up @@ -222,7 +222,7 @@ def error_decay_with_uncert(
ax.get_ylim()[1]]).
Returns:
ax: matplotlib Axes object with plotted model error drop curve based on
plt.Axes: matplotlib Axes object with plotted model error drop curve based on
excluding data points by order of large to small model uncertainties.
"""
if isinstance(y_std, (str, pd.Index)):
Expand Down
2 changes: 1 addition & 1 deletion pymatviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def with_hist(
bins (int, optional): Resolution/bin count of the histograms. Defaults to 100.
Returns:
ax: The matplotlib Axes to be used for the main plot.
plt.Axes: The matplotlib Axes to be used for the main plot.
"""
fig = plt.gcf()

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,5 @@ isort.split-on-trailing-comma = false
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"tests/*" = ["D103", "S101"]
"examples/*" = ["E402", "INP001", "T201"]
"examples/*" = ["INP001", "T201"] # T201: print found
"site/*" = ["INP001", "S602"]
18 changes: 12 additions & 6 deletions tests/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,29 @@
from tests.conftest import DfOrArrays


@pytest.mark.parametrize("log_cmap", [True, False])
@pytest.mark.parametrize("sort", [True, False])
@pytest.mark.parametrize("log_density", [True, False])
@pytest.mark.parametrize("hist_density_kwargs", [None, {}, dict(bins=20, sort=True)])
@pytest.mark.parametrize("cmap", [None, "Greens"])
@pytest.mark.parametrize(
"stats",
[False, True, dict(prefix="test", loc="lower right", prop=dict(fontsize=10))],
)
def test_density_scatter(
df_or_arrays: DfOrArrays,
log_cmap: bool,
sort: bool,
log_density: bool,
hist_density_kwargs: dict[str, int | bool | str] | None,
cmap: str | None,
stats: bool | dict[str, Any],
) -> None:
df, x, y = df_or_arrays
ax = density_scatter(
df=df, x=x, y=y, log_cmap=log_cmap, sort=sort, cmap=cmap, stats=stats
df=df,
x=x,
y=y,
log_density=log_density,
hist_density_kwargs=hist_density_kwargs,
cmap=cmap,
stats=stats,
)
assert isinstance(ax, plt.Axes)
assert ax.get_xlabel() == x if isinstance(x, str) else "Actual"
Expand All @@ -56,7 +62,7 @@ def test_density_scatter_raises_on_bad_stats_type(stats: Any) -> None:
def test_density_scatter_uses_series_name_as_label() -> None:
x = pd.Series(np.random.rand(5), name="x")
y = pd.Series(np.random.rand(5), name="y")
ax = density_scatter(x=x, y=y)
ax = density_scatter(x=x, y=y, log_density=False)

assert ax.get_xlabel() == "x"
assert ax.get_ylabel() == "y"
Expand Down

0 comments on commit 64da6b2

Please sign in to comment.