From d40c0b0147412f37db7ec95c620b5e223c398765 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 30 Aug 2020 21:04:43 -0400 Subject: [PATCH] Delegate hue in PairGrid to plotting functions (#2234) * Add provisional support for delegating hue in PairGrid * Use histplot on pairplot diagonal * Update Pairgrid tests * Return self from map_diag * Improve test coverage * Convert PairGrid docstring to notebook and update * Fix test * Improve support for legends and markers with new plots * Make color/label injection optional * Update markers test * More flexibility in PairGrid * Convert pairplot API examples to notebook * Add public access to Grid legend object * Fix iterative plot_bivariate and improve test coverage * Don't cast diagonal data to array (fixes #1663) --- doc/docstrings/PairGrid.ipynb | 271 +++++++++++++++++++++ doc/docstrings/pairplot.ipynb | 225 ++++++++++++++++++ doc/tools/extract_examples.py | 2 +- seaborn/axisgrid.py | 413 ++++++++++++++------------------- seaborn/distributions.py | 5 +- seaborn/tests/test_axisgrid.py | 84 ++++++- 6 files changed, 747 insertions(+), 253 deletions(-) create mode 100644 doc/docstrings/PairGrid.ipynb create mode 100644 doc/docstrings/pairplot.ipynb diff --git a/doc/docstrings/PairGrid.ipynb b/doc/docstrings/PairGrid.ipynb new file mode 100644 index 0000000000..72268cfda9 --- /dev/null +++ b/doc/docstrings/PairGrid.ipynb @@ -0,0 +1,271 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns; sns.set()\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Calling the constructor sets up a blank grid of subplots with each row and one column corresponding to a numeric variable in the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "g = sns.PairGrid(penguins)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Passing a bivariate function to :meth:`PairGrid.map` will draw a bivariate plot on every axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins)\n", + "g.map(sns.scatterplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Passing separate functions to :meth:`PairGrid.map_diag` and :meth:`PairGrid.map_offdiag` will show each variable's marginal distribution on the diagonal:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins)\n", + "g.map_diag(sns.histplot)\n", + "g.map_offdiag(sns.scatterplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "It's also possible to use different functions on the upper and lower triangles of the plot (which are otherwise redundant):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, diag_sharey=False)\n", + "g.map_upper(sns.scatterplot)\n", + "g.map_lower(sns.kdeplot)\n", + "g.map_diag(sns.kdeplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Or to avoid the redundancy altogether:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, diag_sharey=False, corner=True)\n", + "g.map_lower(sns.scatterplot)\n", + "g.map_diag(sns.kdeplot)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The :class:`PairGrid` constructor accepts a ``hue`` variable. This variable is passed directly to functions that understand it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"species\")\n", + "g.map_diag(sns.histplot)\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "But you can also pass matplotlib functions, in which case a groupby is performed internally and a separate plot is drawn for each level:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"species\")\n", + "g.map_diag(plt.hist)\n", + "g.map_offdiag(plt.scatter)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Additional semantic variables can be assigned by passing data vectors directly while mapping the function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"species\")\n", + "g.map_diag(sns.histplot)\n", + "g.map_offdiag(sns.scatterplot, size=penguins[\"sex\"])\n", + "g.add_legend(title=\"\", adjust_subtitles=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "When using seaborn functions that can implement a numeric hue mapping, you will want to disable mapping of the variable on the diagonal axes. Note that the ``hue`` variable is excluded from the list of variables shown by default:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"body_mass_g\")\n", + "g.map_diag(sns.histplot, hue=None, color=\".3\")\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``vars`` parameter can be used to control exactly which variables are used:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "variables = [\"body_mass_g\", \"bill_length_mm\", \"flipper_length_mm\"]\n", + "g = sns.PairGrid(penguins, hue=\"body_mass_g\", vars=variables)\n", + "g.map_diag(sns.histplot, hue=None, color=\".3\")\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The plot need not be square: separate variables can be used to define the rows and columns:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x_vars = [\"body_mass_g\", \"bill_length_mm\", \"bill_depth_mm\", \"flipper_length_mm\"]\n", + "y_vars = [\"body_mass_g\"]\n", + "g = sns.PairGrid(penguins, hue=\"species\", x_vars=x_vars, y_vars=y_vars)\n", + "g.map_diag(sns.histplot, color=\".3\")\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "It can be useful to explore different approaches to resolving multiple distributions on the diagonal axes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.PairGrid(penguins, hue=\"species\")\n", + "g.map_diag(sns.histplot, multiple=\"stack\", element=\"step\")\n", + "g.map_offdiag(sns.scatterplot)\n", + "g.add_legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/docstrings/pairplot.ipynb b/doc/docstrings/pairplot.ipynb new file mode 100644 index 0000000000..97f509c365 --- /dev/null +++ b/doc/docstrings/pairplot.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set(style=\"ticks\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The simplest invocation uses :func:`scatterplot` for each pairing of the variables and :func:`histplot` for the marginal plots along the diagonal:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "penguins = sns.load_dataset(\"penguins\")\n", + "sns.pairplot(penguins)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Assigning a ``hue`` variable adds a semantic mapping and changes the default marginal plot to a layered kernel density estimate (KDE):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, hue=\"species\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "It's possible to force marginal histograms:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, hue=\"species\", diag_kind=\"hist\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``kind`` parameter determines both the diagonal and off-diagonal plotting style. Several options are available, including using :func:`kdeplot` to draw KDEs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, kind=\"kde\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Or :func:`histplot` to draw both bivariate and univariate histograms:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, kind=\"hist\")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``markers`` parameter applies a style mapping on the off-diagonal axes. Currently, it will be redundant with the ``hue`` variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, hue=\"species\", markers=[\"o\", \"s\", \"D\"])" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "As with other figure-level functions, the size of the figure is controlled by setting the ``height`` of each individual subplot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, height=1.5)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Use ``vars`` or ``x_vars`` and ``y_vars`` to select the variables to plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(\n", + " penguins,\n", + " x_vars=[\"bill_length_mm\", \"bill_depth_mm\", \"flipper_length_mm\"],\n", + " y_vars=[\"bill_length_mm\", \"bill_depth_mm\"],\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "Set ``corner=True`` to plot only the lower triangle:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(penguins, corner=True)" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The ``plot_kws`` and ``diag_kws`` parameters accept dicts of keyword arguments to customize the off-diagonal and diagonal plots, respectively:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.pairplot(\n", + " penguins,\n", + " plot_kws=dict(marker=\"+\", linewidth=1),\n", + " diag_kws=dict(fill=False),\n", + ")" + ] + }, + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "The return object is the underlying :class:`PairGrid`, which can be used to further customize the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.pairplot(penguins, diag_kind=\"kde\")\n", + "g.map_lower(sns.kdeplot, levels=4, color=\".2\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py38-latest", + "language": "python", + "name": "seaborn-py38-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/doc/tools/extract_examples.py b/doc/tools/extract_examples.py index 88fb3db7ef..36b0eff626 100644 --- a/doc/tools/extract_examples.py +++ b/doc/tools/extract_examples.py @@ -32,7 +32,7 @@ def add_cell(nb, lines, cell_type): # Parse the docstring and get the examples section obj = getattr(seaborn, name) - if obj.__class__ != "function": + if obj.__class__.__name__ != "function": obj = obj.__init__ lines = NumpyDocString(pydoc.getdoc(obj))["Examples"] diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 7d28934cdf..dcde8ccc75 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -37,6 +37,10 @@ def __init__(self): self._tight_layout_rect = [0, 0, 1, 1] + # This attribute is set externally and is a hack to handle newer functions that + # don't add proxy artists onto the Axes. We need an overall cleaner approach. + self._extract_legend_handles = False + def set(self, **kwargs): """Set attributes on each subplot Axes.""" for ax in self.axes.flat: @@ -168,6 +172,14 @@ def add_legend(self, legend_data=None, title=None, label_order=None, return self + @property + def legend(self): + """Access to the legend object.""" + try: + return self._legend + except AttributeError: + return None + def _clean_axis(self, ax): """Turn off axis labels and legend.""" ax.set_xlabel("") @@ -177,8 +189,15 @@ def _clean_axis(self, ax): def _update_legend_data(self, ax): """Extract the legend data from an axes object and save it.""" + data = {} + if ax.legend_ is not None and self._extract_legend_handles: + handles = ax.legend_.legendHandles + labels = [t.get_text() for t in ax.legend_.texts] + data.update({l: h for h, l in zip(handles, labels)}) + handles, labels = ax.get_legend_handles_labels() - data = {l: h for h, l in zip(handles, labels)} + data.update({l: h for h, l in zip(handles, labels)}) + self._legend_data.update(data) def _get_palette(self, data, hue, hue_order, palette): @@ -1191,17 +1210,13 @@ def _not_bottom_axes(self): class PairGrid(Grid): """Subplot grid for plotting pairwise relationships in a dataset. - This class maps each variable in a dataset onto a column and row in a + This object maps each variable in a dataset onto a column and row in a grid of multiple axes. Different axes-level plotting functions can be used to draw bivariate plots in the upper and lower triangles, and the the marginal distribution of each variable can be shown on the diagonal. - It can also represent an additional level of conditionalization with the - ``hue`` parameter, which plots different subsets of data in different - colors. This uses color to resolve elements on a third dimension, but - only draws subsets on top of each other and will not tailor the ``hue`` - parameter for the specific visualization the way that axes-level functions - that accept ``hue`` will. + Several different common plots can be generated in a single line using + :func:`pairplot`. Use :class:`PairGrid` when you need more flexibility. See the :ref:`tutorial ` for more information. @@ -1261,96 +1276,7 @@ def __init__( Examples -------- - Draw a scatterplot for each pairwise relationship: - - .. plot:: - :context: close-figs - - >>> import matplotlib.pyplot as plt - >>> import seaborn as sns; sns.set() - >>> iris = sns.load_dataset("iris") - >>> g = sns.PairGrid(iris) - >>> g = g.map(sns.scatterplot) - - Show a univariate distribution on the diagonal: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris) - >>> g = g.map_diag(plt.hist) - >>> g = g.map_offdiag(sns.scatterplot) - - (It's not actually necessary to catch the return value every time, - as it is the same object, but it makes it easier to deal with the - doctests). - - Color the points using a categorical variable: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, hue="species") - >>> g = g.map_diag(plt.hist) - >>> g = g.map_offdiag(sns.scatterplot) - >>> g = g.add_legend() - - Use a different style to show multiple histograms: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, hue="species") - >>> g = g.map_diag(plt.hist, histtype="step", linewidth=3) - >>> g = g.map_offdiag(sns.scatterplot) - >>> g = g.add_legend() - - Plot a subset of variables - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, vars=["sepal_length", "sepal_width"]) - >>> g = g.map(sns.scatterplot) - - Pass additional keyword arguments to the functions - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris) - >>> g = g.map_diag(plt.hist, edgecolor="w") - >>> g = g.map_offdiag(sns.scatterplot) - - Use different variables for the rows and columns: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, - ... x_vars=["sepal_length", "sepal_width"], - ... y_vars=["petal_length", "petal_width"]) - >>> g = g.map(sns.scatterplot) - - Use different functions on the upper and lower triangles: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris) - >>> g = g.map_upper(sns.scatterplot) - >>> g = g.map_lower(sns.kdeplot, color="C0") - >>> g = g.map_diag(sns.kdeplot, lw=2) - - Use different colors and markers for each categorical level: - - .. plot:: - :context: close-figs - - >>> g = sns.PairGrid(iris, hue="species", palette="Set2", - ... hue_kws={"marker": ["o", "s", "D"]}) - >>> g = g.map(sns.scatterplot) - >>> g = g.add_legend() + .. include:: ../docstrings/PairGrid.rst """ @@ -1364,17 +1290,15 @@ def __init__( warnings.warn(UserWarning(msg)) # Sort out the variables that define the grid + numeric_cols = self._find_numeric_cols(data) + if hue in numeric_cols: + numeric_cols.remove(hue) if vars is not None: x_vars = list(vars) y_vars = list(vars) - elif (x_vars is not None) or (y_vars is not None): - if (x_vars is None) or (y_vars is None): - raise ValueError("Must specify `x_vars` and `y_vars`") - else: - numeric_cols = self._find_numeric_cols(data) - if hue in numeric_cols: - numeric_cols.remove(hue) + if x_vars is None: x_vars = numeric_cols + if y_vars is None: y_vars = numeric_cols if np.isscalar(x_vars): @@ -1437,6 +1361,8 @@ def __init__( # Additional dict of kwarg -> list of values for mapping the hue var self.hue_kws = hue_kws if hue_kws is not None else {} + self._orig_palette = palette + self._hue_order = hue_order self.palette = self._get_palette(data, hue, hue_order, palette) self._legend_data = {} @@ -1460,6 +1386,7 @@ def map(self, func, **kwargs): row_indices, col_indices = np.indices(self.axes.shape) indices = zip(row_indices.flat, col_indices.flat) self._map_bivariate(func, indices, **kwargs) + return self def map_lower(self, func, **kwargs): @@ -1559,6 +1486,40 @@ def map_diag(self, func, **kwargs): self.diag_vars = np.array(diag_vars, np.object) self.diag_axes = np.array(diag_axes, np.object) + if "hue" not in signature(func).parameters: + return self._map_diag_iter_hue(func, **kwargs) + + # Loop over diagonal variables and axes, making one plot in each + for var, ax in zip(self.diag_vars, self.diag_axes): + + plt.sca(ax) + plot_kwargs = kwargs.copy() + + vector = self.data[var] + if self._hue_var is not None: + hue = self.data[self._hue_var] + else: + hue = None + + if self._dropna: + not_na = vector.notna() + if hue is not None: + not_na &= hue.notna() + vector = vector[not_na] + if hue is not None: + hue = hue[not_na] + + plot_kwargs.setdefault("hue", hue) + plot_kwargs.setdefault("hue_order", self._hue_order) + plot_kwargs.setdefault("palette", self._orig_palette) + func(x=vector, **plot_kwargs) + self._clean_axis(ax) + + self._add_axis_labels() + return self + + def _map_diag_iter_hue(self, func, **kwargs): + """Put marginal plot on each diagonal axes, iterating over hue.""" # Plot on each of the diagonal axes fixed_color = kwargs.pop("color", None) @@ -1571,10 +1532,9 @@ def map_diag(self, func, **kwargs): # Attempt to get data for this level, allowing for empty try: - # TODO newer matplotlib(?) doesn't need array for hist - data_k = np.asarray(hue_grouped.get_group(label_k)) + data_k = hue_grouped.get_group(label_k) except KeyError: - data_k = np.array([]) + data_k = pd.Series([], dtype=float) if fixed_color is None: color = self.palette[k] @@ -1597,25 +1557,73 @@ def map_diag(self, func, **kwargs): def _map_bivariate(self, func, indices, **kwargs): """Draw a bivariate plot on the indicated axes.""" + # This is a hack to handle the fact that new distribution plots don't add + # their artists onto the axes. This is probably superior in general, but + # we'll need a better way to handle it in the axisgrid functions. + from .distributions import histplot, kdeplot + if func is histplot or func is kdeplot: + self._extract_legend_handles = True + kws = kwargs.copy() # Use copy as we insert other kwargs - kw_color = kws.pop("color", None) for i, j in indices: x_var = self.x_vars[j] y_var = self.y_vars[i] ax = self.axes[i, j] - self._plot_bivariate(x_var, y_var, ax, func, kw_color, **kws) + self._plot_bivariate(x_var, y_var, ax, func, **kws) self._add_axis_labels() - def _plot_bivariate(self, x_var, y_var, ax, func, kw_color, **kwargs): + if "hue" in signature(func).parameters: + self.hue_names = list(self._legend_data) + + def _plot_bivariate(self, x_var, y_var, ax, func, **kwargs): """Draw a bivariate plot on the specified axes.""" + if "hue" not in signature(func).parameters: + self._plot_bivariate_iter_hue(x_var, y_var, ax, func, **kwargs) + return + + plt.sca(ax) + kwargs = kwargs.copy() + + if x_var == y_var: + axes_vars = [x_var] + else: + axes_vars = [x_var, y_var] + + if self._hue_var is not None and self._hue_var not in axes_vars: + axes_vars.append(self._hue_var) + + data = self.data[axes_vars] + if self._dropna: + data = data.dropna() + + x = data[x_var] + y = data[y_var] + if self._hue_var is None: + hue = None + else: + hue = data.get(self._hue_var) + + kwargs.setdefault("hue", hue) + kwargs.setdefault("hue_order", self._hue_order) + kwargs.setdefault("palette", self._orig_palette) + func(x=x, y=y, **kwargs) + + self._update_legend_data(ax) + self._clean_axis(ax) + + def _plot_bivariate_iter_hue(self, x_var, y_var, ax, func, **kwargs): + """Draw a bivariate plot while iterating over hue subsets.""" plt.sca(ax) if x_var == y_var: axes_vars = [x_var] else: axes_vars = [x_var, y_var] + hue_grouped = self.data.groupby(self.hue_vals) for k, label_k in enumerate(self.hue_names): + kws = kwargs.copy() + # Attempt to get data for this level, allowing for empty try: data_k = hue_grouped.get_group(label_k) @@ -1630,16 +1638,18 @@ def _plot_bivariate(self, x_var, y_var, ax, func, kw_color, **kwargs): y = data_k[y_var] for kw, val_list in self.hue_kws.items(): - kwargs[kw] = val_list[k] - color = self.palette[k] if kw_color is None else kw_color + kws[kw] = val_list[k] + kws.setdefault("color", self.palette[k]) + if self._hue_var is not None: + kws["label"] = label_k if str(func.__module__).startswith("seaborn"): - func(x=x, y=y, label=label_k, color=color, **kwargs) + func(x=x, y=y, **kws) else: - func(x, y, label=label_k, color=color, **kwargs) + func(x, y, **kws) - self._clean_axis(ax) self._update_legend_data(ax) + self._clean_axis(ax) def _add_axis_labels(self): """Add labels to the left and bottom Axes.""" @@ -1958,10 +1968,10 @@ def pairplot( """Plot pairwise relationships in a dataset. By default, this function will create a grid of Axes such that each numeric - variable in ``data`` will by shared in the y-axis across a single row and - in the x-axis across a single column. The diagonal Axes are treated - differently, drawing a plot to show the univariate distribution of the data - for the variable in that column. + variable in ``data`` will by shared across the y-axes across a single row and + the x-axes across a single column. The diagonal plots are treated + differently: a univariate distribution plot is drawn to show the marginal + distribution of the data in each column. It is also possible to show a subset of variables or plot different variables on the rows and columns. @@ -1972,10 +1982,10 @@ def pairplot( Parameters ---------- - data : DataFrame + data : `pandas.DataFrame` Tidy (long-form) dataframe where each column is a variable and each row is an observation. - hue : string (variable name) + hue : name of variable in ``data`` Variable in ``data`` to map plot aspects to different colors. hue_order : list of strings Order for the levels of the hue variable in the palette @@ -1988,14 +1998,14 @@ def pairplot( {x, y}_vars : lists of variable names Variables within ``data`` to use separately for the rows and columns of the figure; i.e. to make a non-square plot. - kind : {'scatter', 'reg'} - Kind of plot for the non-identity relationships. + kind : {'scatter', 'kde', 'hist', 'reg'} + Kind of plot to make. diag_kind : {'auto', 'hist', 'kde', None} - Kind of plot for the diagonal subplots. The default depends on whether - ``"hue"`` is used or not. + Kind of plot for the diagonal subplots. If 'auto', choose based on + whether or not ``hue`` is used. markers : single matplotlib marker code or list - Either the marker to use for all datapoints or a list of markers with - a length the same as the number of levels in the hue variable so that + Either the marker to use for all scatterplot points or a list of markers + with a length the same as the number of levels in the hue variable so that differently colored points will also have different scatterplot markers. height : scalar @@ -2020,102 +2030,17 @@ def pairplot( See Also -------- - PairGrid : Subplot grid for more flexible plotting of pairwise - relationships. + PairGrid : Subplot grid for more flexible plotting of pairwise relationships. + JointGrid : Grid for plotting joint and marginal distributions of two variables. Examples -------- - Draw scatterplots for joint relationships and histograms for univariate - distributions: - - .. plot:: - :context: close-figs - - >>> import seaborn as sns; sns.set(style="ticks", color_codes=True) - >>> iris = sns.load_dataset("iris") - >>> g = sns.pairplot(iris) - - Show different levels of a categorical variable by the color of plot - elements: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, hue="species") - - Use a different color palette: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, hue="species", palette="husl") - - Use different markers for each level of the hue variable: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, hue="species", markers=["o", "s", "D"]) - - Plot a subset of variables: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, vars=["sepal_width", "sepal_length"]) - - Draw larger plots: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, height=3, - ... vars=["sepal_width", "sepal_length"]) - - Plot different variables in the rows and columns: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, - ... x_vars=["sepal_width", "sepal_length"], - ... y_vars=["petal_width", "petal_length"]) - - Plot only the lower triangle of bivariate axes: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, corner=True) - - Use kernel density estimates for univariate plots: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, diag_kind="kde") - - Fit linear regression models to the scatter plots: - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, kind="reg") - - Pass keyword arguments down to the underlying functions (it may be easier - to use :class:`PairGrid` directly): - - .. plot:: - :context: close-figs - - >>> g = sns.pairplot(iris, diag_kind="kde", markers="+", - ... plot_kws=dict(s=50, edgecolor="b", linewidth=1), - ... diag_kws=dict(fill=True)) + .. include:: ../docstrings/pairplot.rst """ # Avoid circular import - from .distributions import kdeplot # TODO histplot + from .distributions import histplot, kdeplot # Handle deprecations if size is not None: @@ -2142,32 +2067,42 @@ def pairplot( # Add the markers here as PairGrid has figured out how many levels of the # hue variable are needed and we don't want to duplicate that process if markers is not None: - if grid.hue_names is None: - n_markers = 1 - else: - n_markers = len(grid.hue_names) - if not isinstance(markers, list): - markers = [markers] * n_markers - if len(markers) != n_markers: - raise ValueError(("markers must be a singleton or a list of " - "markers for each level of the hue variable")) - grid.hue_kws = {"marker": markers} + if kind == "reg": + # Needed until regplot supports style + if grid.hue_names is None: + n_markers = 1 + else: + n_markers = len(grid.hue_names) + if not isinstance(markers, list): + markers = [markers] * n_markers + if len(markers) != n_markers: + raise ValueError(("markers must be a singleton or a list of " + "markers for each level of the hue variable")) + grid.hue_kws = {"marker": markers} + elif kind == "scatter": + if isinstance(markers, str): + plot_kws["marker"] = markers + elif hue is not None: + plot_kws["style"] = data[hue] + plot_kws["markers"] = markers # Maybe plot on the diagonal if diag_kind == "auto": - diag_kind = "hist" if hue is None else "kde" + if hue is None: + diag_kind = "kde" if kind == "kde" else "hist" + else: + diag_kind = "hist" if kind == "hist" else "kde" diag_kws = diag_kws.copy() - if grid.square_grid: - if diag_kind == "hist": - grid.map_diag(plt.hist, **diag_kws) - elif diag_kind == "kde": - diag_kws.setdefault("fill", True) - diag_kws["legend"] = False - grid.map_diag(kdeplot, **diag_kws) + diag_kws.setdefault("legend", False) + if diag_kind == "hist": + grid.map_diag(histplot, **diag_kws) + elif diag_kind == "kde": + diag_kws.setdefault("fill", True) + grid.map_diag(kdeplot, **diag_kws) # Maybe plot on the off-diagonals - if grid.square_grid and diag_kind is not None: + if diag_kind is not None: plotter = grid.map_offdiag else: plotter = grid.map @@ -2178,6 +2113,12 @@ def pairplot( elif kind == "reg": from .regression import regplot # Avoid circular import plotter(regplot, **plot_kws) + elif kind == "kde": + from .distributions import kdeplot # Avoid circular import + plotter(kdeplot, **plot_kws) + elif kind == "hist": + from .distributions import histplot # Avoid circular import + plotter(histplot, **plot_kws) # Add a legend if hue is not None: diff --git a/seaborn/distributions.py b/seaborn/distributions.py index 00a66afac9..f932582f7f 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -498,9 +498,7 @@ def plot_univariate_histogram( # Avoid drawing empty fill_between on date axis # https://github.com/matplotlib/matplotlib/issues/17586 scout = None - default_color = plot_kws.pop( - "color", plot_kws.pop("facecolor", None) - ) + default_color = plot_kws.pop("facecolor", color) if default_color is None: default_color = "C0" else: @@ -508,7 +506,6 @@ def plot_univariate_histogram( plot_kws = _normalize_kwargs(plot_kws, artist) scout = self.ax.fill_between([], [], color=color, **plot_kws) default_color = tuple(scout.get_facecolor().squeeze()) - plot_kws.pop("color", None) else: artist = mpl.lines.Line2D plot_kws = _normalize_kwargs(plot_kws, artist) diff --git a/seaborn/tests/test_axisgrid.py b/seaborn/tests/test_axisgrid.py index ed2c215634..494f6bd7f6 100644 --- a/seaborn/tests/test_axisgrid.py +++ b/seaborn/tests/test_axisgrid.py @@ -934,12 +934,13 @@ def test_map_diag_color(self): def test_map_diag_palette(self): - pal = color_palette(n_colors=len(self.df.a.unique())) - g = ag.PairGrid(self.df, hue="a") + palette = "muted" + pal = color_palette(palette, n_colors=len(self.df.a.unique())) + g = ag.PairGrid(self.df, hue="a", palette=palette) g.map_diag(kdeplot) for ax in g.diag_axes: - for line, color in zip(ax.lines, pal): + for line, color in zip(ax.lines[::-1], pal): assert line.get_color() == color def test_map_diag_and_offdiag(self): @@ -1131,10 +1132,11 @@ def test_nondefault_index(self): x_in_k = x_in[self.df.a == k_level] y_in_k = y_in[self.df.a == k_level] x_out, y_out = ax.collections[k].get_offsets().T - npt.assert_array_equal(x_in_k, x_out) - npt.assert_array_equal(y_in_k, y_out) + npt.assert_array_equal(x_in_k, x_out) + npt.assert_array_equal(y_in_k, y_out) - def test_dropna(self): + @pytest.mark.parametrize("func", [scatterplot, plt.scatter]) + def test_dropna(self, func): df = self.df.copy() n_null = 20 @@ -1143,7 +1145,7 @@ def test_dropna(self): plot_vars = ["x", "y", "z"] g1 = ag.PairGrid(df, vars=plot_vars, dropna=True) - g1.map(plt.scatter) + g1.map(func) for i, axes_i in enumerate(g1.axes): for j, ax in enumerate(axes_i): @@ -1156,6 +1158,21 @@ def test_dropna(self): assert n_valid == len(x_out) assert n_valid == len(y_out) + g1.map_diag(histplot) + for i, ax in enumerate(g1.diag_axes): + var = plot_vars[i] + count = sum([p.get_height() for p in ax.patches]) + assert count == df[var].notna().sum() + + def test_histplot_legend(self): + + # Tests _extract_legend_handles + g = ag.PairGrid(self.df, vars=["x", "y"], hue="a") + g.map_offdiag(histplot) + g.add_legend() + + assert len(g._legend.legendHandles) == len(self.df["a"].unique()) + def test_pairplot(self): vars = ["x", "y", "z"] @@ -1196,7 +1213,7 @@ def test_pairplot_reg(self): g = ag.pairplot(self.df, diag_kind="hist", kind="reg") for ax in g.diag_axes: - nt.assert_equal(len(ax.patches), 10) + assert len(ax.patches) for i, j in zip(*np.triu_indices_from(g.axes, 1)): ax = g.axes[i, j] @@ -1224,7 +1241,21 @@ def test_pairplot_reg(self): ax = g.axes[i, j] nt.assert_equal(len(ax.collections), 0) - def test_pairplot_kde(self): + def test_pairplot_reg_hue(self): + + markers = ["o", "s", "d"] + g = ag.pairplot(self.df, kind="reg", hue="a", markers=markers) + + ax = g.axes[-1, 0] + c1 = ax.collections[0] + c2 = ax.collections[2] + + assert not np.array_equal(c1.get_facecolor(), c2.get_facecolor()) + assert not np.array_equal( + c1.get_paths()[0].vertices, c2.get_paths()[0].vertices, + ) + + def test_pairplot_diag_kde(self): vars = ["x", "y", "z"] g = ag.pairplot(self.df, diag_kind="kde") @@ -1252,13 +1283,34 @@ def test_pairplot_kde(self): ax = g.axes[i, j] nt.assert_equal(len(ax.collections), 0) + def test_pairplot_kde(self): + + f, ax1 = plt.subplots() + kdeplot(data=self.df, x="x", y="y", ax=ax1) + + g = ag.pairplot(self.df, kind="kde") + ax2 = g.axes[1, 0] + + assert_plots_equal(ax1, ax2, labels=False) + + def test_pairplot_hist(self): + + f, ax1 = plt.subplots() + histplot(data=self.df, x="x", y="y", ax=ax1) + + g = ag.pairplot(self.df, kind="hist") + ax2 = g.axes[1, 0] + + assert_plots_equal(ax1, ax2, labels=False) + def test_pairplot_markers(self): vars = ["x", "y", "z"] - markers = ["o", "x", "s"] + markers = ["o", "X", "s"] g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers) - assert g.hue_kws["marker"] == markers - plt.close("all") + m1 = g._legend.legendHandles[0].get_paths()[0] + m2 = g._legend.legendHandles[1].get_paths()[0] + assert m1 != m2 with pytest.raises(ValueError): g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers[:-2]) @@ -1275,6 +1327,14 @@ def test_corner_set(self): g.set(xlim=(0, 10)) assert g.axes[-1, 0].get_xlim() == (0, 10) + def test_legend(self): + + g1 = ag.pairplot(self.df, hue="a") + assert isinstance(g1.legend, mpl.legend.Legend) + + g2 = ag.pairplot(self.df) + assert g2.legend is None + class TestJointGrid(object):