Skip to content

Commit

Permalink
Merge pull request #138 from dfm/overplot-points
Browse files Browse the repository at this point in the history
Adding functions to overplot points on the figures
  • Loading branch information
dfm committed Mar 5, 2021
2 parents 7174b5f + b3979cc commit 35be332
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 48 deletions.
2 changes: 2 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ documented here.
.. autofunction:: corner.corner
.. autofunction:: corner.hist2d
.. autofunction:: corner.quantile
.. autofunction:: corner.overplot_lines
.. autofunction:: corner.overplot_points
4 changes: 2 additions & 2 deletions src/corner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-

__all__ = ["corner", "hist2d", "quantile"]
__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"]

from .corner import corner, hist2d, quantile
from .corner import corner, hist2d, overplot_lines, overplot_points, quantile
from .corner_version import __version__ # noqa

__author__ = "Dan Foreman-Mackey"
Expand Down
170 changes: 124 additions & 46 deletions src/corner/corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
except ImportError:
gaussian_filter = None

__all__ = ["corner", "hist2d", "quantile"]
__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"]


def corner(
Expand Down Expand Up @@ -161,8 +161,8 @@ def corner(
Any extra keyword arguments to send to the 1-D histogram plots.
**hist2d_kwargs
Any remaining keyword arguments are sent to `corner.hist2d` to generate
the 2-D histogram plots.
Any remaining keyword arguments are sent to :func:`corner.hist2d` to
generate the 2-D histogram plots.
"""
if quantiles is None:
Expand All @@ -184,12 +184,7 @@ def corner(
titles = labels

# Deal with 1D sample lists.
xs = np.atleast_1d(xs)
if len(xs.shape) == 1:
xs = np.atleast_2d(xs)
else:
assert len(xs.shape) == 2, "The input sample array must be 1- or 2-D."
xs = xs.T
xs = _parse_input(xs)
assert xs.shape[0] <= xs.shape[1], (
"I don't believe that you want more " "dimensions than samples!"
)
Expand Down Expand Up @@ -220,19 +215,8 @@ def corner(
if fig is None:
fig, axes = pl.subplots(K, K, figsize=(dim, dim))
else:
if not fig.axes:
axes = fig.subplots(K, K)
else:
new_fig = False
try:
axes = np.array(fig.axes).reshape((K, K))
except ValueError:
raise ValueError(
(
"Provided figure has {0} axes, but data has "
"dimensions K={1}"
).format(len(fig.axes), K)
)
new_fig = False
axes = _get_fig_axes(fig, K)

# Format the figure.
lb = lbdim / dim
Expand Down Expand Up @@ -333,8 +317,8 @@ def corner(
y0 = np.array(list(zip(n, n))).flatten()
ax.plot(x0, y0, **hist_kwargs)

if truths is not None and truths[i] is not None:
ax.axvline(truths[i], color=truth_color)
# if truths is not None and truths[i] is not None:
# ax.axvline(truths[i], color=truth_color)

# Plot quantiles if wanted.
if len(quantiles) > 0:
Expand Down Expand Up @@ -407,7 +391,7 @@ def corner(
else:
if reverse:
ax.xaxis.tick_top()
[l.set_rotation(45) for l in ax.get_xticklabels()]
[lbl.set_rotation(45) for lbl in ax.get_xticklabels()]
if labels is not None:
if reverse:
if "labelpad" in label_kwargs.keys():
Expand Down Expand Up @@ -464,13 +448,13 @@ def corner(
**hist2d_kwargs
)

if truths is not None:
if truths[i] is not None and truths[j] is not None:
ax.plot(truths[j], truths[i], "s", color=truth_color)
if truths[j] is not None:
ax.axvline(truths[j], color=truth_color)
if truths[i] is not None:
ax.axhline(truths[i], color=truth_color)
# if truths is not None:
# if truths[i] is not None and truths[j] is not None:
# ax.plot(truths[j], truths[i], "s", color=truth_color)
# if truths[j] is not None:
# ax.axvline(truths[j], color=truth_color)
# if truths[i] is not None:
# ax.axhline(truths[i], color=truth_color)

if max_n_ticks == 0:
ax.xaxis.set_major_locator(NullLocator())
Expand Down Expand Up @@ -520,21 +504,16 @@ def corner(
ScalarFormatter(useMathText=use_math_text)
)

return fig


def _set_xlim(new_fig, ax, new_xlim):
if new_fig:
return ax.set_xlim(new_xlim)
xlim = ax.get_xlim()
return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])])

if truths is not None:
overplot_lines(fig, truths, color=truth_color)
overplot_points(
fig,
[[np.nan if t is None else t for t in truths]],
marker="s",
color=truth_color,
)

def _set_ylim(new_fig, ax, new_ylim):
if new_fig:
return ax.set_ylim(new_ylim)
ylim = ax.get_ylim()
return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])])
return fig


def quantile(x, q, weights=None):
Expand Down Expand Up @@ -828,3 +807,102 @@ def hist2d(

_set_xlim(new_fig, ax, range[0])
_set_ylim(new_fig, ax, range[1])


def overplot_lines(fig, xs, **kwargs):
"""
Overplot lines on a figure generated by ``corner.corner``
Parameters
----------
fig : Figure
The figure generated by a call to :func:`corner.corner`.
xs : array_like[ndim]
The values where the lines should be plotted. This must have ``ndim``
entries, where ``ndim`` is compatible with the :func:`corner.corner`
call that originally generated the figure. The entries can optionally
be ``None`` to omit the line in that axis.
**kwargs
Any remaining keyword arguments are passed to the ``ax.axvline``
method.
"""
K = len(xs)
axes = _get_fig_axes(fig, K)
for k1 in range(K):
if xs[k1] is not None:
axes[k1, k1].axvline(xs[k1], **kwargs)
for k2 in range(k1 + 1, K):
if xs[k1] is not None:
axes[k2, k1].axvline(xs[k1], **kwargs)
if xs[k2] is not None:
axes[k2, k1].axhline(xs[k2], **kwargs)


def overplot_points(fig, xs, **kwargs):
"""
Overplot points on a figure generated by ``corner.corner``
Parameters
----------
fig : Figure
The figure generated by a call to :func:`corner.corner`.
xs : array_like[nsamples, ndim]
The coordinates of the points to be plotted. This must have an ``ndim``
that is compatible with the :func:`corner.corner` call that originally
generated the figure.
**kwargs
Any remaining keyword arguments are passed to the ``ax.plot``
method.
"""
kwargs["marker"] = kwargs.pop("marker", ".")
kwargs["linestyle"] = kwargs.pop("linestyle", "none")
xs = _parse_input(xs)
K = len(xs)
axes = _get_fig_axes(fig, K)
for k1 in range(K):
for k2 in range(k1 + 1, K):
axes[k2, k1].plot(xs[k1], xs[k2], **kwargs)


def _parse_input(xs):
xs = np.atleast_1d(xs)
if len(xs.shape) == 1:
xs = np.atleast_2d(xs)
else:
assert len(xs.shape) == 2, "The input sample array must be 1- or 2-D."
xs = xs.T
return xs


def _get_fig_axes(fig, K):
if not fig.axes:
return fig.subplots(K, K)
try:
return np.array(fig.axes).reshape((K, K))
except ValueError:
raise ValueError(
(
"Provided figure has {0} axes, but data has "
"dimensions K={1}"
).format(len(fig.axes), K)
)


def _set_xlim(new_fig, ax, new_xlim):
if new_fig:
return ax.set_xlim(new_xlim)
xlim = ax.get_xlim()
return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])])


def _set_ylim(new_fig, ax, new_ylim):
if new_fig:
return ax.set_ylim(new_ylim)
ylim = ax.get_ylim()
return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])])
Binary file modified tests/baseline_images/test_corner/truths.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 35be332

Please sign in to comment.