Skip to content

Commit

Permalink
Fix PairGrid with non-square grid and non-marginal diagonal axes
Browse files Browse the repository at this point in the history
Fixes #2260
  • Loading branch information
mwaskom committed Sep 11, 2020
1 parent d7cc655 commit 257efe7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
5 changes: 5 additions & 0 deletions doc/releases/v0.11.1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

v0.11.1 (Unreleased)
--------------------

- Fix bug in :class:`PairGrid`/:func:`pairplot` where diagonal axes would be empty when the grid was not square and the diagonal axes did not contain the marginal plots. (:pr:`2270`)
15 changes: 11 additions & 4 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,10 +1292,17 @@ def map_offdiag(self, func, **kwargs):
called ``color`` and ``label``.
"""

self.map_lower(func, **kwargs)
if not self._corner:
self.map_upper(func, **kwargs)
if self.square_grid:
self.map_lower(func, **kwargs)
if not self._corner:
self.map_upper(func, **kwargs)
else:
indices = []
for i, (y_var) in enumerate(self.y_vars):
for j, (x_var) in enumerate(self.x_vars):
if x_var != y_var:
indices.append((i, j))
self._map_bivariate(func, indices, **kwargs)
return self

def map_diag(self, func, **kwargs):
Expand Down
34 changes: 28 additions & 6 deletions seaborn/tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,40 +878,62 @@ def test_map_diag(self):
def test_map_diag_rectangular(self):

x_vars = ["x", "y"]
y_vars = ["x", "y", "z"]
y_vars = ["x", "z", "y"]
g1 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
g1.map_diag(plt.hist)
g1.map_offdiag(plt.scatter)

assert set(g1.diag_vars) == (set(x_vars) & set(y_vars))

for var, ax in zip(g1.diag_vars, g1.diag_axes):
nt.assert_equal(len(ax.patches), 10)
assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()

for i, ax in enumerate(np.diag(g1.axes)):
assert ax.bbox.bounds == g1.diag_axes[i].bbox.bounds
for j, x_var in enumerate(x_vars):
for i, y_var in enumerate(y_vars):

ax = g1.axes[i, j]
if x_var == y_var:
diag_ax = g1.diag_axes[j] # because fewer x than y vars
assert ax.bbox.bounds == diag_ax.bbox.bounds

else:
x, y = ax.collections[0].get_offsets().T
assert_array_equal(x, self.df[x_var])
assert_array_equal(y, self.df[y_var])

g2 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars, hue="a")
g2.map_diag(plt.hist)
g2.map_offdiag(plt.scatter)

assert set(g2.diag_vars) == (set(x_vars) & set(y_vars))

for ax in g2.diag_axes:
nt.assert_equal(len(ax.patches), 30)

x_vars = ["x", "y", "z"]
y_vars = ["x", "y"]
y_vars = ["x", "z"]
g3 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
g3.map_diag(plt.hist)
g3.map_offdiag(plt.scatter)

assert set(g3.diag_vars) == (set(x_vars) & set(y_vars))

for var, ax in zip(g3.diag_vars, g3.diag_axes):
nt.assert_equal(len(ax.patches), 10)
assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()

for i, ax in enumerate(np.diag(g3.axes)):
assert ax.bbox.bounds == g3.diag_axes[i].bbox.bounds
for j, x_var in enumerate(x_vars):
for i, y_var in enumerate(y_vars):

ax = g3.axes[i, j]
if x_var == y_var:
diag_ax = g3.diag_axes[i] # because fewer y than x vars
assert ax.bbox.bounds == diag_ax.bbox.bounds
else:
x, y = ax.collections[0].get_offsets().T
assert_array_equal(x, self.df[x_var])
assert_array_equal(y, self.df[y_var])

def test_map_diag_color(self):

Expand Down

0 comments on commit 257efe7

Please sign in to comment.