From 022e4bd75bd838ca52b2df1a34bebd6438b5efd4 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 13 Jul 2022 20:26:33 -0400 Subject: [PATCH] Add Plot.label method (#2902) * Add Plot.label method * Satisfy mypy (I'm not sure I understand why it's confused here) * Test legend title customization * Refactor label resolution --- seaborn/_core/plot.py | 65 +++++++++++++++++++++++++++++++++------- tests/_core/test_plot.py | 35 ++++++++++++++++++++-- 2 files changed, 88 insertions(+), 12 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index e42a201e8a..b85132335d 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -10,7 +10,7 @@ import textwrap from collections import abc from collections.abc import Callable, Generator, Hashable -from typing import Any, cast +from typing import Any, Optional, cast import pandas as pd from pandas import DataFrame, Series, Index @@ -147,6 +147,7 @@ class Plot: _scales: dict[str, Scale] _limits: dict[str, tuple[Any, Any]] + _labels: dict[str, str | Callable[[str], str] | None] _subplot_spec: dict[str, Any] # TODO values type _facet_spec: FacetSpec @@ -172,6 +173,7 @@ def __init__( self._scales = {} self._limits = {} + self._labels = {} self._subplot_spec = {} self._facet_spec = {} @@ -552,8 +554,8 @@ def limit(self, **limits: tuple[Any, Any]) -> Plot: Keywords correspond to variables defined in the plot, and values are a (min, max) tuple (where either can be `None` to leave unset). - Limits apply only to the axis scale; data outside the limits are still - used in any stat transforms and added to the plot. + Limits apply only to the axis; data outside the visible range are + still used for any stat transforms and added to the plot. Behavior for non-coordinate variables is currently undefined. @@ -562,6 +564,25 @@ def limit(self, **limits: tuple[Any, Any]) -> Plot: new._limits.update(limits) return new + def label(self, **labels: str | Callable[[str], str] | None) -> Plot: + """ + Control the labels used for variables in the plot. + + For coordinate variables, this sets the axis label. + For semantic variables, it sets the legend title. + + Keywords correspond to variables defined in the plot. + Values can be one of the following types:: + + - string (used literally) + - function (called on the default label) + - None (disables the label for this variable) + + """ + new = self._clone() + new._labels.update(labels) + return new + def configure( self, figsize: tuple[float, float] | None = None, @@ -768,6 +789,20 @@ def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]: return common_data, layers + def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str | None: + + label: str | None + if var in p._labels: + manual_label = p._labels[var] + if callable(manual_label) and auto_label is not None: + label = manual_label(auto_label) + else: + # mypy needs a lot of help here, I'm not sure why + label = cast(Optional[str], manual_label) + else: + label = auto_label + return label + def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: # --- Parsing the faceting/pairing parameterization to specify figure grid @@ -797,6 +832,9 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: ax = sub["ax"] for axis in "xy": axis_key = sub[axis] + + # ~~ Axis labels + # TODO Should we make it possible to use only one x/y label for # all rows/columns in a faceted plot? Maybe using sub{axis}label, # although the alignments of the labels from that method leaves @@ -805,9 +843,12 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: common.names.get(axis_key), *(layer["data"].names.get(axis_key) for layer in layers) ] - label = next((name for name in names if name is not None), None) + auto_label = next((name for name in names if name is not None), None) + label = self._resolve_label(p, axis_key, auto_label) ax.set(**{f"{axis}label": label}) + # ~~ Decoration visibility + # TODO there should be some override (in Plot.configure?) so that # tick labels can be shown on interior shared axes axis_obj = getattr(ax, f"{axis}axis") @@ -1151,9 +1192,7 @@ def get_order(var): df = self._unscale_coords(subplots, df, orient) grouping_vars = mark._grouping_props + default_grouping_vars - split_generator = self._setup_split_generator( - grouping_vars, df, subplots - ) + split_generator = self._setup_split_generator(grouping_vars, df, subplots) mark._plot(split_generator, scales, orient) @@ -1162,7 +1201,7 @@ def get_order(var): view["ax"].autoscale_view() if layer["legend"]: - self._update_legend_contents(mark, data, scales) + self._update_legend_contents(p, mark, data, scales) def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: # TODO stricter type on subplots @@ -1357,7 +1396,11 @@ def split_generator(keep_na=False) -> Generator: return split_generator def _update_legend_contents( - self, mark: Mark, data: PlotData, scales: dict[str, Scale] + self, + p: Plot, + mark: Mark, + data: PlotData, + scales: dict[str, Scale], ) -> None: """Add legend artists / labels for one layer in the plot.""" if data.frame.empty and data.frames: @@ -1382,7 +1425,9 @@ def _update_legend_contents( part_vars.append(var) break else: - entry = (data.names[var], data.ids[var]), [var], (values, labels) + auto_title = data.names[var] + title = self._resolve_label(p, var, auto_title) + entry = (title, data.ids[var]), [var], (values, labels) schema.append(entry) # Second pass, generate an artist corresponding to each value diff --git a/tests/_core/test_plot.py b/tests/_core/test_plot.py index 9cf7999be3..093e7e2447 100644 --- a/tests/_core/test_plot.py +++ b/tests/_core/test_plot.py @@ -999,8 +999,8 @@ def test_limits(self, long_df): limit = (-2, 24) p = Plot(long_df, x="x", y="y").limit(x=limit).plot() - ax1 = p._figure.axes[0] - assert ax1.get_xlim() == limit + ax = p._figure.axes[0] + assert ax.get_xlim() == limit limit = (np.datetime64("2005-01-01"), np.datetime64("2008-01-01")) p = Plot(long_df, x="d", y="y").limit(x=limit).plot() @@ -1012,6 +1012,30 @@ def test_limits(self, long_df): ax = p._figure.axes[0] assert ax.get_xlim() == (0.5, 2.5) + def test_labels_axis(self, long_df): + + label = "Y axis" + p = Plot(long_df, x="x", y="y").label(y=label).plot() + ax = p._figure.axes[0] + assert ax.get_ylabel() == label + + label = str.capitalize + p = Plot(long_df, x="x", y="y").label(y=label).plot() + ax = p._figure.axes[0] + assert ax.get_ylabel() == "Y" + + def test_labels_legend(self, long_df): + + m = MockMark() + + label = "A" + p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=label).plot() + assert p._figure.legends[0].get_title().get_text() == label + + func = str.capitalize + p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=func).plot() + assert p._figure.legends[0].get_title().get_text() == label + class TestFacetInterface: @@ -1406,6 +1430,13 @@ def test_limits(self, long_df): ax1 = p._figure.axes[1] assert ax1.get_xlim() == limit + def test_labels(self, long_df): + + label = "Z" + p = Plot(long_df, y="y").pair(x=["x", "z"]).label(x1=label).plot() + ax1 = p._figure.axes[1] + assert ax1.get_xlabel() == label + class TestLabelVisibility: