Skip to content

Commit

Permalink
Copy props from old cmap when creating new cmap in heatmap (#1948)
Browse files Browse the repository at this point in the history
* preserve cmap props when centering

* preserve explicitly set extremes

(cherry picked from commit e760254)
  • Loading branch information
MaozGelbart authored and mwaskom committed Feb 22, 2020
1 parent 488c167 commit 4b2297b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
18 changes: 18 additions & 0 deletions seaborn/matrix.py
Expand Up @@ -216,11 +216,29 @@ def _determine_cmap_params(self, plot_data, vmin, vmax,

# Recenter a divergent colormap
if center is not None:

# Copy bad values
# in mpl<3.2 only masked values are honored with "bad" color spec
# (see https://github.com/matplotlib/matplotlib/pull/14257)
bad = self.cmap(np.ma.masked_invalid([np.nan]))[0]

# under/over values are set for sure when cmap extremes
# do not map to the same color as +-inf
under = self.cmap(-np.inf)
over = self.cmap(np.inf)
under_set = under != self.cmap(0)
over_set = over != self.cmap(self.cmap.N - 1)

vrange = max(vmax - center, center - vmin)
normlize = mpl.colors.Normalize(center - vrange, center + vrange)
cmin, cmax = normlize([vmin, vmax])
cc = np.linspace(cmin, cmax, 256)
self.cmap = mpl.colors.ListedColormap(self.cmap(cc))
self.cmap.set_bad(bad)
if under_set:
self.cmap.set_under(under)
if over_set:
self.cmap.set_over(over)

def _annotate_heatmap(self, ax, mesh):
"""Add textual labels with the value in each cell."""
Expand Down
40 changes: 40 additions & 0 deletions seaborn/tests/test_matrix.py
@@ -1,5 +1,6 @@
import itertools
import tempfile
import copy

import numpy as np
import matplotlib as mpl
Expand Down Expand Up @@ -200,6 +201,45 @@ def test_custom_center_colors(self):
fc = ax.collections[0].get_facecolors()
npt.assert_array_almost_equal(fc, cmap(vals), 2)

def test_cmap_with_properties(self):

kws = self.default_kws.copy()
cmap = copy.copy(mpl.cm.get_cmap("BrBG"))
cmap.set_bad("red")
kws["cmap"] = cmap
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(
cmap(np.ma.masked_invalid([np.nan])),
hm.cmap(np.ma.masked_invalid([np.nan])))

kws["center"] = 0.5
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(
cmap(np.ma.masked_invalid([np.nan])),
hm.cmap(np.ma.masked_invalid([np.nan])))

kws = self.default_kws.copy()
cmap = copy.copy(mpl.cm.get_cmap("BrBG"))
cmap.set_under("red")
kws["cmap"] = cmap
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))

kws["center"] = .5
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))

kws = self.default_kws.copy()
cmap = copy.copy(mpl.cm.get_cmap("BrBG"))
cmap.set_over("red")
kws["cmap"] = cmap
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))

kws["center"] = .5
hm = mat._HeatMapper(self.df_unif, **kws)
npt.assert_array_equal(cmap(np.inf), hm.cmap(np.inf))

def test_tickabels_off(self):
kws = self.default_kws.copy()
kws['xticklabels'] = False
Expand Down

0 comments on commit 4b2297b

Please sign in to comment.