Skip to content

Commit

Permalink
Merge pull request #152 from dfm/arviz-opt
Browse files Browse the repository at this point in the history
Making arviz an optional dep
  • Loading branch information
dfm committed Mar 9, 2021
2 parents 66205f5 + 45913eb commit c0c6264
Show file tree
Hide file tree
Showing 7 changed files with 1,044 additions and 868 deletions.
13 changes: 10 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@ jobs:
runs-on: "ubuntu-latest"
strategy:
matrix:
python-version: [3.9]
python-version: ["3.9"]
arviz-version:
- ""
- "arviz~=0.9"
- "arviz~=0.10"
- "arviz~=0.11"
- "https://github.com/arviz-devs/arviz/archive/main.zip"
include:
- python-version: "3.6"
arviz-version: "arviz"
- python-version: "3.7"
arviz-version: "arviz"
- python-version: "3.8"
arviz-version: "arviz"

steps:
- uses: actions/checkout@v2
Expand All @@ -31,8 +39,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install $ARVIZ
python -m pip install ".[test]"
python -m pip install $ARVIZ ".[test]"
env:
ARVIZ: ${{ matrix.arviz-version }}

Expand Down
10 changes: 8 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,23 @@
"Programming Language :: Python",
"Programming Language :: Python :: 3",
]
INSTALL_REQUIRES = ["arviz>=0.9"]
INSTALL_REQUIRES = ["matplotlib>=2.1"]
EXTRA_REQUIRE = {
"arviz": ["arviz>=0.9"],
"test": [
"pytest>=3.6",
"pytest-cov>=2.6.1",
"black",
"isort",
"toml",
],
"docs": ["sphinx>=1.7.5", "pandoc", "myst-nb", "sphinx-book-theme"],
}
EXTRA_REQUIRE["docs"] = EXTRA_REQUIRE["arviz"] + [
"sphinx>=1.7.5",
"pandoc",
"myst-nb",
"sphinx-book-theme",
]
EXTRA_REQUIRE["dev"] = (
EXTRA_REQUIRE["test"]
+ EXTRA_REQUIRE["docs"]
Expand Down
3 changes: 2 additions & 1 deletion src/corner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"]

from .corner import corner, hist2d, overplot_lines, overplot_points, quantile
from .core import hist2d, overplot_lines, overplot_points, quantile
from .corner import corner
from .corner_version import __version__ # noqa

__author__ = "Dan Foreman-Mackey"
Expand Down
180 changes: 180 additions & 0 deletions src/corner/arviz_corner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-

__all__ = ["arviz_corner"]

import logging
from collections.abc import Mapping

import numpy as np
from arviz.data import convert_to_dataset
from arviz.utils import _var_names, get_coords

# Support multiple versions of arviz
try:
from arviz.plots.plot_utils import (
make_label,
xarray_to_ndarray,
xarray_var_iter,
)

def _get_labels(plotters, labeller=None):
return [
make_label(var_name, selection)
for var_name, selection, _ in plotters
]


except ImportError:
from arviz.labels import BaseLabeller
from arviz.sel_utils import xarray_to_ndarray, xarray_var_iter

def _get_labels(plotters, labeller=None):
if labeller is None:
labeller = BaseLabeller()
return [
labeller.make_label_vert(var_name, sel, isel)
for var_name, sel, isel, _ in plotters
]


from .core import corner_impl, overplot_points


def arviz_corner(
data,
bins=20,
*,
# Original corner parameters
range=None,
weights=None,
color="k",
hist_bin_factor=1,
smooth=None,
smooth1d=None,
labels=None,
label_kwargs=None,
titles=None,
show_titles=False,
title_fmt=".2f",
title_kwargs=None,
truths=None,
truth_color="#4682b4",
scale_hist=False,
quantiles=None,
verbose=False,
fig=None,
max_n_ticks=5,
top_ticks=False,
use_math_text=False,
reverse=False,
labelpad=0.0,
hist_kwargs=None,
# Arviz parameters
group="posterior",
var_names=None,
filter_vars=None,
coords=None,
divergences=False,
divergences_kwargs=None,
labeller=None,
**hist2d_kwargs
):
is_np = False
if isinstance(data, np.ndarray):
is_np = True
if data.ndim == 1:
data = data[None, :, :]
elif data.ndim == 2:
data = data[None, :, :]
elif data.ndim != 3:
raise ValueError("invalid input dimensions")
if data.__class__.__name__ == "DataFrame":
logging.warning(
"Pandas support in corner is deprecated; use ArviZ directly"
)
data = {k: np.asarray(data[k])[None] for k in list(data.columns)}

if coords is None:
coords = {}

# Get posterior draws and combine chains
dataset = convert_to_dataset(data, group=group)
var_names = _var_names(var_names, dataset, filter_vars)
plotters = list(
xarray_var_iter(
get_coords(dataset, coords), var_names=var_names, combined=True
)
)
if labels is None and not is_np:
labels = _get_labels(plotters, labeller=labeller)
if var_names is None:
var_names = dataset.data_vars

divergent_data = None
diverging_mask = None

# Assigning divergence group based on group param
if group == "posterior":
divergent_group = "sample_stats"
elif group == "prior":
divergent_group = "sample_stats_prior"
else:
divergences = False

# Reformat truths and titles as lists if they are mappings
if isinstance(truths, Mapping):
truths = np.concatenate(
[np.asarray(truths[k]).flatten() for k in var_names]
)
if isinstance(titles, Mapping):
titles = np.concatenate(
[np.asarray(titles[k]).flatten() for k in var_names]
)

# Coerce the samples into the expected format
samples = np.stack([x[-1].flatten() for x in plotters], axis=-1)
fig = corner_impl(
samples,
bins=bins,
range=range,
weights=weights,
color=color,
hist_bin_factor=hist_bin_factor,
smooth=smooth,
smooth1d=smooth1d,
labels=labels,
label_kwargs=label_kwargs,
titles=titles,
show_titles=show_titles,
title_fmt=title_fmt,
title_kwargs=title_kwargs,
truths=truths,
truth_color=truth_color,
scale_hist=scale_hist,
quantiles=quantiles,
verbose=verbose,
fig=fig,
max_n_ticks=max_n_ticks,
top_ticks=top_ticks,
use_math_text=use_math_text,
reverse=reverse,
labelpad=labelpad,
hist_kwargs=hist_kwargs,
**hist2d_kwargs,
)

# Get diverging draws and combine chains
if divergences:
if hasattr(data, divergent_group) and hasattr(
getattr(data, divergent_group), "diverging"
):
divergent_data = convert_to_dataset(data, group=divergent_group)
_, diverging_mask = xarray_to_ndarray(
divergent_data, var_names=("diverging",), combined=True
)
diverging_mask = np.squeeze(diverging_mask)
if divergences_kwargs is None:
divergences_kwargs = {"color": "C1", "ms": 1}
overplot_points(fig, samples[diverging_mask], **divergences_kwargs)

return fig

0 comments on commit c0c6264

Please sign in to comment.