Skip to content

Commit

Permalink
Differentiate Line/Path and add Lines/Paths alternatives (#2822)
Browse files Browse the repository at this point in the history
* Add lines module and differentiate Path/Line

* Add markers to Line/Path and add Lines/Paths

* Implement unstatisfying but workable approach to keep_na

* Add tests for Line(s)/Path(s)

* Add backcompat for matplotlib<3.3.0
  • Loading branch information
mwaskom committed May 30, 2022
1 parent 7d1c50f commit fefd940
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 107 deletions.
3 changes: 3 additions & 0 deletions doc/nextgen/api.rst
Expand Up @@ -38,6 +38,9 @@ Marks
Bar
Dot
Line
Lines
Path
Paths
Ribbon
Scatter

Expand Down
26 changes: 21 additions & 5 deletions seaborn/_core/plot.py
Expand Up @@ -674,9 +674,14 @@ def __init__(self, pyplot=False):
] = []
self._scales: dict[str, Scale] = {}

def save(self, fname, **kwargs) -> Plotter:
def save(self, loc, **kwargs) -> Plotter: # TODO type args
kwargs.setdefault("dpi", 96)
self._figure.savefig(os.path.expanduser(fname), **kwargs)
try:
loc = os.path.expanduser(loc)
except TypeError:
# loc may be a buffer in which case that would not work
pass
self._figure.savefig(loc, **kwargs)
return self

def show(self, **kwargs) -> None:
Expand Down Expand Up @@ -1270,14 +1275,25 @@ def _setup_split_generator(
order = categorical_order(df[var])
grouping_keys.append(order)

def split_generator(dropna=True) -> Generator:
def split_generator(keep_na=False) -> Generator:

for view in subplots:

axes_df = self._filter_subplot_data(df, view)

if dropna:
with pd.option_context("mode.use_inf_as_null", True):
with pd.option_context("mode.use_inf_as_null", True):
if keep_na:
# The simpler thing to do would be x.dropna().reindex(x.index).
# But that doesn't work with the way that the subset iteration
# is written below, which assumes data for grouping vars.
# Matplotlib (usually?) masks nan data, so this should "work".
# Downstream code can also drop these rows, at some speed cost.
present = axes_df.notna().all(axis=1)
axes_df = axes_df.assign(
x=axes_df["x"].where(present),
y=axes_df["y"].where(present),
)
else:
axes_df = axes_df.dropna()

subplot_keys = {}
Expand Down
6 changes: 5 additions & 1 deletion seaborn/_marks/base.py
Expand Up @@ -261,7 +261,11 @@ def resolve_color(
"""
color = mark._resolve(data, f"{prefix}color", scales)
alpha = mark._resolve(data, f"{prefix}alpha", scales)

if f"{prefix}alpha" in mark._mappable_props:
alpha = mark._resolve(data, f"{prefix}alpha", scales)
else:
alpha = mark._resolve(data, "alpha", scales)

def visible(x, axis=None):
"""Detect "invisible" colors to set alpha appropriately."""
Expand Down
70 changes: 0 additions & 70 deletions seaborn/_marks/basic.py

This file was deleted.

172 changes: 172 additions & 0 deletions seaborn/_marks/lines.py
@@ -0,0 +1,172 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar

import numpy as np
import matplotlib as mpl

from seaborn._marks.base import (
Mark,
Mappable,
MappableFloat,
MappableString,
MappableColor,
resolve_properties,
resolve_color,
)
from seaborn.external.version import Version


@dataclass
class Path(Mark):
"""
A mark connecting data points in the order they appear.
"""
color: MappableColor = Mappable("C0")
alpha: MappableFloat = Mappable(1)
linewidth: MappableFloat = Mappable(rc="lines.linewidth")
linestyle: MappableString = Mappable(rc="lines.linestyle")
marker: MappableString = Mappable(rc="lines.marker")
pointsize: MappableFloat = Mappable(rc="lines.markersize")
fillcolor: MappableColor = Mappable(depend="color")
edgecolor: MappableColor = Mappable(depend="color")
edgewidth: MappableFloat = Mappable(rc="lines.markeredgewidth")

_sort: ClassVar[bool] = False

def _plot(self, split_gen, scales, orient):

for keys, data, ax in split_gen(keep_na=not self._sort):

vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)
vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)

# https://github.com/matplotlib/matplotlib/pull/16692
if Version(mpl.__version__) < Version("3.3.0"):
vals["marker"] = vals["marker"]._marker

if self._sort:
data = data.sort_values(orient)

line = mpl.lines.Line2D(
data["x"].to_numpy(),
data["y"].to_numpy(),
color=vals["color"],
linewidth=vals["linewidth"],
linestyle=vals["linestyle"],
marker=vals["marker"],
markersize=vals["pointsize"],
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**self.artist_kws,
)
ax.add_line(line)

def _legend_artist(self, variables, value, scales):

keys = {v: value for v in variables}
vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)
vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)

# https://github.com/matplotlib/matplotlib/pull/16692
if Version(mpl.__version__) < Version("3.3.0"):
vals["marker"] = vals["marker"]._marker

return mpl.lines.Line2D(
[], [],
color=vals["color"],
linewidth=vals["linewidth"],
linestyle=vals["linestyle"],
marker=vals["marker"],
markersize=vals["pointsize"],
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**self.artist_kws,
)


@dataclass
class Line(Path):
"""
A mark connecting data points with sorting along the orientation axis.
"""
_sort: ClassVar[bool] = True


@dataclass
class Paths(Mark):
"""
A faster but less-flexible mark for drawing many paths.
"""
color: MappableColor = Mappable("C0")
alpha: MappableFloat = Mappable(1)
linewidth: MappableFloat = Mappable(rc="lines.linewidth")
linestyle: MappableString = Mappable(rc="lines.linestyle")

_sort: ClassVar[bool] = False

def _plot(self, split_gen, scales, orient):

line_data = {}

for keys, data, ax in split_gen(keep_na=not self._sort):

if ax not in line_data:
line_data[ax] = {
"segments": [],
"colors": [],
"linewidths": [],
"linestyles": [],
}

vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)

if self._sort:
data = data.sort_values(orient)

# TODO comment about block consolidation
xy = np.column_stack([data["x"], data["y"]])
line_data[ax]["segments"].append(xy)
line_data[ax]["colors"].append(vals["color"])
line_data[ax]["linewidths"].append(vals["linewidth"])
line_data[ax]["linestyles"].append(vals["linestyle"])

for ax, ax_data in line_data.items():
lines = mpl.collections.LineCollection(
**ax_data,
**self.artist_kws,
)
ax.add_collection(lines, autolim=False)
# https://github.com/matplotlib/matplotlib/issues/23129
# TODO get paths from lines object?
xy = np.concatenate(ax_data["segments"])
ax.dataLim.update_from_data_xy(
xy, ax.ignore_existing_data_limits, updatex=True, updatey=True
)

def _legend_artist(self, variables, value, scales):

key = resolve_properties(self, {v: value for v in variables}, scales)

return mpl.lines.Line2D(
[], [],
color=key["color"],
linewidth=key["linewidth"],
linestyle=key["linestyle"],
**self.artist_kws,
)


@dataclass
class Lines(Paths):
"""
A faster but less-flexible mark for drawing many lines.
"""
_sort: ClassVar[bool] = True
2 changes: 1 addition & 1 deletion seaborn/objects.py
Expand Up @@ -4,9 +4,9 @@
from seaborn._core.plot import Plot # noqa: F401

from seaborn._marks.base import Mark # noqa: F401
from seaborn._marks.basic import Line # noqa: F401
from seaborn._marks.area import Area, Ribbon # noqa: F401
from seaborn._marks.bars import Bar # noqa: F401
from seaborn._marks.lines import Line, Lines, Path, Paths # noqa: F401
from seaborn._marks.scatter import Dot, Scatter # noqa: F401

from seaborn._stats.base import Stat # noqa: F401
Expand Down
30 changes: 0 additions & 30 deletions seaborn/tests/_marks/test_bars.py
Expand Up @@ -57,36 +57,6 @@ def test_numeric_positions_horizontal(self):
for i, bar in enumerate(bars):
self.check_bar(bar, 0, y[i] - w / 2, x[i], w)

@pytest.mark.xfail(reason="new dodge api")
def test_categorical_dodge_vertical(self):

x = ["a", "a", "b", "b"]
y = [1, 2, 3, 4]
group = ["x", "y", "x", "y"]
w = .8
bars = self.plot_bars(
{"x": x, "y": y, "group": group}, {"multiple": "dodge"}, {}
)
for i, bar in enumerate(bars[:2]):
self.check_bar(bar, i - w / 2, 0, w / 2, y[i * 2])
for i, bar in enumerate(bars[2:]):
self.check_bar(bar, i, 0, w / 2, y[i * 2 + 1])

@pytest.mark.xfail(reason="new dodge api")
def test_categorical_dodge_horizontal(self):

x = [1, 2, 3, 4]
y = ["a", "a", "b", "b"]
group = ["x", "y", "x", "y"]
w = .8
bars = self.plot_bars(
{"x": x, "y": y, "group": group}, {"multiple": "dodge"}, {}
)
for i, bar in enumerate(bars[:2]):
self.check_bar(bar, 0, i - w / 2, x[i * 2], w / 2)
for i, bar in enumerate(bars[2:]):
self.check_bar(bar, 0, i, x[i * 2 + 1], w / 2)

def test_direct_properties(self):

x = ["a", "b", "c"]
Expand Down

0 comments on commit fefd940

Please sign in to comment.