Skip to content

Commit

Permalink
Don't fail in regplot on singletons, don't fit regression either (#1969)
Browse files Browse the repository at this point in the history
* FIX: don't squeeze singletons

* Disable regression fit in case of singleton inputs

* Update release notes

[ci skip]

(cherry picked from commit d59ab10)
  • Loading branch information
mwaskom committed Feb 22, 2020
1 parent aa35d00 commit 488c167
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
2 changes: 2 additions & 0 deletions doc/releases/v0.10.1.txt
Expand Up @@ -5,3 +5,5 @@ v0.10.1 (Unreleased)
This is minor release with bug fixes issues in 0.10.0.

- Fixed a bug that appeared within the bootstrapping algorithm on 32-bit systems.

- Fixed a bug where :func:`regplot` would crash on singleton inputs. Now a crash is avoided and regression estimation/plotting is skipped.
9 changes: 7 additions & 2 deletions seaborn/regression.py
Expand Up @@ -48,7 +48,7 @@ def establish_variables(self, data, **kws):
vector = np.asarray(val)
else:
vector = val
if vector is not None:
if vector is not None and vector.shape != (1,):
vector = np.squeeze(vector)
if np.ndim(vector) > 1:
err = "regplot inputs must be 1d"
Expand Down Expand Up @@ -127,8 +127,13 @@ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
else:
self.x_discrete = self.x

# Disable regression in case of singleton inputs
if len(self.x) <= 1:
self.fit_reg = False

# Save the range of the x variable for the grid later
self.x_range = self.x.min(), self.x.max()
if self.fit_reg:
self.x_range = self.x.min(), self.x.max()

@property
def scatter_data(self):
Expand Down
8 changes: 8 additions & 0 deletions seaborn/tests/test_regression.py
Expand Up @@ -165,6 +165,14 @@ def test_dropna(self):
p = lm._RegressionPlotter("x", "y_na", data=self.df, dropna=False)
nt.assert_equal(len(p.x), len(self.df.y_na))

@pytest.mark.parametrize("x,y",
[([1.5], [2]),
(np.array([1.5]), np.array([2])),
(pd.Series(1.5), pd.Series(2))])
def test_singleton(self, x, y):
p = lm._RegressionPlotter(x, y)
assert not p.fit_reg

def test_ci(self):

p = lm._RegressionPlotter("x", "y", data=self.df, ci=95)
Expand Down

0 comments on commit 488c167

Please sign in to comment.