diff --git a/seaborn/tests/test_palettes.py b/seaborn/tests/test_palettes.py index 0761897dea..6ddebcc05f 100644 --- a/seaborn/tests/test_palettes.py +++ b/seaborn/tests/test_palettes.py @@ -5,7 +5,6 @@ import pytest import nose.tools as nt import numpy.testing as npt -import matplotlib.pyplot as plt from .. import palettes, utils, rcmod from ..external import husl @@ -354,11 +353,3 @@ def test_preserved_palette_length(self): pal_in = palettes.color_palette("Set1", 10) pal_out = palettes.color_palette(pal_in) nt.assert_equal(pal_in, pal_out) - - def test_get_color_cycle(self): - - colors = [(1., 0., 0.), (0, 1., 0.)] - prop_cycle = plt.cycler(color=colors) - with plt.rc_context({"axes.prop_cycle": prop_cycle}): - result = utils.get_color_cycle() - assert result == colors diff --git a/seaborn/tests/test_utils.py b/seaborn/tests/test_utils.py index e86aa4c843..ee2cc28540 100644 --- a/seaborn/tests/test_utils.py +++ b/seaborn/tests/test_utils.py @@ -5,6 +5,7 @@ import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt +from cycler import cycler import pytest import nose import nose.tools as nt @@ -374,6 +375,23 @@ def test_locator_to_legend_entries(): assert str_levels == ['1e-07', '1e-05', '1e-03', '1e-01', '10'] +@pytest.mark.parametrize( + "cycler,result", + [ + (cycler(color=["y"]), ["y"]), + (cycler(color=["k"]), ["k"]), + (cycler(color=["k", "y"]), ["k", "y"]), + (cycler(color=["y", "k"]), ["y", "k"]), + (cycler(color=["b", "r"]), ["b", "r"]), + (cycler(color=["r", "b"]), ["r", "b"]), + (cycler(lw=[1, 2]), [".15"]), # no color in cycle + ], +) +def test_get_color_cycle(cycler, result): + with mpl.rc_context(rc={"axes.prop_cycle": cycler}): + assert utils.get_color_cycle() == result + + def check_load_dataset(name): ds = load_dataset(name, cache=False) assert(isinstance(ds, pd.DataFrame)) diff --git a/seaborn/utils.py b/seaborn/utils.py index 0aae2c25c0..f2615606f5 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -555,8 +555,20 @@ def get_view_interval(self): def get_color_cycle(): - """Return the list of colors in the current matplotlib color cycle.""" - return [x['color'] for x in mpl.rcParams['axes.prop_cycle']] + """Return the list of colors in the current matplotlib color cycle + + Parameters + ---------- + None + + Returns + ------- + colors : list + List of matplotlib colors in the current cycle, or dark gray if + the current color cycle is empty. + """ + cycler = mpl.rcParams['axes.prop_cycle'] + return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"] def relative_luminance(color):