Skip to content

Commit

Permalink
Flake8 formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
leoluecken committed Jan 29, 2024
1 parent fc55e33 commit f670959
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 35 deletions.
48 changes: 30 additions & 18 deletions seaborn/regression.py
Expand Up @@ -76,9 +76,9 @@ class _RegressionPlotter(_LinearPlotter):
def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
units=None, seed=None, order=1, logistic=False, lowess=False,
robust=False, logx=False, x_partial=None, y_partial=None,
truncate=False, dropna=True, x_jitter=None, y_jitter=None,
color=None, label=None):
lowess_kws=None, robust=False, logx=False,
x_partial=None, y_partial=None, truncate=False, dropna=True,
x_jitter=None, y_jitter=None, color=None, label=None):

# Set member attributes
self.x_estimator = x_estimator
Expand All @@ -91,6 +91,7 @@ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
self.order = order
self.logistic = logistic
self.lowess = lowess
self.lowess_kws = {} if lowess_kws is None else lowess_kws
self.robust = robust
self.logx = logx
self.truncate = truncate
Expand Down Expand Up @@ -126,13 +127,20 @@ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
self.x_discrete = self.x

# Disable regression in case of singleton inputs
if len(self.x) <= 1:
if len(self.x) <= 1 or self.lowess:
self.fit_reg = False

# Save the range of the x variable for the grid later
if self.fit_reg:
self.x_range = self.x.min(), self.x.max()

# Check lowess_kws
if self.lowess:
allowed_lowess_kws = ("frac", "it", "delta")
for k in self.lowess_kws:
if k not in allowed_lowess_kws:
raise ValueError(f"Unsupported parameter '{k}' for lowess.")

@property
def scatter_data(self):
"""Data where each observation is a point."""
Expand Down Expand Up @@ -306,7 +314,7 @@ def reg_func(_x, _y):
def fit_lowess(self):
"""Fit a locally-weighted regression, which returns its own grid."""
from statsmodels.nonparametric.smoothers_lowess import lowess
grid, yhat = lowess(self.y, self.x).T
grid, yhat = lowess(self.y, self.x, **self.lowess_kws).T
return grid, yhat

def fit_logx(self, grid):
Expand Down Expand Up @@ -533,6 +541,11 @@ def lineplot(self, ax, kws):
model (locally weighted linear regression). Note that confidence
intervals cannot currently be drawn for this kind of model.\
"""),
lowess_kws=dedent("""\
lowess_kws : dict, optional
Additional keyword arguments to pass to ``lowess()`` function \
from ``statsmodels``.
"""),
robust=dedent("""\
robust : bool, optional
If ``True``, use ``statsmodels`` to estimate a robust regression. This
Expand Down Expand Up @@ -581,7 +594,7 @@ def lmplot(
legend=True, legend_out=None, x_estimator=None, x_bins=None,
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
units=None, seed=None, order=1, logistic=False, lowess=False,
robust=False, logx=False, x_partial=None, y_partial=None,
lowess_kws=None, robust=False, logx=False, x_partial=None, y_partial=None,
truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
line_kws=None, facet_kws=None,
):
Expand Down Expand Up @@ -646,7 +659,7 @@ def update_datalim(data, x, y, ax, **kws):
seed=seed, order=order, logistic=logistic, lowess=lowess,
robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
scatter_kws=scatter_kws, line_kws=line_kws,
scatter_kws=scatter_kws, line_kws=line_kws, lowess_kws=lowess_kws
)
facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)
facets.set_axis_labels(x, y)
Expand Down Expand Up @@ -720,6 +733,7 @@ def update_datalim(data, x, y, ax, **kws):
{order}
{logistic}
{lowess}
{lowess_kws}
{robust}
{logx}
{xy_partial}
Expand Down Expand Up @@ -753,17 +767,17 @@ def regplot(
data=None, *, x=None, y=None,
x_estimator=None, x_bins=None, x_ci="ci",
scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
seed=None, order=1, logistic=False, lowess=False, robust=False,
logx=False, x_partial=None, y_partial=None,
seed=None, order=1, logistic=False, lowess=False, lowess_kws=None,
robust=False, logx=False, x_partial=None, y_partial=None,
truncate=True, dropna=True, x_jitter=None, y_jitter=None,
label=None, color=None, marker="o",
scatter_kws=None, line_kws=None, ax=None
):

plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
scatter, fit_reg, ci, n_boot, units, seed,
order, logistic, lowess, robust, logx,
x_partial, y_partial, truncate, dropna,
order, logistic, lowess, lowess_kws, robust,
logx, x_partial, y_partial, truncate, dropna,
x_jitter, y_jitter, color, label)

if ax is None:
Expand Down Expand Up @@ -800,6 +814,7 @@ def regplot(
{order}
{logistic}
{lowess}
{lowess_kws}
{robust}
{logx}
{xy_partial}
Expand Down Expand Up @@ -853,7 +868,7 @@ def regplot(

def residplot(
data=None, *, x=None, y=None,
x_partial=None, y_partial=None, lowess=False,
x_partial=None, y_partial=None, lowess=False, lowess_kws=None,
order=1, robust=False, dropna=True, label=None, color=None,
scatter_kws=None, line_kws=None, ax=None
):
Expand All @@ -877,6 +892,8 @@ def residplot(
the `x` or `y` variables before plotting.
lowess : boolean, optional
Fit a lowess smoother to the residual scatterplot.
lowess_kws : dict, optional
Additional keyword arguments passed to lowess() from statsmodels.
order : int, optional
Order of the polynomial to fit when calculating the residuals.
robust : boolean, optional
Expand Down Expand Up @@ -915,6 +932,7 @@ def residplot(
plotter = _RegressionPlotter(x, y, data, ci=None,
order=order, robust=robust,
x_partial=x_partial, y_partial=y_partial,
lowess=lowess, lowess_kws=lowess_kws,
dropna=dropna, color=color, label=label)

if ax is None:
Expand All @@ -924,12 +942,6 @@ def residplot(
_, yhat, _ = plotter.fit_regression(grid=plotter.x)
plotter.y = plotter.y - yhat

# Set the regression option on the plotter
if lowess:
plotter.lowess = True
else:
plotter.fit_reg = False

# Plot a horizontal line at 0
ax.axhline(0, ls=":", c=".2")

Expand Down
35 changes: 18 additions & 17 deletions tests/test_regression.py
Expand Up @@ -427,26 +427,27 @@ def test_lowess_regression(self):

p = lm._RegressionPlotter("x", "y", data=self.df, lowess=True)
grid, yhat, err_bands = p.fit_regression(x_range=(-3, 3))

assert len(grid) == len(yhat)
assert err_bands is None

@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_lowess_regression_with_kws(self):
lowess_kws = dict(frac=2/3, it=1, delta=0.0)
p = lm._RegressionPlotter("x", "y", data=self.df, lowess=True,
lowess_kws = dict(frac=2 / 3, it=1, delta=0.0)
p = lm._RegressionPlotter("x", "y", data=self.df, lowess=True,
lowess_kws=lowess_kws)
grid, yhat, err_bands = p.fit_regression(x_range=(-3, 3))

assert len(grid) == len(yhat)
assert err_bands is None

@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_lowess_regression_with_bad_kw(self):

lowess_kws = dict(frac=2/3, it=3, delta=0.0, bad_kw=-1)
with pytest.raises(ValueError, match="Unsupported parameter 'bad_kw' for lowess\\."):
lm._RegressionPlotter("x", "y", data=self.df, lowess=True,

lowess_kws = dict(frac=2 / 3, it=3, delta=0.0, bad_kw=-1)
with pytest.raises(ValueError, match="Unsupported parameter "
"'bad_kw' for lowess\\."):
lm._RegressionPlotter("x", "y", data=self.df, lowess=True,
lowess_kws=lowess_kws)

def test_regression_options(self):
Expand Down Expand Up @@ -683,26 +684,26 @@ def test_residplot_lowess(self):

x, y = ax.lines[1].get_xydata().T
npt.assert_array_equal(x, np.sort(self.df.x))

@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_residplot_lowess_with_kws(self):

lowess_kws = dict(frac=2/3, it=3, delta=0.0)
ax = lm.residplot(x="x", y="y", data=self.df, lowess=True,
lowess_kws = dict(frac=2 / 3, it=3, delta=0.0)
ax = lm.residplot(x="x", y="y", data=self.df, lowess=True,
lowess_kws=lowess_kws)
assert len(ax.lines) == 2

x, y = ax.lines[1].get_xydata().T
npt.assert_array_equal(x, np.sort(self.df.x))

@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_residplot_lowess_bad_kw(self):

lowess_kws = dict(frac=2/3, it=3, delta=0.0, bad_kw=-1)
lowess_kws = dict(frac=2 / 3, it=3, delta=0.0, bad_kw=-1)
with pytest.raises(ValueError, match="Unsupported parameter"
" 'bad_kw' for lowess\\."):
lm.residplot(x="x", y="y", data=self.df, lowess=True,
lowess_kws=lowess_kws)
lm.residplot(x="x", y="y", data=self.df, lowess=True,
lowess_kws=lowess_kws)

@pytest.mark.parametrize("option", ["robust", "lowess"])
@pytest.mark.skipif(not _no_statsmodels, reason="statsmodels installed")
Expand Down

0 comments on commit f670959

Please sign in to comment.