Skip to content

Commit

Permalink
Improve relational/categorical legends to show non-semantic properties (
Browse files Browse the repository at this point in the history
#3467)

* Begin refactoring relational legend setup

* Move legend creation methods down to base  VectorPlotter

* Convert RelationalPlotter legend approach to method rather than named attribute

* Refactor and update scatterplot legend tests

* Add utility function for translating scatter kws to Line2D legend artist

* Use line artist for swarmplot legend

* Fix accidental change to default stripplot edge color

* Update some tests based on new legend artists

* Update relational legends

* Add some custom kwarg support for relplot legends

* Add tests for new legend behavior in categorical module

* Add tests for legend syncing in relational module

* Make CategoricalPlotter directly inherit VectorPlotter

* Handle face/edge better in scatter legends

* Allow plotting methods to specify legend attributes

* Matplotlib backcompat

* Fix legends with semantic_kws

* Remove unnecessary legend modification from example

* Remove legend_func class attribute
  • Loading branch information
mwaskom committed Sep 11, 2023
1 parent e71d95b commit fbc44d5
Show file tree
Hide file tree
Showing 9 changed files with 580 additions and 493 deletions.
3 changes: 0 additions & 3 deletions examples/heat_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""
import seaborn as sns
from seaborn._compat import get_legend_handles
sns.set_theme(style="whitegrid")

# Load the brain networks dataset, select subset, and collapse the multi-index
Expand Down Expand Up @@ -38,5 +37,3 @@
g.ax.margins(.02)
for label in g.ax.get_xticklabels():
label.set_rotation(90)
for artist in get_legend_handles(g.legend):
artist.set_edgecolor(".7")
133 changes: 133 additions & 0 deletions seaborn/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
)
from seaborn.utils import (
_check_argument,
_version_predates,
desaturate,
locator_to_legend_entries,
get_color_cycle,
remove_na,
)
Expand Down Expand Up @@ -1200,6 +1202,137 @@ def _add_axis_labels(self, ax, default_x="", default_y=""):
y_visible = any(t.get_visible() for t in ax.get_yticklabels())
ax.set_ylabel(self.variables.get("y", default_y), visible=y_visible)

def add_legend_data(
self, ax, func, common_kws=None, attrs=None, semantic_kws=None,
):
"""Add labeled artists to represent the different plot semantics."""
verbosity = self.legend
if isinstance(verbosity, str) and verbosity not in ["auto", "brief", "full"]:
err = "`legend` must be 'auto', 'brief', 'full', or a boolean."
raise ValueError(err)
elif verbosity is True:
verbosity = "auto"

keys = []
legend_kws = {}
common_kws = {} if common_kws is None else common_kws.copy()
semantic_kws = {} if semantic_kws is None else semantic_kws.copy()

# Assign a legend title if there is only going to be one sub-legend,
# otherwise, subtitles will be inserted into the texts list with an
# invisible handle (which is a hack)
titles = {
title for title in
(self.variables.get(v, None) for v in ["hue", "size", "style"])
if title is not None
}
title = "" if len(titles) != 1 else titles.pop()
title_kws = dict(
visible=False, color="w", s=0, linewidth=0, marker="", dashes=""
)

def update(var_name, val_name, **kws):

key = var_name, val_name
if key in legend_kws:
legend_kws[key].update(**kws)
else:
keys.append(key)
legend_kws[key] = dict(**kws)

if attrs is None:
attrs = {"hue": "color", "size": ["linewidth", "s"], "style": None}
for var, names in attrs.items():
self._update_legend_data(
update, var, verbosity, title, title_kws, names, semantic_kws.get(var),
)

legend_data = {}
legend_order = []

# Don't allow color=None so we can set a neutral color for size/style legends
if common_kws.get("color", False) is None:
common_kws.pop("color")

for key in keys:

_, label = key
kws = legend_kws[key]
level_kws = {}
use_attrs = [
*self._legend_attributes,
*common_kws,
*[attr for var_attrs in semantic_kws.values() for attr in var_attrs],
]
for attr in use_attrs:
if attr in kws:
level_kws[attr] = kws[attr]
artist = func(label=label, **{"color": ".2", **common_kws, **level_kws})
if _version_predates(mpl, "3.5.0"):
if isinstance(artist, mpl.lines.Line2D):
ax.add_line(artist)
elif isinstance(artist, mpl.patches.Patch):
ax.add_patch(artist)
elif isinstance(artist, mpl.collections.Collection):
ax.add_collection(artist)
else:
ax.add_artist(artist)
legend_data[key] = artist
legend_order.append(key)

self.legend_title = title
self.legend_data = legend_data
self.legend_order = legend_order

def _update_legend_data(
self,
update,
var,
verbosity,
title,
title_kws,
attr_names,
other_props,
):
"""Generate legend tick values and formatted labels."""
brief_ticks = 6
mapper = getattr(self, f"_{var}_map", None)
if mapper is None:
return

brief = mapper.map_type == "numeric" and (
verbosity == "brief"
or (verbosity == "auto" and len(mapper.levels) > brief_ticks)
)
if brief:
if isinstance(mapper.norm, mpl.colors.LogNorm):
locator = mpl.ticker.LogLocator(numticks=brief_ticks)
else:
locator = mpl.ticker.MaxNLocator(nbins=brief_ticks)
limits = min(mapper.levels), max(mapper.levels)
levels, formatted_levels = locator_to_legend_entries(
locator, limits, self.plot_data[var].infer_objects().dtype
)
elif mapper.levels is None:
levels = formatted_levels = []
else:
levels = formatted_levels = mapper.levels

if not title and self.variables.get(var, None) is not None:
update((self.variables[var], "title"), self.variables[var], **title_kws)

other_props = {} if other_props is None else other_props

for level, formatted_level in zip(levels, formatted_levels):
if level is not None:
attr = mapper(level)
if isinstance(attr_names, list):
attr = {name: attr for name in attr_names}
elif attr_names is not None:
attr = {attr_names: attr}
attr.update({k: v[level] for k, v in other_props.items() if level in v})
update(self.variables[var], formatted_level, **attr)

# XXX If the scale_* methods are going to modify the plot_data structure, they
# can't be called twice. That means that if they are called twice, they should
# raise. Alternatively, we could store an original version of plot_data and each
Expand Down
56 changes: 25 additions & 31 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
import matplotlib.pyplot as plt

from seaborn._core.typing import default, deprecated
from seaborn._base import infer_orient, categorical_order
from seaborn._base import VectorPlotter, infer_orient, categorical_order
from seaborn._stats.density import KDE
from seaborn.relational import _RelationalPlotter
from seaborn import utils
from seaborn.utils import (
desaturate,
_check_argument,
_draw_figure,
_default_color,
_get_patch_legend_artist,
_get_transform_functions,
_normalize_kwargs,
_scatter_legend_artist,
_version_predates,
)
from seaborn._statistics import EstimateAggregator, LetterValues
Expand All @@ -39,14 +40,11 @@
]


# Subclassing _RelationalPlotter for the legend machinery,
# but probably should move that more centrally
class _CategoricalPlotter(_RelationalPlotter):
class _CategoricalPlotter(VectorPlotter):

wide_structure = {"x": "@columns", "y": "@values"}
flat_structure = {"y": "@values"}

_legend_func = "scatter"
_legend_attributes = ["color"]

def __init__(
Expand All @@ -63,13 +61,13 @@ def __init__(

# This method takes care of some bookkeeping that is necessary because the
# original categorical plots (prior to the 2021 refactor) had some rules that
# don't fit exactly into the logic of _core. It may be wise to have a second
# don't fit exactly into VectorPlotter logic. It may be wise to have a second
# round of refactoring that moves the logic deeper, but this will keep things
# relatively sensible for now.

# For wide data, orient determines assignment to x/y differently from the
# wide_structure rules in _core. If we do decide to make orient part of the
# _core variable assignment, we'll want to figure out how to express that.
# default VectorPlotter rules. If we do decide to make orient part of the
# _base variable assignment, we'll want to figure out how to express that.
if self.input_format == "wide" and orient in ["h", "y"]:
self.plot_data = self.plot_data.rename(columns={"x": "y", "y": "x"})
orig_variables = set(self.variables)
Expand All @@ -85,7 +83,7 @@ def __init__(
self.var_types["x"] = orig_y_type

# The concept of an "orientation" is important to the original categorical
# plots, but there's no provision for it in _core, so we need to do it here.
# plots, but there's no provision for it in VectorPlotter, so we need it here.
# Note that it could be useful for the other functions in at least two ways
# (orienting a univariate distribution plot from long-form data and selecting
# the aggregation axis in lineplot), so we may want to eventually refactor it.
Expand Down Expand Up @@ -403,7 +401,7 @@ def _configure_legend(self, ax, func, common_kws=None, semantic_kws=None):
show_legend = bool(self.legend)

if show_legend:
self.add_legend_data(ax, func, common_kws, semantic_kws)
self.add_legend_data(ax, func, common_kws, semantic_kws=semantic_kws)
handles, _ = ax.get_legend_handles_labels()
if handles:
ax.legend(title=self.legend_title)
Expand Down Expand Up @@ -488,7 +486,7 @@ def plot_strips(
if "hue" in self.variables:
points.set_facecolors(self._hue_map(sub_data["hue"]))

self._configure_legend(ax, ax.scatter)
self._configure_legend(ax, _scatter_legend_artist, common_kws=plot_kws)

def plot_swarms(
self,
Expand Down Expand Up @@ -558,7 +556,7 @@ def draw(points, renderer, *, center=center):
points.draw = draw.__get__(points)

_draw_figure(ax.figure)
self._configure_legend(ax, ax.scatter)
self._configure_legend(ax, _scatter_legend_artist, plot_kws)

def plot_boxes(
self,
Expand Down Expand Up @@ -712,12 +710,8 @@ def get_props(element, artist=mpl.lines.Line2D):

ax.add_container(BoxPlotContainer(artists))

patch_kws = props["box"].copy()
if not fill:
patch_kws["facecolor"] = (1, 1, 1, 0)
else:
patch_kws["edgecolor"] = linecolor
self._configure_legend(ax, ax.fill_between, patch_kws)
legend_artist = _get_patch_legend_artist(fill)
self._configure_legend(ax, legend_artist, boxprops)

def plot_boxens(
self,
Expand Down Expand Up @@ -856,12 +850,9 @@ def plot_boxens(

ax.autoscale_view(scalex=self.orient == "y", scaley=self.orient == "x")

patch_kws = box_kws.copy()
if not fill:
patch_kws["facecolor"] = (1, 1, 1, 0)
else:
patch_kws["edgecolor"] = linecolor
self._configure_legend(ax, ax.fill_between, patch_kws)
legend_artist = _get_patch_legend_artist(fill)
common_kws = {**box_kws, "linewidth": linewidth, "edgecolor": linecolor}
self._configure_legend(ax, legend_artist, common_kws)

def plot_violins(
self,
Expand Down Expand Up @@ -1135,7 +1126,9 @@ def vars_to_key(sub_vars):
}
ax.plot(invx(x2), invy(y2), **dot_kws)

self._configure_legend(ax, ax.fill_between) # TODO, patch_kws)
legend_artist = _get_patch_legend_artist(fill)
common_kws = {**plot_kws, "linewidth": linewidth, "edgecolor": linecolor}
self._configure_legend(ax, legend_artist, common_kws)

def plot_points(
self,
Expand Down Expand Up @@ -1212,8 +1205,9 @@ def plot_points(
if aggregator.error_method is not None:
self.plot_errorbars(ax, agg_data, capsize, sub_err_kws)

legend_artist = partial(mpl.lines.Line2D, [], [])
semantic_kws = {"hue": {"marker": markers, "linestyle": linestyles}}
self._configure_legend(ax, ax.plot, sub_kws, semantic_kws)
self._configure_legend(ax, legend_artist, sub_kws, semantic_kws)

def plot_bars(
self,
Expand Down Expand Up @@ -1294,7 +1288,8 @@ def plot_bars(
{"color": ".26" if fill else main_color, **err_kws}
)

self._configure_legend(ax, ax.fill_between)
legend_artist = _get_patch_legend_artist(fill)
self._configure_legend(ax, legend_artist, plot_kws)

def plot_errorbars(self, ax, data, capsize, err_kws):

Expand Down Expand Up @@ -2041,7 +2036,7 @@ def boxenplot(
def stripplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
jitter=True, dodge=False, orient=None, color=None, palette=None,
size=5, edgecolor="face", linewidth=0,
size=5, edgecolor=default, linewidth=0,
hue_norm=None, log_scale=None, native_scale=False, formatter=None, legend="auto",
ax=None, **kwargs
):
Expand Down Expand Up @@ -2810,8 +2805,7 @@ def catplot(
if saturation < 1:
color = desaturate(color, saturation)

edgecolor = kwargs.pop("edgecolor", "face" if kind == "strip" else "auto")
edgecolor = p._complement_color(edgecolor, color, p._hue_map)
edgecolor = p._complement_color(kwargs.pop("edgecolor", default), color, p._hue_map)

width = kwargs.pop("width", 0.8)
dodge = kwargs.pop("dodge", False if kind in undodged_kinds else "auto")
Expand Down

0 comments on commit fbc44d5

Please sign in to comment.