Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OutlierRemover and tests #148

Merged
merged 9 commits into from Jun 28, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .appveyor.yml
Expand Up @@ -17,7 +17,7 @@ install:
- conda config --set always_yes yes --set changeps1 no
- conda create -n sklego-env -c pytorch %REQUIREMENTS% python=%PYTHON_VERSION%
- activate sklego-env
- pip install -e .
- pip install -e .[dev]

test_script:
- activate sklego-env
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -9,7 +9,7 @@
docs_packages = ["sphinx>=1.8.5", "sphinx_rtd_theme>=0.4.3"]
dev_packages = docs_packages + ["flake8>=3.6.0", "matplotlib>=3.0.2", "pytest>=4.0.2",
"nbval>=0.9.1", "plotnine>=0.5.1", "jupyter>=1.0.0",
"jupyterlab>=0.35.4", "pytest-cov>=2.6.1"]
"jupyterlab>=0.35.4", "pytest-cov>=2.6.1", "pytest-mock>=1.6.3"]


def read(fname):
Expand Down
7 changes: 5 additions & 2 deletions sklego/common.py
Expand Up @@ -45,9 +45,12 @@ class TrainOnlyTransformerMixin(TransformerMixin):

}

def fit(self, X, y):
def fit(self, X, y=None):
"""Calculates the hash of X_train"""
check_X_y(X, y, estimator=self)
if y is None:
check_array(X, estimator=self)
else:
check_X_y(X, y, estimator=self)
self.X_hash_ = self._hash(X)
self.dim_ = X.shape[1]
return self
Expand Down
28 changes: 27 additions & 1 deletion sklego/meta.py
Expand Up @@ -4,7 +4,7 @@
from sklearn.base import BaseEstimator, TransformerMixin, MetaEstimatorMixin
from sklearn.utils.validation import check_is_fitted, check_X_y, check_array, FLOAT_DTYPES

from sklego.common import as_list
from sklego.common import as_list, TrainOnlyTransformerMixin


class EstimatorTransformer(TransformerMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down Expand Up @@ -119,6 +119,32 @@ def predict(self, X):
raise ValueError(f"found a group(s) {culprits} in `.predict` that was not in `.fit`")


class OutlierRemover(TrainOnlyTransformerMixin, BaseEstimator):
"""
Removes outliers (train-time only) using the supplied removal model.

:param outlier_detector: must implement `fit` and `predict` methods
:param refit: If True, fits the estimator during pipeline.fit().

"""
def __init__(self, outlier_detector, refit=True):
self.outlier_detector = outlier_detector
self.refit = refit

def fit(self, X, y=None):
self.estimator_ = clone(self.outlier_detector)
if self.refit:
super().fit(X, y)
self.estimator_.fit(X, y)
return self

def transform_train(self, X):
check_is_fitted(self, 'estimator_')
predictions = self.estimator_.predict(X)
check_array(predictions, estimator=self.outlier_detector)
return X[predictions.squeeze() != -1]


class DecayEstimator(BaseEstimator):
"""
Morphs an estimator suchs that the training weights can be
Expand Down
48 changes: 48 additions & 0 deletions tests/test_meta/test_outlier_remover.py
@@ -0,0 +1,48 @@
import pytest
from pandas.tests.extension.numpy_.test_numpy_nested import np
from sklearn.utils import estimator_checks

from sklego.common import flatten
from sklego.meta import OutlierRemover
from sklego.mixture import GMMOutlierDetector


@pytest.mark.parametrize("test_fn", flatten([
estimator_checks.check_transformers_unfitted,
estimator_checks.check_fit2d_predict1d,
estimator_checks.check_fit2d_1sample,
estimator_checks.check_fit2d_1feature,
estimator_checks.check_fit1d,
estimator_checks.check_get_params_invariance,
estimator_checks.check_set_params,
estimator_checks.check_dont_overwrite_parameters,
estimator_checks.check_transformers_unfitted
]))
def test_estimator_checks(test_fn):
outlier_remover = OutlierRemover(outlier_detector=GMMOutlierDetector(), refit=True)
jcshoekstra marked this conversation as resolved.
Show resolved Hide resolved
test_fn(OutlierRemover.__name__, outlier_remover)


@pytest.fixture
def mock_outlier_detector(mocker):
return mocker.Mock()


def test_no_outliers(mock_outlier_detector, mocker):
mock_outlier_detector.fit.return_value = None
mock_outlier_detector.predict.return_value = np.array([[1, 1]])
mocker.patch('sklego.meta.clone').return_value = mock_outlier_detector

outlier_remover = OutlierRemover(outlier_detector=mock_outlier_detector, refit=True)
outlier_remover.fit(X=np.array([[1, 1], [2, 2]]))
assert len(outlier_remover.transform_train(np.array([[1, 1], [2, 2]]))) == 2


jcshoekstra marked this conversation as resolved.
Show resolved Hide resolved
def test_remove_outlier(mock_outlier_detector, mocker):
mock_outlier_detector.fit.return_value = None
mock_outlier_detector.predict.return_value = np.array([[-1]])
mocker.patch('sklego.meta.clone').return_value = mock_outlier_detector

outlier_remover = OutlierRemover(outlier_detector=mock_outlier_detector, refit=True)
outlier_remover.fit(X=np.array([[5, 5]]))
assert len(outlier_remover.transform_train(np.array([[0, 0]]))) == 0