Skip to content

Commit

Permalink
Allow FacetGrid legend to work with nested labels (#1909)
Browse files Browse the repository at this point in the history
* Add test to reproduce GH#1560

* Allow FacetGrid legend to work with nested labels

Fixes #1560

* Update comments and release notes
  • Loading branch information
mwaskom committed Dec 30, 2019
1 parent fdf0a2b commit dc10c4c
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 9 deletions.
2 changes: 2 additions & 0 deletions doc/releases/v0.9.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ New features

- It is now possible to force a categorical interpretation of the ``hue`` varaible in a relational plot by passing the name of a categorical palette (e.g. ``"deep"``, or ``"Set2"``). This complements the (previously supported) option of passig a list/dict of colors.

- Added the ability to pass hierarchical label names to the :class:`FacetGrid` legend, which also fixes a bug in :func:`relplot` when the same label appeared in diffent semantics.

Bug fixes and adaptations
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
19 changes: 15 additions & 4 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def add_legend(self, legend_data=None, title=None, label_order=None,
Parameters
----------
legend_data : dict, optional
Dictionary mapping label names to matplotlib artist handles. The
Dictionary mapping label names (or two-element tuples where the
second element is a label name) to matplotlib artist handles. The
default reads from ``self._legend_data``.
title : string, optional
Title for the legend. The default reads from ``self._hue_var``.
Expand All @@ -61,7 +62,8 @@ def add_legend(self, legend_data=None, title=None, label_order=None,
"""
# Find the data for the legend
legend_data = self._legend_data if legend_data is None else legend_data
if legend_data is None:
legend_data = self._legend_data
if label_order is None:
if self.hue_names is None:
label_order = list(legend_data.keys())
Expand All @@ -76,6 +78,15 @@ def add_legend(self, legend_data=None, title=None, label_order=None,
except TypeError: # labelsize is something like "large"
title_size = mpl.rcParams["axes.labelsize"]

# Unpack nested labels from a hierarchical legend
labels = []
for entry in label_order:
if isinstance(entry, tuple):
_, label = entry
else:
label = entry
labels.append(label)

# Set default legend kwargs
kwargs.setdefault("scatterpoints", 1)

Expand All @@ -84,7 +95,7 @@ def add_legend(self, legend_data=None, title=None, label_order=None,
kwargs.setdefault("frameon", False)

# Draw a full-figure legend outside the grid
figlegend = self.fig.legend(handles, label_order, "center right",
figlegend = self.fig.legend(handles, labels, "center right",
**kwargs)
self._legend = figlegend
figlegend.set_title(title, prop={"size": title_size})
Expand Down Expand Up @@ -115,7 +126,7 @@ def add_legend(self, legend_data=None, title=None, label_order=None,
else:
# Draw a legend in the first axis
ax = self.axes.flat[0]
leg = ax.legend(handles, label_order, loc="best", **kwargs)
leg = ax.legend(handles, labels, loc="best", **kwargs)
leg.set_title(title, prop={"size": title_size})

return self
Expand Down
4 changes: 2 additions & 2 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,8 @@ def update(var_name, val_name, **kws):
artist = func([], [], label=label, **use_kws)
if self._legend_func == "plot":
artist = artist[0]
legend_data[label] = artist
legend_order.append(label)
legend_data[key] = artist
legend_order.append(key)

self.legend_data = legend_data
self.legend_order = legend_order
Expand Down
12 changes: 12 additions & 0 deletions seaborn/tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,18 @@ def test_get_boolean_legend_data(self):
for label, level in zip(labels, b_levels):
nt.assert_equal(label.get_text(), level)

def test_legend_tuples(self):

g = ag.FacetGrid(self.df, hue="a")
g.map(plt.plot, "x", "y")

handles, labels = g.ax.get_legend_handles_labels()
label_tuples = [("", l) for l in labels]
legend_data = dict(zip(label_tuples, handles))
g.add_legend(legend_data, label_tuples)
for entry, label in zip(g._legend.get_texts(), labels):
assert entry.get_text() == label

def test_legend_options(self):

g1 = ag.FacetGrid(self.df, hue="b")
Expand Down
18 changes: 15 additions & 3 deletions seaborn/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,12 @@ def test_relplot_styles(self, long_df):
expected_paths = [paths[val] for val in grp_df["a"]]
assert self.paths_equal(points.get_paths(), expected_paths)

def test_relplot_stringy_numerics(self, long_df):

long_df["x_str"] = long_df["x"].astype(str)
g = rel.relplot(x="x", y="y", hue="x_str", data=long_df)
assert g._legend.texts[0].get_text() == "x_str"

def test_relplot_legend(self, long_df):

g = rel.relplot(x="x", y="y", data=long_df)
Expand All @@ -1723,6 +1729,12 @@ def test_relplot_legend(self, long_df):
g = rel.relplot(x="x", y="y", hue="a", legend=False, data=long_df)
assert g._legend is None

long_df["x_str"] = long_df["x"].astype(str)
g = rel.relplot(x="x", y="y", hue="x_str", data=long_df)
assert g._legend.texts[0].get_text() == "x_str"
palette = color_palette("deep", len(long_df["b"].unique()))
a_like_b = dict(zip(long_df["a"].unique(), long_df["b"].unique()))
long_df["a_like_b"] = long_df["a"].map(a_like_b)
g = rel.relplot(x="x", y="y", hue="b", style="a_like_b",
palette=palette, kind="line", estimator=None,
data=long_df)
lines = g._legend.get_lines()[1:] # Chop off title dummy
for line, color in zip(lines, palette):
assert line.get_color() == color

0 comments on commit dc10c4c

Please sign in to comment.