Skip to content

Commit

Permalink
Add some tests for pandas index alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Dec 26, 2019
1 parent e8b3c40 commit 604fc0d
Showing 1 changed file with 139 additions and 1 deletion.
140 changes: 139 additions & 1 deletion seaborn/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ def test_longform_groupby(self):
for group, hues in zip(["a", "b", "c"], p.plot_hues):
npt.assert_array_equal(hues, self.h[self.g == group])

# Test grouped data that matches on index
p1 = cat._CategoricalPlotter()
p1.establish_variables(self.g, self.y, self.h)
p2 = cat._CategoricalPlotter()
p2.establish_variables(self.g, self.y[::-1], self.h)
for i, (d1, d2) in enumerate(zip(p1.plot_data, p2.plot_data)):
assert np.array_equal(d1.sort_index(), d2.sort_index())

def test_input_validation(self):

p = cat._CategoricalPlotter()
Expand All @@ -260,6 +268,12 @@ def test_input_validation(self):
with nt.assert_raises(ValueError):
p.establish_variables(**input_kws)

g_prime = pd.Series(self.g.values, np.roll(self.g.index, 2))
with pytest.warns(UserWarning):
p.establish_variables(x=g_prime, y=self.y)
with pytest.warns(UserWarning):
p.establish_variables(x=self.g, y=self.y, hue=g_prime)

def test_order(self):

p = cat._CategoricalPlotter()
Expand Down Expand Up @@ -831,6 +845,22 @@ def test_missing_data(self):

plt.close("all")

def test_unaligned_index(self):

f, (ax1, ax2) = plt.subplots(2)
cat.boxplot(self.g, self.y, ax=ax1)
cat.boxplot(self.g, self.y.sample(frac=1), ax=ax2)
for l1, l2 in zip(ax1.lines, ax2.lines):
assert np.array_equal(l1.get_xydata(), l2.get_xydata())

f, (ax1, ax2) = plt.subplots(2)
hue_order = self.h.unique()
cat.boxplot(self.g, self.y, self.h, hue_order=hue_order, ax=ax1)
cat.boxplot(self.g, self.y.sample(frac=1), self.h.sample(frac=1),
hue_order=hue_order, ax=ax2)
for l1, l2 in zip(ax1.lines, ax2.lines):
assert np.array_equal(l1.get_xydata(), l2.get_xydata())

def test_boxplots(self):

# Smoke test the high level boxplot options
Expand Down Expand Up @@ -1536,7 +1566,7 @@ def test_hue_point_colors(self):

for i, group_colors in enumerate(point_colors):
for j, point_color in enumerate(group_colors):
hue_level = p.plot_hues[i][j]
hue_level = np.asarray(p.plot_hues[i])[j]
nt.assert_equal(tuple(point_color),
deep_colors[hue_order.index(hue_level)])

Expand Down Expand Up @@ -1672,6 +1702,29 @@ def test_three_strip_points(self):
nt.assert_equal(facecolors.shape, (3, 4))
npt.assert_array_equal(facecolors[0], facecolors[1])

def test_unaligned_index(self):

f, (ax1, ax2) = plt.subplots(2)
cat.stripplot(self.g, self.y, ax=ax1)
cat.stripplot(self.g, self.y.sample(frac=1), ax=ax2)
for p1, p2 in zip(ax1.collections, ax2.collections):
y1, y2 = p1.get_offsets()[:, 1], p2.get_offsets()[:, 1]
assert np.array_equal(np.sort(y1), np.sort(y2))
assert np.array_equal(p1.get_facecolors()[np.argsort(y1)],
p2.get_facecolors()[np.argsort(y2)])

f, (ax1, ax2) = plt.subplots(2)
hue_order = self.h.unique()
cat.stripplot(self.g, self.y, self.h,
hue_order=hue_order, ax=ax1)
cat.stripplot(self.g, self.y.sample(frac=1), self.h.sample(frac=1),
hue_order=hue_order, ax=ax2)
for p1, p2 in zip(ax1.collections, ax2.collections):
y1, y2 = p1.get_offsets()[:, 1], p2.get_offsets()[:, 1]
assert np.array_equal(np.sort(y1), np.sort(y2))
assert np.array_equal(p1.get_facecolors()[np.argsort(y1)],
p2.get_facecolors()[np.argsort(y2)])


class TestSwarmPlotter(CategoricalFixture):

Expand Down Expand Up @@ -1823,6 +1876,29 @@ def test_nested_swarmplot_horizontal(self):

npt.assert_equal(fc[:3], pal[hue_names.index(hue)])

def test_unaligned_index(self):

f, (ax1, ax2) = plt.subplots(2)
cat.swarmplot(self.g, self.y, ax=ax1)
cat.swarmplot(self.g, self.y.sample(frac=1), ax=ax2)
for p1, p2 in zip(ax1.collections, ax2.collections):
assert np.allclose(p1.get_offsets()[:, 1],
p2.get_offsets()[:, 1])
assert np.array_equal(p1.get_facecolors(),
p2.get_facecolors())

f, (ax1, ax2) = plt.subplots(2)
hue_order = self.h.unique()
cat.swarmplot(self.g, self.y, self.h,
hue_order=hue_order, ax=ax1)
cat.swarmplot(self.g, self.y.sample(frac=1), self.h.sample(frac=1),
hue_order=hue_order, ax=ax2)
for p1, p2 in zip(ax1.collections, ax2.collections):
assert np.allclose(p1.get_offsets()[:, 1],
p2.get_offsets()[:, 1])
assert np.array_equal(p1.get_facecolors(),
p2.get_facecolors())


class TestBarPlotter(CategoricalFixture):

Expand Down Expand Up @@ -1992,6 +2068,31 @@ def test_draw_missing_bars(self):

plt.close("all")

def test_unaligned_index(self):

f, (ax1, ax2) = plt.subplots(2)
cat.barplot(self.g, self.y, ci="sd", ax=ax1)
cat.barplot(self.g, self.y.sample(frac=1), ci="sd", ax=ax2)
for l1, l2 in zip(ax1.lines, ax2.lines):
assert pytest.approx(l1.get_xydata()) == l2.get_xydata()
for p1, p2 in zip(ax1.patches, ax2.patches):
assert pytest.approx(p1.get_height()) == p2.get_height()
assert p1.get_width() == p2.get_width()
assert p1.get_xy() == p2.get_xy()

f, (ax1, ax2) = plt.subplots(2)
hue_order = self.h.unique()
cat.barplot(self.g, self.y, self.h, hue_order=hue_order, ci="sd",
ax=ax1)
cat.barplot(self.g, self.y.sample(frac=1), self.h.sample(frac=1),
hue_order=hue_order, ci="sd", ax=ax2)
for l1, l2 in zip(ax1.lines, ax2.lines):
assert pytest.approx(l1.get_xydata()) == l2.get_xydata()
for p1, p2 in zip(ax1.patches, ax2.patches):
assert pytest.approx(p1.get_height()) == p2.get_height()
assert p1.get_width() == p2.get_width()
assert p1.get_xy() == p2.get_xy()

def test_barplot_colors(self):

# Test unnested palette colors
Expand Down Expand Up @@ -2222,6 +2323,27 @@ def test_draw_missing_points(self):
f, ax = plt.subplots()
p.draw_points(ax)

def test_unaligned_index(self):

f, (ax1, ax2) = plt.subplots(2)
cat.pointplot(self.g, self.y, ci="sd", ax=ax1)
cat.pointplot(self.g, self.y.sample(frac=1), ci="sd", ax=ax2)
for l1, l2 in zip(ax1.lines, ax2.lines):
assert pytest.approx(l1.get_xydata()) == l2.get_xydata()
for p1, p2 in zip(ax1.collections, ax2.collections):
assert pytest.approx(p1.get_offsets()) == p2.get_offsets()

f, (ax1, ax2) = plt.subplots(2)
hue_order = self.h.unique()
cat.pointplot(self.g, self.y, self.h,
hue_order=hue_order, ci="sd", ax=ax1)
cat.pointplot(self.g, self.y.sample(frac=1), self.h.sample(frac=1),
hue_order=hue_order, ci="sd", ax=ax2)
for l1, l2 in zip(ax1.lines, ax2.lines):
assert pytest.approx(l1.get_xydata()) == l2.get_xydata()
for p1, p2 in zip(ax1.collections, ax2.collections):
assert pytest.approx(p1.get_offsets()) == p2.get_offsets()

def test_pointplot_colors(self):

# Test a single-color unnested plot
Expand Down Expand Up @@ -2622,6 +2744,22 @@ def test_draw_missing_boxes(self):
nt.assert_equal(len(list(patches)), 3)
plt.close("all")

def test_unaligned_index(self):

f, (ax1, ax2) = plt.subplots(2)
cat.boxenplot(self.g, self.y, ax=ax1)
cat.boxenplot(self.g, self.y.sample(frac=1), ax=ax2)
for l1, l2 in zip(ax1.lines, ax2.lines):
assert np.array_equal(l1.get_xydata(), l2.get_xydata())

f, (ax1, ax2) = plt.subplots(2)
hue_order = self.h.unique()
cat.boxenplot(self.g, self.y, self.h, hue_order=hue_order, ax=ax1)
cat.boxenplot(self.g, self.y.sample(frac=1), self.h.sample(frac=1),
hue_order=hue_order, ax=ax2)
for l1, l2 in zip(ax1.lines, ax2.lines):
assert np.array_equal(l1.get_xydata(), l2.get_xydata())

def test_missing_data(self):

x = ["a", "a", "b", "b", "c", "c", "d", "d"]
Expand Down

0 comments on commit 604fc0d

Please sign in to comment.