Skip to content

Commit

Permalink
Several fixes/improvements related to lmplot axis scaling (#2576)
Browse files Browse the repository at this point in the history
* Only set sticky edges on regression line when not truncating

* Improve how datalimits are established for non-truncated lines

Fixes #2509

* Add facet_kws to lmplot

Closes #2518

* Update release notes

* Skip test that fails because of matplotlib bug

xref matplotlib/matplotlib#15967

* Always set datalims using float data

* Deprecate sharex/sharey/legend_out from lmplot signature

* Tweak release note [skip ci]

(cherry picked from commit 3750373)
  • Loading branch information
mwaskom committed Aug 6, 2021
1 parent 5f69eeb commit 0877be4
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 14 deletions.
8 changes: 7 additions & 1 deletion doc/releases/v0.12.0.txt
Expand Up @@ -27,10 +27,16 @@ v0.12.0 (Unreleased)

- |Enhancement| |Fix| Improved integration with the matplotlib color cycle in most axes-level functions (:pr:`2449`).

- |API| In :func:`lmplot`, the `sharex`, `sharey`, and `legend_out` parameters have been deprecated from the function signature, but they can be passed using the new `facet_kws` parameter (:pr:`2576`).

- |Fix| In :func:`lineplot, allowed the `dashes` keyword to set the style of a line without mapping a `style` variable (:pr:`2449`).

- |Fix| In :func:`rugplot`, fixed a bug that prevented the use of datetime data (:pr:`2458`).

- |Fix| In :func:`lmplot`, fixed a bug where the x axis was clamped to the data limits with `truncate=True` (:pr:`2576`).

- |Fix| In :func:`lmplot`, fixed a bug where `sharey=False` did not always work as expected (:pr:`2576`).

- |Fix| In :func:`histplot` and :func:`kdeplot`, fixed a bug where the `alpha` parameter was ignored when `fill=False` (:pr:`2460`).

- |Fix| In :func:`histplot` and :func:`kdeplot`, fixed a bug where the `multiple` was ignored when `hue` was provided as a vector without a name (:pr:`2462`).
Expand All @@ -39,7 +45,7 @@ v0.12.0 (Unreleased)

- |Fix| In :func:`histplot`, fixed a bug where using `shrink` with non-discrete bins shifted bar positions inaccurately (:pr:`2477`).

- |Fix| In :func:`histplot`, fixed two bugs where automatically computed edge widths were too thick for log-scaled histograms and categorical histograms on the y axis (:pr:2522`).
- |Fix| In :func:`histplot`, fixed two bugs where automatically computed edge widths were too thick for log-scaled histograms and categorical histograms on the y axis (:pr:`2522`).

- |Fix| In :func:`displot`, fixed a bug where `common_norm` was ignored when `kind="hist"` and faceting was used without assigning `hue` (:pr:`2468`).

Expand Down
51 changes: 38 additions & 13 deletions seaborn/regression.py
Expand Up @@ -420,7 +420,8 @@ def lineplot(self, ax, kws):

# Draw the regression line and confidence interval
line, = ax.plot(grid, yhat, **kws)
line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
if not self.truncate:
line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
if err_bands is not None:
ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)

Expand Down Expand Up @@ -563,13 +564,13 @@ def lmplot(
data=None,
hue=None, col=None, row=None, # TODO move before data once * is enforced
palette=None, col_wrap=None, height=5, aspect=1, markers="o",
sharex=True, sharey=True, hue_order=None, col_order=None, row_order=None,
legend=True, legend_out=True, x_estimator=None, x_bins=None,
sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None,
legend=True, legend_out=None, x_estimator=None, x_bins=None,
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
units=None, seed=None, order=1, logistic=False, lowess=False,
robust=False, logx=False, x_partial=None, y_partial=None,
truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
line_kws=None, size=None
line_kws=None, facet_kws=None, size=None,
):

# Handle deprecations
Expand All @@ -579,6 +580,22 @@ def lmplot(
"please update your code.")
warnings.warn(msg, UserWarning)

if facet_kws is None:
facet_kws = {}

def facet_kw_deprecation(key, val):
msg = (
f"{key} is deprecated from the `lmplot` function signature. "
"Please update your code to pass it using `facet_kws`."
)
if val is not None:
warnings.warn(msg, UserWarning)
facet_kws[key] = val

facet_kw_deprecation("sharex", sharex)
facet_kw_deprecation("sharey", sharey)
facet_kw_deprecation("legend_out", legend_out)

if data is None:
raise TypeError("Missing required keyword argument `data`.")

Expand All @@ -593,7 +610,7 @@ def lmplot(
palette=palette,
row_order=row_order, col_order=col_order, hue_order=hue_order,
height=height, aspect=aspect, col_wrap=col_wrap,
sharex=sharex, sharey=sharey, legend_out=legend_out
**facet_kws,
)

# Add the markers here as FacetGrid has figured out how many levels of the
Expand All @@ -609,12 +626,12 @@ def lmplot(
"for each level of the hue variable"))
facets.hue_kws = {"marker": markers}

# Hack to set the x limits properly, which needs to happen here
# because the extent of the regression estimate is determined
# by the limits of the plot
if sharex:
for ax in facets.axes.flat:
ax.scatter(data[x], np.ones(len(data)) * data[y].mean()).remove()
def update_datalim(data, x, y, ax, **kws):
xys = data[[x, y]].to_numpy().astype(float)
ax.update_datalim(xys, updatey=False)
ax.autoscale_view(scaley=False)

facets.map_dataframe(update_datalim, x=x, y=y)

# Draw the regression plot on each facet
regplot_kws = dict(
Expand All @@ -626,8 +643,6 @@ def lmplot(
scatter_kws=scatter_kws, line_kws=line_kws,
)
facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)

# TODO this will need to change when we relax string requirement
facets.set_axis_labels(x, y)

# Add a legend
Expand Down Expand Up @@ -672,13 +687,21 @@ def lmplot(
Markers for the scatterplot. If a list, each marker in the list will be
used for each level of the ``hue`` variable.
{share_xy}
.. deprecated:: 0.12.0
Pass using the `facet_kws` dictionary.
{{hue,col,row}}_order : lists, optional
Order for the levels of the faceting variables. By default, this will
be the order that the levels appear in ``data`` or, if the variables
are pandas categoricals, the category order.
legend : bool, optional
If ``True`` and there is a ``hue`` variable, add a legend.
{legend_out}
.. deprecated:: 0.12.0
Pass using the `facet_kws` dictionary.
{x_estimator}
{x_bins}
{x_ci}
Expand All @@ -697,6 +720,8 @@ def lmplot(
{truncate}
{xy_jitter}
{scatter_line_kws}
facet_kws : dict
Dictionary of keyword arguments for :class:`FacetGrid`.
See Also
--------
Expand Down
39 changes: 39 additions & 0 deletions seaborn/tests/test_regression.py
@@ -1,3 +1,4 @@
from distutils.version import LooseVersion
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -596,6 +597,44 @@ def test_lmplot_scatter_kws(self):
npt.assert_array_equal(red, red_scatter.get_facecolors()[0, :3])
npt.assert_array_equal(blue, blue_scatter.get_facecolors()[0, :3])

@pytest.mark.skipif(LooseVersion(mpl.__version__) < "3.4",
reason="MPL bug #15967")
@pytest.mark.parametrize("sharex", [True, False])
def test_lmplot_facet_truncate(self, sharex):

g = lm.lmplot(
data=self.df, x="x", y="y", hue="g", col="h",
truncate=False, facet_kws=dict(sharex=sharex),
)

for ax in g.axes.flat:
for line in ax.lines:
xdata = line.get_xdata()
assert ax.get_xlim() == tuple(xdata[[0, -1]])

def test_lmplot_sharey(self):

df = pd.DataFrame(dict(
x=[0, 1, 2, 0, 1, 2],
y=[1, -1, 0, -100, 200, 0],
z=["a", "a", "a", "b", "b", "b"],
))

with pytest.warns(UserWarning):
g = lm.lmplot(data=df, x="x", y="y", col="z", sharey=False)
ax1, ax2 = g.axes.flat
assert ax1.get_ylim()[0] > ax2.get_ylim()[0]
assert ax1.get_ylim()[1] < ax2.get_ylim()[1]

def test_lmplot_facet_kws(self):

xlim = -4, 20
g = lm.lmplot(
data=self.df, x="x", y="y", col="h", facet_kws={"xlim": xlim}
)
for ax in g.axes.flat:
assert ax.get_xlim() == xlim

def test_residplot(self):

x, y = self.df.x, self.df.y
Expand Down

0 comments on commit 0877be4

Please sign in to comment.