Skip to content

Commit

Permalink
Add formula to stat_smooth
Browse files Browse the repository at this point in the history
closes #311
  • Loading branch information
has2k1 committed Apr 11, 2020
1 parent abcc966 commit eeb8f66
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 7 deletions.
3 changes: 3 additions & 0 deletions doc/changelog.rst
Expand Up @@ -22,6 +22,9 @@ New Features
makes it easy to change the ordering of a discrete variable according
to some other variable/column.

- :class:`~plotnine.stats.stat_smooth` can now use formulae for linear
models.


Bug Fixes
*********
Expand Down
3 changes: 2 additions & 1 deletion doc/conf.py
Expand Up @@ -410,7 +410,8 @@
'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None),
'sklearn': ('http://scikit-learn.org/stable/', None),
'skmisc': ('https://has2k1.github.io/scikit-misc/', None),
'adjustText': ('https://adjusttext.readthedocs.io/en/latest/', None)
'adjustText': ('https://adjusttext.readthedocs.io/en/latest/', None),
'patsy': ('https://patsy.readthedocs.io/en/stable', None)
}


Expand Down
153 changes: 148 additions & 5 deletions plotnine/stats/smoothers.py
Expand Up @@ -5,6 +5,8 @@
import pandas as pd
import scipy.stats as stats
import statsmodels.api as sm
import statsmodels.formula.api as smf
from patsy import dmatrices

from ..exceptions import PlotnineError, PlotnineWarning
from ..utils import get_valid_kwargs
Expand Down Expand Up @@ -47,17 +49,25 @@ def lm(data, xseq, **params):
"""
Fit OLS / WLS if data has weight
"""
if params['formula']:
return lm_formula(data, xseq, **params)

X = sm.add_constant(data['x'])
Xseq = sm.add_constant(xseq)
weights = data.get('weights', None)

if 'weight' in data:
init_kwargs, fit_kwargs = separate_method_kwargs(
params['method_args'], sm.WLS, sm.WLS.fit)
model = sm.WLS(data['y'], X, weights=data['weight'], **init_kwargs)
else:
if weights is None:
init_kwargs, fit_kwargs = separate_method_kwargs(
params['method_args'], sm.OLS, sm.OLS.fit)
model = sm.OLS(data['y'], X, **init_kwargs)
else:
if np.any(weights < 0):
raise ValueError(
"All weights must be greater than zero."
)
init_kwargs, fit_kwargs = separate_method_kwargs(
params['method_args'], sm.WLS, sm.WLS.fit)
model = sm.WLS(data['y'], X, weights=data['weight'], **init_kwargs)

results = model.fit(**fit_kwargs)
data = pd.DataFrame({'x': xseq})
Expand All @@ -74,10 +84,60 @@ def lm(data, xseq, **params):
return data


def lm_formula(data, xseq, **params):
"""
Fit OLS / WLS using a formula
"""
formula = params['formula']
eval_env = params['enviroment']
weights = data.get('weight', None)

if weights is None:
init_kwargs, fit_kwargs = separate_method_kwargs(
params['method_args'], sm.OLS, sm.OLS.fit)
model = smf.ols(
formula,
data,
eval_env=eval_env,
**init_kwargs
)
else:
if np.any(weights < 0):
raise ValueError(
"All weights must be greater than zero."
)
init_kwargs, fit_kwargs = separate_method_kwargs(
params['method_args'], sm.OLS, sm.OLS.fit)
model = smf.wls(
formula,
data,
weights=weights,
eval_env=eval_env,
**init_kwargs
)

results = model.fit(**fit_kwargs)
data = pd.DataFrame({'x': xseq})
data['y'] = results.predict(data)

if params['se']:
_, predictors = dmatrices(formula, data, eval_env=eval_env)
alpha = 1 - params['level']
prstd, iv_l, iv_u = wls_prediction_std(
results, predictors, alpha=alpha)
data['se'] = prstd
data['ymin'] = iv_l
data['ymax'] = iv_u
return data


def rlm(data, xseq, **params):
"""
Fit RLM
"""
if params['formula']:
return rlm_formula(data, xseq, **params)

X = sm.add_constant(data['x'])
Xseq = sm.add_constant(xseq)

Expand All @@ -96,10 +156,38 @@ def rlm(data, xseq, **params):
return data


def rlm_formula(data, xseq, **params):
"""
Fit RLM using a formula
"""
eval_env = params['enviroment']
formula = params['formula']
init_kwargs, fit_kwargs = separate_method_kwargs(
params['method_args'], sm.RLM, sm.RLM.fit)
model = smf.rlm(
formula,
data,
eval_env=eval_env,
**init_kwargs
)
results = model.fit(**fit_kwargs)
data = pd.DataFrame({'x': xseq})
data['y'] = results.predict(data)

if params['se']:
warnings.warn("Confidence intervals are not yet implemented"
"for RLM smoothing.", PlotnineWarning)

return data


def gls(data, xseq, **params):
"""
Fit GLS
"""
if params['formula']:
return gls_formula(data, xseq, **params)

X = sm.add_constant(data['x'])
Xseq = sm.add_constant(xseq)

Expand All @@ -122,10 +210,42 @@ def gls(data, xseq, **params):
return data


def gls_formula(data, xseq, **params):
"""
Fit GLL using a formula
"""
eval_env = params['enviroment']
formula = params['formula']
init_kwargs, fit_kwargs = separate_method_kwargs(
params['method_args'], sm.GLS, sm.GLS.fit)
model = smf.gls(
formula,
data,
eval_env=eval_env,
**init_kwargs
)
results = model.fit(**fit_kwargs)
data = pd.DataFrame({'x': xseq})
data['y'] = results.predict(data)

if params['se']:
_, predictors = dmatrices(formula, data, eval_env=eval_env)
alpha = 1 - params['level']
prstd, iv_l, iv_u = wls_prediction_std(
results, predictors, alpha=alpha)
data['se'] = prstd
data['ymin'] = iv_l
data['ymax'] = iv_u
return data


def glm(data, xseq, **params):
"""
Fit GLM
"""
if params['formula']:
return glm_formula(data, xseq, **params)

X = sm.add_constant(data['x'])
Xseq = sm.add_constant(xseq)

Expand All @@ -146,6 +266,29 @@ def glm(data, xseq, **params):
return data


def glm_formula(data, xseq, **params):
eval_env = params['enviroment']
init_kwargs, fit_kwargs = separate_method_kwargs(
params['method_args'], sm.GLM, sm.GLM.fit)
model = smf.glm(
params['formula'],
data,
eval_env=eval_env,
**init_kwargs
)
results = model.fit(**fit_kwargs)
data = pd.DataFrame({'x': xseq})
data['y'] = results.predict(data)

if params['se']:
df = pd.DataFrame({'x': xseq})
prediction = results.get_prediction(df)
ci = prediction.conf_int(1 - params['level'])
data['ymin'] = ci[:, 0]
data['ymax'] = ci[:, 1]
return data


def lowess(data, xseq, **params):
for k in ('is_sorted', 'return_sorted'):
with suppress(KeyError):
Expand Down
16 changes: 16 additions & 0 deletions plotnine/stats/stat_smooth.py
Expand Up @@ -76,6 +76,12 @@ def my_smoother(data, xseq, **params):
data['ymax'] = high
return data
formula : formula_like
An object that can be used to construct a patsy design matrix.
This is usually a string. You can only use a formula if ``method``
is one of *lm*, *ols*, *wls*, *glm*, *rlm* or *gls*, and in the
:ref:`formula <patsy:formulas>` you may refer to the ``x`` and
``y`` aesthetic variables.
se : bool (default: True)
If :py:`True` draw confidence interval around the smooth line.
n : int (default: 80)
Expand Down Expand Up @@ -131,6 +137,7 @@ def my_smoother(data, xseq, **params):
DEFAULT_PARAMS = {'geom': 'smooth', 'position': 'identity',
'na_rm': False,
'method': 'auto', 'se': True, 'n': 80,
'formula': None,
'fullrange': False, 'level': 0.95,
'span': 0.75, 'method_args': {}}
CREATES = {'se', 'ymin', 'ymax'}
Expand Down Expand Up @@ -168,6 +175,15 @@ def setup_params(self, data):
"facets".format(window), PlotnineWarning)
params['method_args']['window'] = window

if params['formula']:
allowed = {'lm', 'ols', 'wls', 'glm', 'rlm', 'gls'}
if params['method'] not in allowed:
raise ValueError(
"You can only use a formula with `method` is "
"one of {}".format(allowed)
)
params['enviroment'] = self.environment

return params

@classmethod
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
68 changes: 67 additions & 1 deletion plotnine/tests/test_geom_smooth.py
Expand Up @@ -4,7 +4,7 @@
import statsmodels.api as sm


from plotnine import ggplot, aes, geom_point, geom_smooth
from plotnine import ggplot, aes, geom_point, geom_smooth, stat_smooth
from plotnine.exceptions import PlotnineWarning


Expand Down Expand Up @@ -185,3 +185,69 @@ def test_init_and_fit_kwargs():
)

assert p == 'init_and_fit_kwargs'


n = 100
random_state = np.random.RandomState(123)
mu = 0
sigma = 0.065
noise = random_state.randn(n)*sigma + mu
df = pd.DataFrame({
'x': x,
'y': np.sin(x) + noise,
})


class TestFormula:

p = ggplot(df, aes('x', 'y')) + geom_point()

def test_lm(self):
p = (self.p
+ stat_smooth(
method='lm',
formula='y ~ np.sin(x)',
fill='red',
se=True
))
assert p == 'lm_formula'

def test_lm_weights(self):
p = (self.p
+ aes(weight='x.abs()')
+ stat_smooth(
method='lm',
formula='y ~ np.sin(x)',
fill='red',
se=True
))
assert p == 'lm_formula_weights'

def test_glm(self):
p = (self.p
+ stat_smooth(
method='glm',
formula='y ~ np.sin(x)',
fill='red',
se=True
))
assert p == 'glm_formula'

def test_rlm(self):
p = (self.p
+ stat_smooth(
method='rlm',
formula='y ~ np.sin(x)',
fill='red',
))
assert p == 'rlm_formula'

def test_gls(self):
p = (self.p
+ stat_smooth(
method='gls',
formula='y ~ np.sin(x)',
fill='red',
se=True
))
assert p == 'gls_formula'

0 comments on commit eeb8f66

Please sign in to comment.