Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hue_order in PairGrid #547

Merged
merged 2 commits into from
May 9, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/releases/v0.6.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,5 @@ Bug fixes
- Fixed a bug in :class:`FacetGrid` and :class:`PairGrid` that lead to incorrect legend labels when levels of the ``hue`` variable appeared in ``hue_order`` but not in the data.

- Fixed a bug in :meth:`FacetGrid.set_xticklabels` or :meth:`FacetGrid.set_yticklabels` when ``col_wrap`` is being used.

- Fixed a bug in :class:`PairGrid` where the ``hue_order`` parameter was ignored.
48 changes: 38 additions & 10 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,14 +915,11 @@ def __init__(self, data, hue=None, hue_order=None, palette=None,
# Sort out the hue variable
self._hue_var = hue
if hue is None:
self.hue_names = None
self.hue_names = ["_nolegend_"]
self.hue_vals = pd.Series(["_nolegend_"] * len(data),
index=data.index)
else:
if hue_order is None:
hue_names = np.unique(np.sort(data[hue]))
else:
hue_names = hue_order
hue_names = utils.categorical_order(data[hue], hue_order)
if dropna:
# Filter NA from the list of unique hue names
hue_names = list(filter(pd.notnull, hue_names))
Expand Down Expand Up @@ -954,7 +951,14 @@ def map(self, func, **kwargs):
for i, y_var in enumerate(self.y_vars):
for j, x_var in enumerate(self.x_vars):
hue_grouped = self.data.groupby(self.hue_vals)
for k, (label_k, data_k) in enumerate(hue_grouped):
for k, label_k in enumerate(self.hue_names):

# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = np.array([])

ax = self.axes[i, j]
plt.sca(ax)

Expand Down Expand Up @@ -1008,11 +1012,22 @@ def map_diag(self, func, **kwargs):
# Special-case plt.hist with stacked bars
if func is plt.hist:
plt.sca(ax)
vals = [v.values for g, v in hue_grouped]
vals = []
for label in self.hue_names:
# Attempt to get data for this level, allowing for empty
try:
vals.append(hue_grouped.get_group(label))
except KeyError:
vals.append(np.array([]))
func(vals, color=self.palette, histtype="barstacked",
**kwargs)
else:
for k, (label_k, data_k) in enumerate(hue_grouped):
for k, label_k in enumerate(self.hue_names):
# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = np.array([])
plt.sca(ax)
func(data_k, label=label_k,
color=self.palette[k], **kwargs)
Expand All @@ -1034,7 +1049,13 @@ def map_lower(self, func, **kwargs):
kw_color = kwargs.pop("color", None)
for i, j in zip(*np.tril_indices_from(self.axes, -1)):
hue_grouped = self.data.groupby(self.hue_vals)
for k, (label_k, data_k) in enumerate(hue_grouped):
for k, label_k in enumerate(self.hue_names):

# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = np.array([])

ax = self.axes[i, j]
plt.sca(ax)
Expand Down Expand Up @@ -1071,7 +1092,14 @@ def map_upper(self, func, **kwargs):
for i, j in zip(*np.triu_indices_from(self.axes, 1)):

hue_grouped = self.data.groupby(self.hue_vals)
for k, (label_k, data_k) in enumerate(hue_grouped):

for k, label_k in enumerate(self.hue_names):

# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = np.array([])

ax = self.axes[i, j]
plt.sca(ax)
Expand Down
53 changes: 53 additions & 0 deletions seaborn/tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,59 @@ def test_hue_kws(self):
for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
nt.assert_equal(line.get_marker(), marker)

g = ag.PairGrid(self.df, hue="a", hue_kws=kws,
hue_order=list("dcab"))
g.map(plt.plot)

for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
nt.assert_equal(line.get_marker(), marker)

plt.close("all")

@skipif(old_matplotlib)
def test_hue_order(self):

order = list("dcab")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map(plt.plot)

for line, level in zip(g.axes[1, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])

plt.close("all")

g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_diag(plt.plot)

for line, level in zip(g.axes[0, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])

plt.close("all")

g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_lower(plt.plot)

for line, level in zip(g.axes[1, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])

plt.close("all")

g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_upper(plt.plot)

for line, level in zip(g.axes[0, 1].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "y"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])

plt.close("all")

def test_nondefault_index(self):

df = self.df.copy().set_index("b")
Expand Down