Skip to content

Commit

Permalink
Refactor PairGrid bivariate plotting code
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Sep 4, 2019
1 parent 1cf824b commit b27d8e6
Showing 1 changed file with 71 additions and 125 deletions.
196 changes: 71 additions & 125 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,37 +1321,55 @@ def map(self, func, **kwargs):
called ``color`` and ``label``.
"""
kw_color = kwargs.pop("color", None)
for i, y_var in enumerate(self.y_vars):
for j, x_var in enumerate(self.x_vars):
hue_grouped = self.data.groupby(self.hue_vals)
for k, label_k in enumerate(self.hue_names):
row_indices, col_indices = np.indices(self.axes.shape)
indices = zip(row_indices.flat, col_indices.flat)
self._map_bivariate(func, indices, **kwargs)
return self

def map_lower(self, func, **kwargs):
"""Plot with a bivariate function on the lower diagonal subplots.
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = pd.DataFrame(columns=self.data.columns,
dtype=np.float)
"""
indices = zip(*np.tril_indices_from(self.axes, -1))
self._map_bivariate(func, indices, **kwargs)
return self

ax = self.axes[i, j]
plt.sca(ax)
def map_upper(self, func, **kwargs):
"""Plot with a bivariate function on the upper diagonal subplots.
# Insert the other hue aesthetics if appropriate
for kw, val_list in self.hue_kws.items():
kwargs[kw] = val_list[k]
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
color = self.palette[k] if kw_color is None else kw_color
func(data_k[x_var], data_k[y_var],
label=label_k, color=color, **kwargs)
"""
indices = zip(*np.triu_indices_from(self.axes, 1))
self._map_bivariate(func, indices, **kwargs)
return self

self._clean_axis(ax)
self._update_legend_data(ax)
def map_offdiag(self, func, **kwargs):
"""Plot with a bivariate function on the off-diagonal subplots.
if kw_color is not None:
kwargs["color"] = kw_color
self._add_axis_labels()
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
"""

self.map_lower(func, **kwargs)
self.map_upper(func, **kwargs)
return self

def map_diag(self, func, **kwargs):
Expand Down Expand Up @@ -1417,114 +1435,42 @@ def map_diag(self, func, **kwargs):

return self

def map_lower(self, func, **kwargs):
"""Plot with a bivariate function on the lower diagonal subplots.
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
"""
kw_color = kwargs.pop("color", None)
for i, j in zip(*np.tril_indices_from(self.axes, -1)):
hue_grouped = self.data.groupby(self.hue_vals)
for k, label_k in enumerate(self.hue_names):

# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = pd.DataFrame(columns=self.data.columns,
dtype=np.float)

ax = self.axes[i, j]
plt.sca(ax)

x_var = self.x_vars[j]
y_var = self.y_vars[i]

# Insert the other hue aesthetics if appropriate
for kw, val_list in self.hue_kws.items():
kwargs[kw] = val_list[k]

color = self.palette[k] if kw_color is None else kw_color
func(data_k[x_var], data_k[y_var], label=label_k,
color=color, **kwargs)

self._clean_axis(ax)
self._update_legend_data(ax)

if kw_color is not None:
kwargs["color"] = kw_color
def _map_bivariate(self, func, indices, **kwargs):
"""Draw a bivariate plot on the axes indicated in indices."""
kws = kwargs.copy() # Use copy as we insert other kwargs
kw_color = kws.pop("color", None)
for i, j in indices:
x_var = self.x_vars[j]
y_var = self.y_vars[i]
ax = self.axes[i, j]
self._plot_bivariate(x_var, y_var, ax, func, kw_color, **kws)
self._add_axis_labels()

return self

def map_upper(self, func, **kwargs):
"""Plot with a bivariate function on the upper diagonal subplots.
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
"""
kw_color = kwargs.pop("color", None)
for i, j in zip(*np.triu_indices_from(self.axes, 1)):

hue_grouped = self.data.groupby(self.hue_vals)

for k, label_k in enumerate(self.hue_names):

# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = pd.DataFrame(columns=self.data.columns,
dtype=np.float)

ax = self.axes[i, j]
plt.sca(ax)

x_var = self.x_vars[j]
y_var = self.y_vars[i]

# Insert the other hue aesthetics if appropriate
for kw, val_list in self.hue_kws.items():
kwargs[kw] = val_list[k]

color = self.palette[k] if kw_color is None else kw_color
func(data_k[x_var], data_k[y_var], label=label_k,
color=color, **kwargs)

self._clean_axis(ax)
self._update_legend_data(ax)

if kw_color is not None:
kwargs["color"] = kw_color
def _plot_bivariate(self, x_var, y_var, ax, func, kw_color, **kwargs):
"""Draw a bivariate plot on the specified axes."""
plt.sca(ax)
hue_grouped = self.data.groupby(self.hue_vals)
for k, label_k in enumerate(self.hue_names):

return self
# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = pd.DataFrame(columns=self.data.columns,
dtype=np.float)

def map_offdiag(self, func, **kwargs):
"""Plot with a bivariate function on the off-diagonal subplots.
# Insert the other hue aesthetics if appropriate
for kw, val_list in self.hue_kws.items():
kwargs[kw] = val_list[k]

Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
x = data_k[x_var]
y = data_k[y_var]

"""
color = self.palette[k] if kw_color is None else kw_color
func(x, y, label=label_k, color=color, **kwargs)

self.map_lower(func, **kwargs)
self.map_upper(func, **kwargs)
return self
self._clean_axis(ax)
self._update_legend_data(ax)

def _add_axis_labels(self):
"""Add labels to the left and bottom Axes."""
Expand Down

0 comments on commit b27d8e6

Please sign in to comment.