Skip to content

Commit

Permalink
be more explicit about which sklearn checks to run
Browse files Browse the repository at this point in the history
  • Loading branch information
Nico de Vos committed Jun 5, 2019
1 parent f8d3b25 commit 2bc7fce
Showing 1 changed file with 52 additions and 20 deletions.
72 changes: 52 additions & 20 deletions kmodes/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,66 @@
"""
General sklearn tests for the estimators in kmodes.
"""
from sklearn.utils.testing import assert_false
from sklearn.utils.testing import assert_greater
from sklearn.utils.estimator_checks import (
_yield_all_checks,
check_parameters_default_constructible)
check_parameters_default_constructible
)

from kmodes.kmodes import KModes
from kmodes.kprototypes import KPrototypes
from kmodes.util.testing import _named_check

all_estimators = lambda: (('kmodes', KModes), ('kprototypes', KPrototypes))

KMODES_INCLUDE_CHECKS = (
'check_estimators_dtypes',
'check_fit_score_takes_y',
'check_sample_weights_pandas_series',
'check_sample_weights_list',
'check_sample_weights_invariance',
'check_estimators_fit_returns_self',
'check_complex_data',
'check_estimators_empty_data_messages',
'check_pipeline_consistency',
'check_estimators_nan_inf',
'check_estimators_overwrite_params',
'check_estimator_sparse_data',
'check_estimators_overwrite_params',
'check_estimators_pickle',
'check_fit2d_predict1d',
'check_methods_subset_invariance',
'check_fit2d_1sample',
'check_fit2d_1feature',
'check_fit1d',
'check_get_params_invariance',
'check_set_params',
'check_dict_unchanged',
'check_dont_overwrite_parameters',
'check_fit_idempotent',
'check_clusterer_compute_labels_predict',
'check_estimators_partial_fit_n_features',
'check_non_transformer_estimators_n_iter',
)

KPROTOTYPES_INCLUDE_CHECKS = (
'check_sample_weights_pandas_series',
'check_sample_weights_list',
'check_sample_weights_invariance',
'check_estimator_sparse_data',
'check_get_params_invariance',
'check_set_params',
'check_clusterer_compute_labels_predict',
'check_estimators_partial_fit_n_features',
)


def test_all_estimator_no_base_class():
# test that all_estimators doesn't find abstract classes.
for name, Estimator in all_estimators():
msg = ("Base estimators such as {0} should not be included"
" in all_estimators").format(name)
assert_false(name.lower().startswith('base'), msg=msg)
assert not name.lower().startswith('base'), msg


def test_all_estimators():
Expand All @@ -37,22 +78,13 @@ def test_all_estimators():

def test_non_meta_estimators():
for name, Estimator in all_estimators():
estimator = Estimator()
if name == 'kmodes':
for check in _yield_all_checks(name, Estimator):
# Skip these
if hasattr(check, '__name__'):
if check.__name__ not in ('check_clustering',
'check_dtype_object'):
yield _named_check(check, name), name, estimator
else:
yield check, name, estimator
relevant_checks = KMODES_INCLUDE_CHECKS
elif name == 'kprototypes':
for check in _yield_all_checks(name, Estimator):
# Only do these
if hasattr(check, '__name__') and check.__name__ in (
'check_estimator_sparse_data',
'check_clusterer_compute_labels_predict',
'check_estimators_partial_fit_n_features'
):
yield _named_check(check, name), name, estimator
relevant_checks = KPROTOTYPES_INCLUDE_CHECKS
else:
raise NotImplementedError
estimator = Estimator()
for check in _yield_all_checks(name, estimator):
if hasattr(check, '__name__') and check.__name__ in relevant_checks:
yield _named_check(check, name), name, estimator

0 comments on commit 2bc7fce

Please sign in to comment.