Skip to content

Commit

Permalink
Added a reverse option to overplot_* (#156)
Browse files Browse the repository at this point in the history
* Added a reverse option to overplot_*

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added tests for new overplotting functions (+fixed a bug)

* Extended images for test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* More reverse options

* Added reverse truths and a test for them

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Bugfix

* bugfix

* fixing baseline images

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dan F-M <foreman.mackey@gmail.com>
  • Loading branch information
3 people committed Mar 18, 2021
1 parent 6dcfb7f commit ce1f6f6
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 14 deletions.
7 changes: 6 additions & 1 deletion src/corner/arviz_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def arviz_corner(
diverging_mask = np.squeeze(diverging_mask)
if divergences_kwargs is None:
divergences_kwargs = {"color": "C1", "ms": 1}
overplot_points(fig, samples[diverging_mask], **divergences_kwargs)
overplot_points(
fig,
samples[diverging_mask],
reverse=reverse,
**divergences_kwargs,
)

return fig
52 changes: 39 additions & 13 deletions src/corner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,11 @@ def corner_impl(
)

if truths is not None:
overplot_lines(fig, truths, color=truth_color)
overplot_lines(fig, truths, reverse=reverse, color=truth_color)
overplot_points(
fig,
[[np.nan if t is None else t for t in truths]],
reverse=reverse,
marker="s",
color=truth_color,
)
Expand Down Expand Up @@ -683,7 +684,7 @@ def hist2d(
_set_ylim(new_fig, ax, range[1])


def overplot_lines(fig, xs, **kwargs):
def overplot_lines(fig, xs, reverse=False, **kwargs):
"""
Overplot lines on a figure generated by ``corner.corner``
Expand All @@ -698,24 +699,39 @@ def overplot_lines(fig, xs, **kwargs):
call that originally generated the figure. The entries can optionally
be ``None`` to omit the line in that axis.
reverse: bool
A boolean flag that should be set to 'True' if the corner plot itself
was plotted with 'reverse=True'.
**kwargs
Any remaining keyword arguments are passed to the ``ax.axvline``
method.
"""
K = len(xs)
axes = _get_fig_axes(fig, K)
for k1 in range(K):
if xs[k1] is not None:
axes[k1, k1].axvline(xs[k1], **kwargs)
for k2 in range(k1 + 1, K):
if reverse:
for k1 in range(K):
if xs[k1] is not None:
axes[K - k1 - 1, K - k1 - 1].axvline(xs[k1], **kwargs)
for k2 in range(k1 + 1, K):
if xs[k1] is not None:
axes[K - k2 - 1, K - k1 - 1].axvline(xs[k1], **kwargs)
if xs[k2] is not None:
axes[K - k2 - 1, K - k1 - 1].axhline(xs[k2], **kwargs)

else:
for k1 in range(K):
if xs[k1] is not None:
axes[k2, k1].axvline(xs[k1], **kwargs)
if xs[k2] is not None:
axes[k2, k1].axhline(xs[k2], **kwargs)
axes[k1, k1].axvline(xs[k1], **kwargs)
for k2 in range(k1 + 1, K):
if xs[k1] is not None:
axes[k2, k1].axvline(xs[k1], **kwargs)
if xs[k2] is not None:
axes[k2, k1].axhline(xs[k2], **kwargs)


def overplot_points(fig, xs, **kwargs):
def overplot_points(fig, xs, reverse=False, **kwargs):
"""
Overplot points on a figure generated by ``corner.corner``
Expand All @@ -729,6 +745,10 @@ def overplot_points(fig, xs, **kwargs):
that is compatible with the :func:`corner.corner` call that originally
generated the figure.
reverse: bool
A boolean flag that should be set to 'True' if the corner plot itself
was plotted with 'reverse=True'.
**kwargs
Any remaining keyword arguments are passed to the ``ax.plot``
method.
Expand All @@ -739,9 +759,15 @@ def overplot_points(fig, xs, **kwargs):
xs = _parse_input(xs)
K = len(xs)
axes = _get_fig_axes(fig, K)
for k1 in range(K):
for k2 in range(k1 + 1, K):
axes[k2, k1].plot(xs[k1], xs[k2], **kwargs)
if reverse:
for k1 in range(K):
for k2 in range(k1):
axes[K - k1 - 1, K - k2 - 1].plot(xs[k2], xs[k1], **kwargs)

else:
for k1 in range(K):
for k2 in range(k1 + 1, K):
axes[k2, k1].plot(xs[k1], xs[k2], **kwargs)


def _parse_input(xs):
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 83 additions & 0 deletions tests/test_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ def test_truths():
_run_corner(truths=[0.0, None, 0.15])


@image_comparison(
baseline_images=["reverse_truths"], remove_text=True, extensions=["png"]
)
def test_reverse_truths():
_run_corner(truths=[0.0, None, 0.15], reverse=True)


@image_comparison(
baseline_images=["no_fill_contours"], remove_text=True, extensions=["png"]
)
Expand All @@ -179,6 +186,82 @@ def test_reverse():
_run_corner(ndim=2, range=[(4, -4), (-5, 5)])


@image_comparison(
baseline_images=["extended_overplotting"],
remove_text=True,
extensions=["png"],
)
def test_extended_overplotting():
# Test overplotting a more complex plot
labels = [r"$\theta_1$", r"$\theta_2$", r"$\theta_3$", r"$\theta_4$"]

figure = _run_corner(ndim=4, reverse=False, labels=labels)

# Set same results:
ndim, nsamples = 4, 10000
np.random.seed(1234)

data1 = np.random.randn(ndim * 4 * nsamples // 5).reshape(
[4 * nsamples // 5, ndim]
)
mean = 4 * np.random.rand(ndim)
data2 = mean[None, :] + np.random.randn(ndim * nsamples // 5).reshape(
[nsamples // 5, ndim]
)
samples = np.vstack([data1, data2])

value1 = mean
# This is the empirical mean of the sample:
value2 = np.mean(data1, axis=0)

corner.overplot_lines(figure, value1, color="C1", reverse=False)
corner.overplot_points(
figure, value1[None], marker="s", color="C1", reverse=False
)
corner.overplot_lines(figure, value2, color="C2", reverse=False)
corner.overplot_points(
figure, value2[None], marker="s", color="C2", reverse=False
)


@image_comparison(
baseline_images=["reverse_overplotting"],
remove_text=True,
extensions=["png"],
)
def test_reverse_overplotting():
# Test overplotting with a reversed plot
labels = [r"$\theta_1$", r"$\theta_2$", r"$\theta_3$", r"$\theta_4$"]

figure = _run_corner(ndim=4, reverse=True, labels=labels)

# Set same results:
ndim, nsamples = 4, 10000
np.random.seed(1234)

data1 = np.random.randn(ndim * 4 * nsamples // 5).reshape(
[4 * nsamples // 5, ndim]
)
mean = 4 * np.random.rand(ndim)
data2 = mean[None, :] + np.random.randn(ndim * nsamples // 5).reshape(
[nsamples // 5, ndim]
)
samples = np.vstack([data1, data2])

value1 = mean
# This is the empirical mean of the sample:
value2 = np.mean(data1, axis=0)

corner.overplot_lines(figure, value1, color="C1", reverse=True)
corner.overplot_points(
figure, value1[None], marker="s", color="C1", reverse=True
)
corner.overplot_lines(figure, value2, color="C2", reverse=True)
corner.overplot_points(
figure, value2[None], marker="s", color="C2", reverse=True
)


@image_comparison(
baseline_images=["hist_bin_factor"], remove_text=True, extensions=["png"]
)
Expand Down

0 comments on commit ce1f6f6

Please sign in to comment.