Skip to content

Commit

Permalink
Add Est stat and Interval mark to show error bars (#2912)
Browse files Browse the repository at this point in the history
* Add a docstring for seaborn.objects namespace

* Add Est stat (mostly copied from EstimateAggregator

* Handle cases where x or y are not defined better

* Improve datalim update with collections

* Handle matplotlib edge cases with line capstyles

* Add Interval mark

* Add Interval unit tests

* Revert Est to use EstimateAggregator and add (light) tests

* Pandas (?) backcompat
  • Loading branch information
mwaskom committed Jul 24, 2022
1 parent b5a85ff commit fb61ded
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 47 deletions.
9 changes: 5 additions & 4 deletions seaborn/_core/plot.py
Expand Up @@ -1338,10 +1338,11 @@ def split_generator(keep_na=False) -> Generator:
# Matplotlib (usually?) masks nan data, so this should "work".
# Downstream code can also drop these rows, at some speed cost.
present = axes_df.notna().all(axis=1)
axes_df = axes_df.assign(
x=axes_df["x"].where(present),
y=axes_df["y"].where(present),
)
nulled = {}
for axis in "xy":
if axis in axes_df:
nulled[axis] = axes_df[axis].where(present)
axes_df = axes_df.assign(**nulled)
else:
axes_df = axes_df.dropna()

Expand Down
5 changes: 4 additions & 1 deletion seaborn/_core/scales.py
Expand Up @@ -350,7 +350,10 @@ def normalize(x):
]

def spacer(x):
return np.min(np.diff(np.sort(x.dropna().unique())))
x = x.dropna().unique()
if len(x) < 2:
return np.nan
return np.min(np.diff(np.sort(x)))
new._spacer = spacer

# TODO How to allow disabling of legend for all uses of property?
Expand Down
6 changes: 2 additions & 4 deletions seaborn/_marks/bars.py
Expand Up @@ -200,10 +200,8 @@ def _plot(self, split_gen, scales, orient):
# Workaround for matplotlib autoscaling bug
# https://github.com/matplotlib/matplotlib/issues/11898
# https://github.com/matplotlib/matplotlib/issues/23129
xy = np.vstack([path.vertices for path in col.get_paths()])
ax.dataLim.update_from_data_xy(
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
)
xys = np.vstack([path.vertices for path in col.get_paths()])
ax.update_datalim(xys)

if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):

Expand Down
94 changes: 80 additions & 14 deletions seaborn/_marks/lines.py
Expand Up @@ -50,6 +50,9 @@ def _plot(self, split_gen, scales, orient):
if self._sort:
data = data.sort_values(orient)

artist_kws = self.artist_kws.copy()
self._handle_capstyle(artist_kws, vals)

line = mpl.lines.Line2D(
data["x"].to_numpy(),
data["y"].to_numpy(),
Expand All @@ -61,7 +64,7 @@ def _plot(self, split_gen, scales, orient):
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**self.artist_kws,
**artist_kws,
)
ax.add_line(line)

Expand All @@ -77,6 +80,9 @@ def _legend_artist(self, variables, value, scales):
if Version(mpl.__version__) < Version("3.3.0"):
vals["marker"] = vals["marker"]._marker

artist_kws = self.artist_kws.copy()
self._handle_capstyle(artist_kws, vals)

return mpl.lines.Line2D(
[], [],
color=vals["color"],
Expand All @@ -87,9 +93,17 @@ def _legend_artist(self, variables, value, scales):
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**self.artist_kws,
**artist_kws,
)

def _handle_capstyle(self, kws, vals):

# Work around for this matplotlib issue:
# https://github.com/matplotlib/matplotlib/issues/23437
if vals["linestyle"][1] is None:
capstyle = kws.get("solid_capstyle", mpl.rcParams["lines.solid_capstyle"])
kws["dash_capstyle"] = capstyle


@dataclass
class Line(Path):
Expand All @@ -111,7 +125,15 @@ class Paths(Mark):

_sort: ClassVar[bool] = False

def _plot(self, split_gen, scales, orient):
def __post_init__(self):

# LineCollection artists have a capstyle property but don't source its value
# from the rc, so we do that manually here. Unfortunately, because we add
# only one LineCollection, we have the use the same capstyle for all lines
# even when they are dashed. It's a slight inconsistency, but looks fine IMO.
self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"])

def _setup_lines(self, split_gen, scales, orient):

line_data = {}

Expand All @@ -131,36 +153,42 @@ def _plot(self, split_gen, scales, orient):
if self._sort:
data = data.sort_values(orient)

# TODO comment about block consolidation
# Column stack to avoid block consolidation
xy = np.column_stack([data["x"], data["y"]])
line_data[ax]["segments"].append(xy)
line_data[ax]["colors"].append(vals["color"])
line_data[ax]["linewidths"].append(vals["linewidth"])
line_data[ax]["linestyles"].append(vals["linestyle"])

return line_data

def _plot(self, split_gen, scales, orient):

line_data = self._setup_lines(split_gen, scales, orient)

for ax, ax_data in line_data.items():
lines = mpl.collections.LineCollection(
**ax_data,
**self.artist_kws,
)
ax.add_collection(lines, autolim=False)
lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws)
# Handle datalim update manually
# https://github.com/matplotlib/matplotlib/issues/23129
# TODO get paths from lines object?
ax.add_collection(lines, autolim=False)
xy = np.concatenate(ax_data["segments"])
ax.dataLim.update_from_data_xy(
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
)
ax.update_datalim(xy)

def _legend_artist(self, variables, value, scales):

key = resolve_properties(self, {v: value for v in variables}, scales)

artist_kws = self.artist_kws.copy()
capstyle = artist_kws.pop("capstyle")
artist_kws["solid_capstyle"] = capstyle
artist_kws["dash_capstyle"] = capstyle

return mpl.lines.Line2D(
[], [],
color=key["color"],
linewidth=key["linewidth"],
linestyle=key["linestyle"],
**self.artist_kws,
**artist_kws,
)


Expand All @@ -170,3 +198,41 @@ class Lines(Paths):
A faster but less-flexible mark for drawing many lines.
"""
_sort: ClassVar[bool] = True


@dataclass
class Interval(Paths):
"""
An oriented line mark drawn between min/max values.
"""
def _setup_lines(self, split_gen, scales, orient):

line_data = {}

other = {"x": "y", "y": "x"}[orient]

for keys, data, ax in split_gen(keep_na=not self._sort):

if ax not in line_data:
line_data[ax] = {
"segments": [],
"colors": [],
"linewidths": [],
"linestyles": [],
}

vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)

cols = [orient, f"{other}min", f"{other}max"]
data = data[cols].melt(orient, value_name=other)[["x", "y"]]
segments = [d.to_numpy() for _, d in data.groupby(orient)]

line_data[ax]["segments"].extend(segments)

n = len(segments)
line_data[ax]["colors"].extend([vals["color"]] * n)
line_data[ax]["linewidths"].extend([vals["linewidth"]] * n)
line_data[ax]["linestyles"].extend([vals["linestyle"]] * n)

return line_data
76 changes: 57 additions & 19 deletions seaborn/_stats/aggregation.py
@@ -1,14 +1,16 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Callable

import pandas as pd
from pandas import DataFrame

from seaborn._core.scales import Scale
from seaborn._core.groupby import GroupBy
from seaborn._stats.base import Stat
from seaborn._statistics import EstimateAggregator

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable
from numbers import Number
from seaborn._core.typing import Vector
from seaborn._core.typing import Vector


@dataclass
Expand All @@ -18,23 +20,22 @@ class Agg(Stat):
Parameters
----------
func
Name of a method understood by Pandas or an arbitrary vector -> scalar function.
func : str or callable
Name of a :class:`pandas.Series` method or a vector -> scalar function.
"""
# TODO In current practice we will always have a numeric x/y variable,
# but they may represent non-numeric values. Needs clear documentation.
func: str | Callable[[Vector], Number] = "mean"
func: str | Callable[[Vector], float] = "mean"

group_by_orient: ClassVar[bool] = True

def __call__(self, data, groupby, orient, scales):
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

var = {"x": "y", "y": "x"}.get(orient)
res = (
groupby
.agg(data, {var: self.func})
# TODO Could be an option not to drop NA?
.dropna()
.reset_index(drop=True)
)
Expand All @@ -43,19 +44,56 @@ def __call__(self, data, groupby, orient, scales):

@dataclass
class Est(Stat):
"""
Calculate a point estimate and error bar interval.
# TODO a string here must be a numpy ufunc?
func: str | Callable[[Vector], Number] = "mean"
Parameters
----------
func : str or callable
Name of a :class:`numpy.ndarray` method or a vector -> scalar function.
errorbar : str, (str, float) tuple, or callable
Name of errorbar method (one of "ci", "pi", "se" or "sd"), or a tuple
with a method name ane a level parameter, or a function that maps from a
vector to a (min, max) interval.
n_boot : int
Number of bootstrap samples to draw for "ci" errorbars.
seed : int
Seed for the PRNG used to draw bootstrap samples.
# TODO type errorbar options with literal?
"""
func: str | Callable[[Vector], float] = "mean"
errorbar: str | tuple[str, float] = ("ci", 95)
n_boot: int = 1000
seed: int | None = None

group_by_orient: ClassVar[bool] = True

def __call__(self, data, groupby, orient, scales):
def _process(
self, data: DataFrame, var: str, estimator: EstimateAggregator
) -> DataFrame:
# Needed because GroupBy.apply assumes func is DataFrame -> DataFrame
# which we could probably make more general to allow Series return
res = estimator(data, var)
return pd.DataFrame([res])

# TODO port code over from _statistics
...
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)

var = {"x": "y", "y": "x"}.get(orient)
res = (
groupby
.apply(data, self._process, var, engine)
.dropna(subset=["x", "y"])
.reset_index(drop=True)
)

res = res.fillna({f"{var}min": res[var], f"{var}max": res[var]})

return res


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions seaborn/_stats/histograms.py
Expand Up @@ -31,6 +31,9 @@ class Hist(Stat):
# Q: would Discrete() scale imply binwidth=1 or bins centered on integers?
discrete: bool = False

# TODO Note that these methods are mostly copied from _statistics.Histogram,
# but it only computes univariate histograms. We should reconcile the code.

def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete):
"""Inner function that takes bin parameters as arguments."""
vals = vals.dropna()
Expand Down
4 changes: 2 additions & 2 deletions seaborn/objects.py
Expand Up @@ -31,11 +31,11 @@
from seaborn._marks.base import Mark # noqa: F401
from seaborn._marks.area import Area, Ribbon # noqa: F401
from seaborn._marks.bars import Bar, Bars # noqa: F401
from seaborn._marks.lines import Line, Lines, Path, Paths # noqa: F401
from seaborn._marks.lines import Line, Lines, Path, Paths, Interval # noqa: F401
from seaborn._marks.scatter import Dot, Scatter # noqa: F401

from seaborn._stats.base import Stat # noqa: F401
from seaborn._stats.aggregation import Agg # noqa: F401
from seaborn._stats.aggregation import Agg, Est # noqa: F401
from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401
from seaborn._stats.histograms import Hist # noqa: F401

Expand Down

0 comments on commit fb61ded

Please sign in to comment.