Skip to content

Commit

Permalink
FIX: Ignore masked cells when finding heatmap data limits (#1956)
Browse files Browse the repository at this point in the history
* respect mask when setting heatmap limits

* respect mask when setting heatmap limits

* improve code style

(cherry picked from commit 4757065)
  • Loading branch information
MaozGelbart authored and mwaskom committed Feb 26, 2020
1 parent b5afc80 commit ceea0e1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
14 changes: 11 additions & 3 deletions seaborn/matrix.py
Expand Up @@ -194,11 +194,19 @@ def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt,
def _determine_cmap_params(self, plot_data, vmin, vmax,
cmap, center, robust):
"""Use some heuristics to set good defaults for colorbar and range."""
calc_data = plot_data.data[~np.isnan(plot_data.data)]

# plot_data is a np.ma.array instance
calc_data = plot_data.filled(np.nan)
if vmin is None:
vmin = np.percentile(calc_data, 2) if robust else calc_data.min()
if robust:
vmin = np.nanpercentile(calc_data, 2)
else:
vmin = np.nanmin(calc_data)
if vmax is None:
vmax = np.percentile(calc_data, 98) if robust else calc_data.max()
if robust:
vmax = np.nanpercentile(calc_data, 98)
else:
vmax = np.nanmax(calc_data)
self.vmin, self.vmax = vmin, vmax

# Choose default colormaps if not provided
Expand Down
19 changes: 19 additions & 0 deletions seaborn/tests/test_matrix.py
Expand Up @@ -97,6 +97,25 @@ def test_mask_input(self):

npt.assert_array_equal(p.plot_data, plot_data)

def test_mask_limits(self):
"""Make sure masked cells are not used to calculate extremes"""

kws = self.default_kws.copy()

mask = self.x_norm > 0
kws['mask'] = mask
p = mat._HeatMapper(self.x_norm, **kws)

assert p.vmax == np.ma.array(self.x_norm, mask=mask).max()
assert p.vmin == np.ma.array(self.x_norm, mask=mask).min()

mask = self.x_norm < 0
kws['mask'] = mask
p = mat._HeatMapper(self.x_norm, **kws)

assert p.vmin == np.ma.array(self.x_norm, mask=mask).min()
assert p.vmax == np.ma.array(self.x_norm, mask=mask).max()

def test_default_vlims(self):

p = mat._HeatMapper(self.df_unif, **self.default_kws)
Expand Down

0 comments on commit ceea0e1

Please sign in to comment.