Skip to content

Commit

Permalink
Fix broken color tests by manually backporting assert_colors_equal
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Aug 7, 2021
1 parent e241cf1 commit 71eb8d1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
19 changes: 19 additions & 0 deletions seaborn/_testing.py
@@ -1,5 +1,6 @@
import numpy as np
import matplotlib as mpl
from matplotlib.colors import to_rgb, to_rgba
from numpy.testing import assert_array_equal


Expand Down Expand Up @@ -38,6 +39,24 @@
]


def assert_colors_equal(a, b, check_alpha=True):

def handle_array(x):

if isinstance(x, np.ndarray):
if x.ndim > 1:
x = np.unique(x, axis=0).squeeze()
if x.ndim > 1:
raise ValueError("Color arrays must be 1 dimensional")
return x

a = handle_array(a)
b = handle_array(b)

f = to_rgba if check_alpha else to_rgb
assert f(a) == f(b)


def assert_artists_equal(list1, list2, properties):

assert len(list1) == len(list2)
Expand Down
9 changes: 5 additions & 4 deletions seaborn/tests/test_axisgrid.py
Expand Up @@ -20,6 +20,7 @@
from .. import axisgrid as ag
from .._testing import (
assert_plots_equal,
assert_colors_equal,
)

rs = np.random.RandomState(0)
Expand Down Expand Up @@ -965,21 +966,20 @@ def test_map_diag_rectangular(self):
def test_map_diag_color(self):

color = "red"
rgb_color = mpl.colors.colorConverter.to_rgba(color)

g1 = ag.PairGrid(self.df)
g1.map_diag(plt.hist, color=color)

for ax in g1.diag_axes:
for patch in ax.patches:
assert patch.get_facecolor() == rgb_color
assert_colors_equal(patch.get_facecolor(), color)

g2 = ag.PairGrid(self.df)
g2.map_diag(kdeplot, color='red')

for ax in g2.diag_axes:
for line in ax.lines:
assert line.get_color() == color
assert_colors_equal(line.get_color(), color)

def test_map_diag_palette(self):

Expand All @@ -990,7 +990,7 @@ def test_map_diag_palette(self):

for ax in g.diag_axes:
for line, color in zip(ax.lines[::-1], pal):
assert line.get_color() == color
assert_colors_equal(line.get_color(), color)

def test_map_diag_and_offdiag(self):

Expand Down Expand Up @@ -1609,6 +1609,7 @@ def test_refline(self):
npt.assert_array_equal(g.ax_joint.lines[-1].get_xydata(), hline)
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)


class TestJointPlot:

rs = np.random.RandomState(sum(map(ord, "jointplot")))
Expand Down
3 changes: 3 additions & 0 deletions seaborn/tests/test_matrix.py
Expand Up @@ -28,6 +28,9 @@
_no_fastcluster = True


# Copied from master onto v0.11 here to fix break introduced by
# cherry pick commit 49fbd353

class TestHeatmap:
rs = np.random.RandomState(sum(map(ord, "heatmap")))

Expand Down

0 comments on commit 71eb8d1

Please sign in to comment.