Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for #158 #159

Merged
merged 13 commits into from
Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 52 additions & 3 deletions sklego/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,52 @@ def __init__(self, estimator, groups, use_fallback=True):
self.groups = groups
self.use_fallback = use_fallback

def __check_group_cols_exist(self, X):
try:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking the type of X is probably better to do via isinstance instead waiting for an AttributeError.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in this case the intent is clearer with an isinstance check.

x_cols = set(X.columns)
except AttributeError:
try:
ncols = X.shape[1]
except IndexError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is X ever only 1d?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it could be. groupby-mean as a model?

ncols = 1
x_cols = set(range(ncols))

# Check whether grouping columns exist
diff = set(as_list(self.groups)) - x_cols

if len(diff) > 0:
raise KeyError(f'{diff} not in columns of X ({x_cols})')

@staticmethod
def __check_missing_and_inf(X):
"""Check that all elements of X are non-missing and finite"""
koaning marked this conversation as resolved.
Show resolved Hide resolved
if np.any(pd.isnull(X)):
raise ValueError("X has NaN values")
try: # if X cannot be converted to numeric, checking infinites does not make sense
if np.any(np.isinf(X)):
raise ValueError("X has infinite values")
except TypeError:
pass

def __validate(self, X, y=None):
# Split the model data from the grouping columns
X_data = self.__remove_groups_from_x(X)

# We want to use __validate in both fit and predict, so y can be None
if y is not None:
koaning marked this conversation as resolved.
Show resolved Hide resolved
check_X_y(X_data, y)
else:
check_array(X_data)

self.__check_missing_and_inf(X)
self.__check_group_cols_exist(X)

def __remove_groups_from_x(self, X):
try:
return X.drop(columns=self.groups, inplace=False)
except AttributeError: # np.array
return np.delete(X, self.groups, axis=1)

def fit(self, X, y):
"""
Fit the model using X, y as training data. Will also learn the groups
Expand All @@ -63,15 +109,17 @@ def fit(self, X, y):
:param y: array-like, shape=(n_samples,) training data.
:return: Returns an instance of self.
"""
check_X_y(X, y)
self.__validate(X, y)

pred_col = 'the-column-that-i-want-to-predict-but-dont-have-the-name-for'
if isinstance(X, np.ndarray):
X = pd.DataFrame(X, columns=[str(_) for _ in range(X.shape[1])])
X = X.assign(**{pred_col: y})

self.group_colnames_ = [str(_) for _ in as_list(self.groups)]
if any([c not in X.columns for c in self.group_colnames_]):
raise ValueError(f"{self.group_colnames_} not in {X.columns}")
raise KeyError(f"{self.group_colnames_} not in {X.columns}")

self.X_colnames_ = [_ for _ in X.columns if _ not in self.group_colnames_ and _ is not pred_col]
self.fallback_ = None
if self.use_fallback:
Expand All @@ -93,7 +141,8 @@ def predict(self, X):
:param X: array-like, shape=(n_columns, n_samples,) training data.
:return: array, shape=(n_samples,) the predicted data
"""
check_array(X)
self.__validate(X)

check_is_fitted(self, ['estimators_', 'groups_', 'group_colnames_',
'X_colnames_', 'fallback_'])
if isinstance(X, np.ndarray):
Expand Down
21 changes: 18 additions & 3 deletions tests/test_meta/test_grouped_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import pandas as pd
import numpy as np
from sklearn.utils import estimator_checks
from sklearn.linear_model import LinearRegression

Expand All @@ -15,9 +16,8 @@
estimator_checks.check_estimators_nan_inf,
estimator_checks.check_estimators_overwrite_params,
estimator_checks.check_estimators_pickle,
estimator_checks.check_fit2d_predict1d,
estimator_checks.check_fit2d_1sample,
estimator_checks.check_fit1d,
koaning marked this conversation as resolved.
Show resolved Hide resolved
# estimator_checks.check_fit1d not tested because in 1d we cannot have both groups and data
estimator_checks.check_dont_overwrite_parameters,
estimator_checks.check_sample_weights_invariance,
estimator_checks.check_get_params_invariance,
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_chickweight_raise_error_cols_missing1():
df = load_chicken(give_pandas=True)
mod = GroupedEstimator(estimator=LinearRegression(), groups="diet")
mod.fit(df[['time', 'diet']], df['weight'])
with pytest.raises(ValueError):
with pytest.raises(KeyError):
mod.predict(df[['time', 'chick']])


Expand All @@ -90,3 +90,18 @@ def test_chickweight_np_keys():
mod.fit(df[['time', 'chick', 'diet']].values, df['weight'].values)
# there should still only be 50 groups on this dataset
assert len(mod.estimators_.keys()) == 50


def test_chickweigt_string_groups():

df = load_chicken(give_pandas=True)
df['diet'] = ['omgomgomg' + s for s in df['diet'].astype(str)]

X = df[['time', 'diet']]
X_np = np.array(X)

y = df['weight']

# This should NOT raise errors
GroupedEstimator(LinearRegression(), groups=['diet']).fit(X, y).predict(X)
GroupedEstimator(LinearRegression(), groups=1).fit(X_np, y).predict(X_np)