Skip to content

Commit

Permalink
MAINT: Improve user-facing error message
Browse files Browse the repository at this point in the history
Improve error message shown for incorrect use of predict after formula

closes statsmodels#3987
  • Loading branch information
bashtage committed May 6, 2019
1 parent fb6d268 commit 0333c68
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
10 changes: 9 additions & 1 deletion statsmodels/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,15 @@ def predict(self, exog=None, transform=True, *args, **kwargs):
exog = pd.DataFrame(exog).T
orig_exog_len = len(exog)
is_dict = isinstance(exog, dict)
exog = dmatrix(design_info, exog, return_type="dataframe")
try:
exog = dmatrix(design_info, exog, return_type="dataframe")
except Exception as exc:
msg = ('predict requires that you use a DataFrame when '
'predicting from a model\nthat was created using the '
'formula api.'
'\n\nThe original error message returned by patsy is:\n'
'{0}'.format(str(str(exc))))
raise exc.__class__(msg)
if orig_exog_len > len(exog) and not is_dict:
import warnings
if exog_index is None:
Expand Down
15 changes: 14 additions & 1 deletion statsmodels/formula/tests/test_formula.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from statsmodels.compat.python import iteritems, StringIO
from statsmodels.compat.python import iteritems, StringIO, PY3
import warnings

from statsmodels.formula.api import ols
Expand All @@ -10,6 +10,8 @@
import numpy.testing as npt
from statsmodels.tools.testing import assert_equal
import numpy as np
import pandas as pd
import patsy
import pytest


Expand Down Expand Up @@ -210,3 +212,14 @@ def test_patsy_missing_data():
assert 'nan values have been dropped' in repr(w[-1].message)
# Frist record will be dropped in both cases
assert_equal(res.fittedvalues, res2)


def test_predict_nondataframe():
df = pd.DataFrame([[3, 0.030], [10, 0.060], [20, 0.120]],
columns=['BSA', 'Absorbance'])

model = ols('Absorbance ~ BSA', data=df)
fit = model.fit()
error = patsy.PatsyError if PY3 else TypeError
with pytest.raises(error):
fit.predict([0.25])

0 comments on commit 0333c68

Please sign in to comment.