Skip to content

Commit

Permalink
Added option for log scaled axes. (#174)
Browse files Browse the repository at this point in the history
* Added option for log scaled axes.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed bug with bins in hist2d.

* Added formatter for log scale axes.

* Added tests for log scale.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed default of axes_scale in arviz_corner.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
castillohair and pre-commit-ci[bot] committed Oct 6, 2022
1 parent 0d43ad4 commit 6e61b95
Show file tree
Hide file tree
Showing 13 changed files with 225 additions and 32 deletions.
2 changes: 2 additions & 0 deletions src/corner/arviz_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def arviz_corner(
*,
# Original corner parameters
range=None,
axes_scale="linear",
weights=None,
color=None,
hist_bin_factor=1,
Expand Down Expand Up @@ -136,6 +137,7 @@ def arviz_corner(
samples,
bins=bins,
range=range,
axes_scale=axes_scale,
weights=weights,
color=color,
hist_bin_factor=hist_bin_factor,
Expand Down
155 changes: 123 additions & 32 deletions src/corner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
import numpy as np
from matplotlib import pyplot as pl
from matplotlib.colors import LinearSegmentedColormap, colorConverter
from matplotlib.ticker import MaxNLocator, NullLocator, ScalarFormatter
from matplotlib.ticker import (
LogFormatterMathtext,
LogLocator,
MaxNLocator,
NullFormatter,
NullLocator,
ScalarFormatter,
)

try:
from scipy.ndimage import gaussian_filter
Expand All @@ -27,6 +34,7 @@ def corner_impl(
xs,
bins=20,
range=None,
axes_scale="linear",
weights=None,
color=None,
hist_bin_factor=1,
Expand Down Expand Up @@ -105,6 +113,14 @@ def corner_impl(
plotdim = factor * K + factor * (K - 1.0) * whspace
dim = lbdim + plotdim + trdim

# Make axes_scale into a list if necessary, otherwise check length
if isinstance(axes_scale, str):
axes_scale = [axes_scale] * K
else:
assert (
len(axes_scale) == K
), "'axes_scale' should contain as many elements as data dimensions"

# Create a new figure if one wasn't provided.
new_fig = True
if fig is None:
Expand Down Expand Up @@ -197,23 +213,29 @@ def corner_impl(
ax = axes[i, i]

# Plot the histograms.
if smooth1d is None:
bins_1d = int(max(1, np.round(hist_bin_factor[i] * bins[i])))
n, _, _ = ax.hist(
x,
bins=bins_1d,
weights=weights,
range=np.sort(range[i]),
**hist_kwargs,
n_bins_1d = int(max(1, np.round(hist_bin_factor[i] * bins[i])))
if axes_scale[i] == "linear":
bins_1d = np.linspace(min(range[i]), max(range[i]), n_bins_1d + 1)
elif axes_scale[i] == "log":
bins_1d = np.logspace(
np.log10(min(range[i])), np.log10(max(range[i])), n_bins_1d + 1
)
else:
raise ValueError(
"Scale "
+ axes_scale[i]
+ "for dimension "
+ str(i)
+ "not supported. Use 'linear' or 'log'"
)
if smooth1d is None:
n, _, _ = ax.hist(x, bins=bins_1d, weights=weights, **hist_kwargs)
else:
if gaussian_filter is None:
raise ImportError("Please install scipy for smoothing")
n, b = np.histogram(
x, bins=bins[i], weights=weights, range=np.sort(range[i])
)
n, _ = np.histogram(x, bins=bins_1d, weights=weights)
n = gaussian_filter(n, smooth1d)
x0 = np.array(list(zip(b[:-1], b[1:]))).flatten()
x0 = np.array(list(zip(bins_1d[:-1], bins_1d[1:]))).flatten()
y0 = np.array(list(zip(n, n))).flatten()
ax.plot(x0, y0, **hist_kwargs)

Expand Down Expand Up @@ -264,6 +286,7 @@ def corner_impl(

# Set up the axes.
_set_xlim(new_fig, ax, range[i])
ax.set_xscale(axes_scale[i])
if scale_hist:
maxn = np.max(n)
_set_ylim(new_fig, ax, [-0.1 * maxn, 1.1 * maxn])
Expand All @@ -276,19 +299,27 @@ def corner_impl(
ax.xaxis.set_major_locator(NullLocator())
ax.yaxis.set_major_locator(NullLocator())
else:
ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower"))
if axes_scale[i] == "linear":
ax.xaxis.set_major_locator(
MaxNLocator(max_n_ticks, prune="lower")
)
elif axes_scale[i] == "log":
ax.xaxis.set_major_locator(LogLocator(numticks=max_n_ticks))
ax.yaxis.set_major_locator(NullLocator())

if i < K - 1:
if top_ticks:
ax.xaxis.set_ticks_position("top")
[l.set_rotation(45) for l in ax.get_xticklabels()]
[l.set_rotation(45) for l in ax.get_xticklabels(minor=True)]
else:
ax.set_xticklabels([])
ax.set_xticklabels([], minor=True)
else:
if reverse:
ax.xaxis.tick_top()
[lbl.set_rotation(45) for lbl in ax.get_xticklabels()]
[l.set_rotation(45) for l in ax.get_xticklabels()]
[l.set_rotation(45) for l in ax.get_xticklabels(minor=True)]
if labels is not None:
if reverse:
if "labelpad" in label_kwargs.keys():
Expand All @@ -308,9 +339,12 @@ def corner_impl(
ax.xaxis.set_label_coords(0.5, -0.3 - labelpad)

# use MathText for axes ticks
ax.xaxis.set_major_formatter(
ScalarFormatter(useMathText=use_math_text)
)
if axes_scale[i] == "linear":
ax.xaxis.set_major_formatter(
ScalarFormatter(useMathText=use_math_text)
)
elif axes_scale[i] == "log":
ax.xaxis.set_major_formatter(LogFormatterMathtext())

for j, y in enumerate(xs):
if np.shape(xs)[0] == 1:
Expand All @@ -337,6 +371,7 @@ def corner_impl(
x,
ax=ax,
range=[range[j], range[i]],
axes_scale=[axes_scale[j], axes_scale[i]],
weights=weights,
color=color,
smooth=smooth,
Expand All @@ -349,19 +384,32 @@ def corner_impl(
ax.xaxis.set_major_locator(NullLocator())
ax.yaxis.set_major_locator(NullLocator())
else:
ax.xaxis.set_major_locator(
MaxNLocator(max_n_ticks, prune="lower")
)
ax.yaxis.set_major_locator(
MaxNLocator(max_n_ticks, prune="lower")
)
if axes_scale[j] == "linear":
ax.xaxis.set_major_locator(
MaxNLocator(max_n_ticks, prune="lower")
)
elif axes_scale[j] == "log":
ax.xaxis.set_major_locator(
LogLocator(numticks=max_n_ticks)
)

if axes_scale[i] == "linear":
ax.yaxis.set_major_locator(
MaxNLocator(max_n_ticks, prune="lower")
)
elif axes_scale[i] == "log":
ax.yaxis.set_major_locator(
LogLocator(numticks=max_n_ticks)
)

if i < K - 1:
ax.set_xticklabels([])
ax.set_xticklabels([], minor=True)
else:
if reverse:
ax.xaxis.tick_top()
[l.set_rotation(45) for l in ax.get_xticklabels()]
[l.set_rotation(45) for l in ax.get_xticklabels(minor=True)]
if labels is not None:
ax.set_xlabel(labels[j], **label_kwargs)
if reverse:
Expand All @@ -370,16 +418,21 @@ def corner_impl(
ax.xaxis.set_label_coords(0.5, -0.3 - labelpad)

# use MathText for axes ticks
ax.xaxis.set_major_formatter(
ScalarFormatter(useMathText=use_math_text)
)
if axes_scale[j] == "linear":
ax.xaxis.set_major_formatter(
ScalarFormatter(useMathText=use_math_text)
)
elif axes_scale[j] == "log":
ax.xaxis.set_major_formatter(LogFormatterMathtext())

if j > 0:
ax.set_yticklabels([])
ax.set_yticklabels([], minor=True)
else:
if reverse:
ax.yaxis.tick_right()
[l.set_rotation(45) for l in ax.get_yticklabels()]
[l.set_rotation(45) for l in ax.get_yticklabels(minor=True)]
if labels is not None:
if reverse:
ax.set_ylabel(labels[i], rotation=-90, **label_kwargs)
Expand All @@ -389,9 +442,12 @@ def corner_impl(
ax.yaxis.set_label_coords(-0.3 - labelpad, 0.5)

# use MathText for axes ticks
ax.yaxis.set_major_formatter(
ScalarFormatter(useMathText=use_math_text)
)
if axes_scale[i] == "linear":
ax.yaxis.set_major_formatter(
ScalarFormatter(useMathText=use_math_text)
)
elif axes_scale[i] == "log":
ax.yaxis.set_major_formatter(LogFormatterMathtext())

if truths is not None:
overplot_lines(fig, truths, reverse=reverse, color=truth_color)
Expand Down Expand Up @@ -464,6 +520,7 @@ def hist2d(
y,
bins=20,
range=None,
axes_scale=["linear", "linear"],
weights=None,
levels=None,
smooth=None,
Expand Down Expand Up @@ -494,6 +551,9 @@ def hist2d(
y : array_like[nsamples,]
The samples.
axes_scale : iterable (2,)
Scale (``"linear"``, ``"log"``) to use for each dimension.
quiet : bool
If true, suppress warnings for small datasets.
Expand Down Expand Up @@ -577,13 +637,42 @@ def hist2d(
for i, l in enumerate(levels):
contour_cmap[i][-1] *= float(i) / (len(levels) + 1)

# Parse the bin specifications.
try:
bins = [int(bins) for _ in range]
except TypeError:
if len(bins) != len(range):
raise ValueError("Dimension mismatch between bins and range")

# We'll make the 2D histogram to directly estimate the density.
bins_2d = []
if axes_scale[0] == "linear":
bins_2d.append(np.linspace(min(range[0]), max(range[0]), bins[0] + 1))
elif axes_scale[0] == "log":
bins_2d.append(
np.logspace(
np.log10(min(range[0])),
np.log10(max(range[0])),
bins[0] + 1,
)
)

if axes_scale[1] == "linear":
bins_2d.append(np.linspace(min(range[1]), max(range[1]), bins[1] + 1))
elif axes_scale[1] == "log":
bins_2d.append(
np.logspace(
np.log10(min(range[1])),
np.log10(max(range[1])),
bins[1] + 1,
)
)

try:
H, X, Y = np.histogram2d(
x.flatten(),
y.flatten(),
bins=bins,
range=list(map(np.sort, range)),
bins=bins_2d,
weights=weights,
)
except ValueError:
Expand Down Expand Up @@ -705,6 +794,8 @@ def hist2d(

_set_xlim(new_fig, ax, range[0])
_set_ylim(new_fig, ax, range[1])
ax.set_xscale(axes_scale[0])
ax.set_yscale(axes_scale[1])


def overplot_lines(fig, xs, reverse=False, **kwargs):
Expand Down
7 changes: 7 additions & 0 deletions src/corner/corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def corner(
*,
# Original corner parameters
range=None,
axes_scale="linear",
weights=None,
color=None,
hist_bin_factor=1,
Expand Down Expand Up @@ -156,6 +157,10 @@ def corner(
[(0.,10.), (1.,5), 0.999, etc.].
If a fraction, the bounds are chosen to be equal-tailed.
axes_scale : str or iterable (ndim,)
Scale (``"linear"``, ``"log"``) to use for each data dimension. If only
one scale is specified, use that for all dimensions.
truths : iterable (ndim,)
A list of reference values to indicate on the plots. Individual
values can be omitted by using ``None``.
Expand Down Expand Up @@ -238,6 +243,7 @@ def corner(
data,
bins=bins,
range=range,
axes_scale=axes_scale,
weights=weights,
color=color,
hist_bin_factor=hist_bin_factor,
Expand Down Expand Up @@ -269,6 +275,7 @@ def corner(
data,
bins=bins,
range=range,
axes_scale=axes_scale,
weights=weights,
color=color,
hist_bin_factor=hist_bin_factor,
Expand Down
Binary file added tests/baseline_images/test_corner/basic_log.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline_images/test_corner/reverse_log.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline_images/test_corner/smooth2_log.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 6e61b95

Please sign in to comment.