Skip to content

Commit

Permalink
Fixing handling of range arugment when empty figure is provided (#224)
Browse files Browse the repository at this point in the history
* fixing #223

* removing hanging comments

* updating test
  • Loading branch information
dfm committed Mar 27, 2023
1 parent 9ea2436 commit e9f5396
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 46 deletions.
5 changes: 4 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
@nox.session
def tests(session):
session.install("-e", ".[test]")
session.run("pytest", "-v", "tests")
if session.posargs:
session.run("pytest", *session.posargs)
else:
session.run("pytest", "-v", "tests")


@nox.session
Expand Down
35 changes: 19 additions & 16 deletions src/corner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
LogFormatterMathtext,
LogLocator,
MaxNLocator,
NullFormatter,
NullLocator,
ScalarFormatter,
)
Expand Down Expand Up @@ -126,8 +125,7 @@ def corner_impl(
if fig is None:
fig, axes = pl.subplots(K, K, figsize=(dim, dim))
else:
new_fig = False
axes = _get_fig_axes(fig, K)
axes, new_fig = _get_fig_axes(fig, K)

# Format the figure.
lb = lbdim / dim
Expand All @@ -137,6 +135,7 @@ def corner_impl(
)

# Parse the parameter ranges.
force_range = False
if range is None:
if "extents" in hist2d_kwargs:
logging.warning(
Expand All @@ -161,6 +160,8 @@ def corner_impl(
)

else:
force_range = True

# If any of the extents are percentiles, convert them to ranges.
# Also make sure it's a normal list.
range = list(range)
Expand Down Expand Up @@ -285,14 +286,14 @@ def corner_impl(
ax.set_title(title, **title_kwargs)

# Set up the axes.
_set_xlim(new_fig, ax, range[i])
_set_xlim(force_range, new_fig, ax, range[i])
ax.set_xscale(axes_scale[i])
if scale_hist:
maxn = np.max(n)
_set_ylim(new_fig, ax, [-0.1 * maxn, 1.1 * maxn])
_set_ylim(force_range, new_fig, ax, [-0.1 * maxn, 1.1 * maxn])

else:
_set_ylim(new_fig, ax, [0, 1.1 * np.max(n)])
_set_ylim(force_range, new_fig, ax, [0, 1.1 * np.max(n)])

ax.set_yticklabels([])
if max_n_ticks == 0:
Expand Down Expand Up @@ -377,6 +378,7 @@ def corner_impl(
smooth=smooth,
bins=[bins[j], bins[i]],
new_fig=new_fig,
force_range=force_range,
**hist2d_kwargs,
)

Expand Down Expand Up @@ -537,6 +539,7 @@ def hist2d(
data_kwargs=None,
pcolor_kwargs=None,
new_fig=True,
force_range=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -791,8 +794,8 @@ def hist2d(
contour_kwargs["colors"] = contour_kwargs.get("colors", color)
ax.contour(X2, Y2, H2.T, V, **contour_kwargs)

_set_xlim(new_fig, ax, range[0])
_set_ylim(new_fig, ax, range[1])
_set_xlim(force_range, new_fig, ax, range[0])
_set_ylim(force_range, new_fig, ax, range[1])
ax.set_xscale(axes_scale[0])
ax.set_yscale(axes_scale[1])

Expand Down Expand Up @@ -822,7 +825,7 @@ def overplot_lines(fig, xs, reverse=False, **kwargs):
"""
K = len(xs)
axes = _get_fig_axes(fig, K)
axes, _ = _get_fig_axes(fig, K)
if reverse:
for k1 in range(K):
if xs[k1] is not None:
Expand Down Expand Up @@ -871,7 +874,7 @@ def overplot_points(fig, xs, reverse=False, **kwargs):
kwargs["linestyle"] = kwargs.pop("linestyle", "none")
xs = _parse_input(xs)
K = len(xs)
axes = _get_fig_axes(fig, K)
axes, _ = _get_fig_axes(fig, K)
if reverse:
for k1 in range(K):
for k2 in range(k1):
Expand All @@ -895,9 +898,9 @@ def _parse_input(xs):

def _get_fig_axes(fig, K):
if not fig.axes:
return fig.subplots(K, K)
return fig.subplots(K, K), True
try:
return np.array(fig.axes).reshape((K, K))
return np.array(fig.axes).reshape((K, K)), False
except ValueError:
raise ValueError(
(
Expand All @@ -907,15 +910,15 @@ def _get_fig_axes(fig, K):
)


def _set_xlim(new_fig, ax, new_xlim):
if new_fig:
def _set_xlim(force, new_fig, ax, new_xlim):
if force or new_fig:
return ax.set_xlim(new_xlim)
xlim = ax.get_xlim()
return ax.set_xlim([min(xlim[0], new_xlim[0]), max(xlim[1], new_xlim[1])])


def _set_ylim(new_fig, ax, new_ylim):
if new_fig:
def _set_ylim(force, new_fig, ax, new_ylim):
if force or new_fig:
return ax.set_ylim(new_ylim)
ylim = ax.get_ylim()
return ax.set_ylim([min(ylim[0], new_ylim[0]), max(ylim[1], new_ylim[1])])
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 15 additions & 29 deletions tests/test_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,6 @@

import corner

try:
import arviz as az
except ImportError:
az = None

try:
import pandas as pd
except ImportError:
pd = None

try:
import scipy # noqa
except ImportError:
scipy_installed = False
else:
scipy_installed = True


def _run_corner(
pandas=False,
Expand All @@ -49,11 +32,12 @@ def _run_corner(
if exp_data:
data = 10**data
if pandas:
# data = pd.DataFrame.from_items()
pd = pytest.importorskip("pandas")
data = pd.DataFrame.from_dict(
OrderedDict(zip(map("d{0}".format, range(ndim)), data.T))
)
elif arviz:
az = pytest.importorskip("arviz")
data = az.from_dict(
posterior={"x": data[None]},
sample_stats={"diverging": data[None, :, 0] < 0.0},
Expand Down Expand Up @@ -198,35 +182,35 @@ def test_bins_log():
_run_corner(exp_data=True, axes_scale="log", bins=25)


@pytest.mark.skipif(not scipy_installed, reason="requires scipy for smoothing")
@image_comparison(
baseline_images=["smooth"], remove_text=True, extensions=["png"]
)
def test_smooth():
pytest.importorskip("scipy")
_run_corner(bins=50, smooth=1.0)


@pytest.mark.skipif(not scipy_installed, reason="requires scipy for smoothing")
@image_comparison(
baseline_images=["smooth_log"], remove_text=True, extensions=["png"]
)
def test_smooth_log():
pytest.importorskip("scipy")
_run_corner(exp_data=True, axes_scale="log", bins=50, smooth=1.0)


@pytest.mark.skipif(not scipy_installed, reason="requires scipy for smoothing")
@image_comparison(
baseline_images=["smooth1d"], remove_text=True, extensions=["png"]
)
def test_smooth1d():
pytest.importorskip("scipy")
_run_corner(bins=50, smooth=1.0, smooth1d=1.0)


@pytest.mark.skipif(not scipy_installed, reason="requires scipy for smoothing")
@image_comparison(
baseline_images=["smooth1d_log"], remove_text=True, extensions=["png"]
)
def test_smooth1d_log():
pytest.importorskip("scipy")
_run_corner(
exp_data=True, axes_scale="log", bins=50, smooth=1.0, smooth1d=1.0
)
Expand All @@ -249,7 +233,6 @@ def test_top_ticks():
_run_corner(top_ticks=True)


@pytest.mark.skipif(pd is None, reason="requires pandas")
@image_comparison(baseline_images=["pandas"], extensions=["png"])
def test_pandas():
_run_corner(pandas=True)
Expand Down Expand Up @@ -356,13 +339,8 @@ def test_reverse_overplotting():
[4 * nsamples // 5, ndim]
)
mean = 4 * np.random.rand(ndim)
data2 = mean[None, :] + np.random.randn(ndim * nsamples // 5).reshape(
[nsamples // 5, ndim]
)
samples = np.vstack([data1, data2])

value1 = mean
# This is the empirical mean of the sample:
value2 = np.mean(data1, axis=0)

corner.overplot_lines(figure, value1, color="C1", reverse=True)
Expand Down Expand Up @@ -391,7 +369,15 @@ def test_hist_bin_factor_log():
_run_corner(exp_data=True, axes_scale="log", hist_bin_factor=4)


@pytest.mark.skipif(az is None, reason="requires arviz")
@image_comparison(baseline_images=["arviz"], extensions=["png"])
def test_arviz():
_run_corner(arviz=True)


@image_comparison(
baseline_images=["range_fig_arg"], remove_text=True, extensions=["png"]
)
def test_range_fig_arg():
fig = pl.figure()
ranges = [(-1.1, 1), 0.8, (-1, 1)]
_run_corner(N=100_000, range=ranges, fig=fig)

0 comments on commit e9f5396

Please sign in to comment.