Skip to content

Commit

Permalink
add a test case for 'validation_curve'
Browse files Browse the repository at this point in the history
  • Loading branch information
gxyd committed Dec 8, 2017
1 parent 5fe685c commit 2af3c9c
Showing 1 changed file with 81 additions and 4 deletions.
85 changes: 81 additions & 4 deletions sklearn/model_selection/tests/test_validation.py
Expand Up @@ -125,13 +125,56 @@ def partial_fit(self, X, y=None, **params):

class MockEstimatorWithParameter(BaseEstimator):
"""Dummy classifier to test the validation curve"""
def __init__(self, param=0.5):
def __init__(self, param=0.5, allow_nd=False):
self.X_subset = None
self.param = param
self.allow_nd = allow_nd

def fit(self, X_subset, y_subset):
self.X_subset = X_subset
self.train_sizes = X_subset.shape[0]
def fit(self, X, Y=None, sample_weight=None, class_prior=None,
sparse_sample_weight=None, sparse_param=None, dummy_int=None,
dummy_str=None, dummy_obj=None, callback=None):
"""The dummy arguments are to test that this fit function can
accept non-array arguments through cross-validation, such as:
- int
- str (this is actually array-like)
- object
- function
"""
self.X_subset = X
self.train_size = X.shape[0]
self.y = y
self.dummy_int = dummy_int
self.dummy_str = dummy_str
self.dummy_obj = dummy_obj
if callback is not None:
callback(self)

if self.allow_nd:
X = X.reshape(len(X), -1)
if X.ndim >= 3 and not self.allow_nd:
raise ValueError('X cannot be d')
if sample_weight is not None:
assert_true(sample_weight.shape[0] == X.shape[0],
'MockClassifier extra fit_param sample_weight.shape[0]'
' is {0}, should be {1}'.format(sample_weight.shape[0],
X.shape[0]))
if class_prior is not None:
assert_true(class_prior.shape[0] == len(np.unique(y)),
'MockClassifier extra fit_param class_prior.shape[0]'
' is {0}, should be {1}'.format(class_prior.shape[0],
len(np.unique(y))))
if sparse_sample_weight is not None:
fmt = ('MockClassifier extra fit_param sparse_sample_weight'
'.shape[0] is {0}, should be {1}')
assert_true(sparse_sample_weight.shape[0] == X.shape[0],
fmt.format(sparse_sample_weight.shape[0], X.shape[0]))
if sparse_param is not None:
fmt = ('MockClassifier extra fit_param sparse_param.shape '
'is ({0}, {1}), should be ({2}, {3})')
assert_true(sparse_param.shape == P_sparse.shape,
fmt.format(sparse_param.shape[0],
sparse_param.shape[1],
P_sparse.shape[0], P_sparse.shape[1]))
return self

def predict(self, X):
Expand Down Expand Up @@ -601,6 +644,40 @@ def assert_fit_params(clf):
cross_val_score(clf, X, y, fit_params=fit_params)


def test_validation_curve_fit_params():
clf = MockEstimatorWithParameter()
n_samples = X.shape[0]
n_classes = len(np.unique(y))

W_sparse = coo_matrix((np.array([1]), (np.array([1]), np.array([0]))),
shape=(10, 1))
P_sparse = coo_matrix(np.eye(5))

DUMMY_INT = 42
DUMMY_STR = '42'
DUMMY_OBJ = object()

def assert_fit_params(clf):
# Function to test that the values are passed correctly to the
# classifier arguments for non-array type

assert_equal(clf.dummy_int, DUMMY_INT)
assert_equal(clf.dummy_str, DUMMY_STR)
assert_equal(clf.dummy_obj, DUMMY_OBJ)

fit_params = {'sample_weight': np.ones(n_samples),
'class_prior': np.ones(n_classes) / n_classes,
'sparse_sample_weight': W_sparse,
'sparse_param': P_sparse,
'dummy_int': DUMMY_INT,
'dummy_str': DUMMY_STR,
'dummy_obj': DUMMY_OBJ,
'callback': assert_fit_params}
param_range = np.linspace(0, 1, 10)
validation_curve(clf, X, y, param_name="param", param_range=param_range,
fit_params=fit_params)


def test_cross_val_score_score_func():
clf = MockClassifier()
_score_func_args = []
Expand Down

0 comments on commit 2af3c9c

Please sign in to comment.