Skip to content

Commit

Permalink
grid_search: add sample_weight support
Browse files Browse the repository at this point in the history
  • Loading branch information
ndawe committed May 14, 2014
1 parent edbe7b7 commit 3da7fb7
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 29 deletions.
31 changes: 24 additions & 7 deletions sklearn/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,8 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
return np.array(scores)[:, 0]


def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
def _fit_and_score(estimator, X, y, sample_weight,
scorer, train, test, verbose, parameters,
fit_params, return_train_score=False,
return_parameters=False):
"""Fit estimator and compute scores for a given dataset split.
Expand All @@ -1159,10 +1160,13 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
X : array-like of shape at least 2D
The data to fit.
y : array-like, optional, default: None
y : array-like or None
The target variable to try to predict in the case of
supervised learning.
sample_weight : array-like or None
Sample weights.
scoring : callable
A scorer callable object / function with signature
``scorer(estimator, X, y)``.
Expand Down Expand Up @@ -1227,13 +1231,26 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,

X_train, y_train = _safe_split(estimator, X, y, train)
X_test, y_test = _safe_split(estimator, X, y, test, train)

test_score_params = dict()
train_score_params = dict()
if sample_weight is not None:
# move to _safe_split?
sample_weight_train = sample_weight[safe_mask(sample_weight, train)]
sample_weight_test = sample_weight[safe_mask(sample_weight, test)]
fit_params['sample_weight'] = sample_weight_train
test_score_params['sample_weight'] = sample_weight_test
train_score_params['sample_weight'] = sample_weight_train

if y_train is None:
estimator.fit(X_train, **fit_params)
else:
estimator.fit(X_train, y_train, **fit_params)
test_score = _score(estimator, X_test, y_test, scorer)
test_score = _score(estimator, X_test, y_test, scorer,
**test_score_params)
if return_train_score:
train_score = _score(estimator, X_train, y_train, scorer)
train_score = _score(estimator, X_train, y_train, scorer,
**train_score_params)

scoring_time = time.time() - start_time

Expand Down Expand Up @@ -1282,12 +1299,12 @@ def _safe_split(estimator, X, y, indices, train_indices=None):
return X_subset, y_subset


def _score(estimator, X_test, y_test, scorer):
def _score(estimator, X_test, y_test, scorer, **params):
"""Compute the score of an estimator on a given test set."""
if y_test is None:
score = scorer(estimator, X_test)
score = scorer(estimator, X_test, **params)
else:
score = scorer(estimator, X_test, y_test)
score = scorer(estimator, X_test, y_test, **params)
if not isinstance(score, numbers.Number):
raise ValueError("scoring must return a number, got %s (%s) instead."
% (str(score), type(score)))
Expand Down
56 changes: 40 additions & 16 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Gael Varoquaux <gael.varoquaux@normalesup.org>
# Andreas Mueller <amueller@ais.uni-bonn.de>
# Olivier Grisel <olivier.grisel@ensta.org>
# Noel Dawe <noel@dawe.me>
# License: BSD 3 clause

from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -228,7 +229,8 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
n_samples_test : int
Number of test samples in this split.
"""
score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train,
score, n_samples_test, _ = _fit_and_score(estimator, X, y, None,
scorer, train,
test, verbose, parameters,
fit_params)
return score, parameters, n_samples_test
Expand Down Expand Up @@ -295,7 +297,7 @@ def __init__(self, estimator, scoring=None, loss_func=None,
self.verbose = verbose
self.pre_dispatch = pre_dispatch

def score(self, X, y=None):
def score(self, X, y=None, sample_weight=None):
"""Returns the score on the given test data and labels, if the search
estimator has been refit. The ``score`` function of the best estimator
is used, or the ``scoring`` parameter where unavailable.
Expand All @@ -310,18 +312,24 @@ def score(self, X, y=None):
Target relative to X for classification or regression;
None for unsupervised learning.
sample_weight : array-like, shape = [n_samples], optional
Sample weights.
Returns
-------
score : float
"""
kwargs = {}
if sample_weight is not None:
kwargs['sample_weight'] = sample_weight
if hasattr(self.best_estimator_, 'score'):
return self.best_estimator_.score(X, y)
return self.best_estimator_.score(X, y, **kwargs)
if self.scorer_ is None:
raise ValueError("No score function explicitly defined, "
"and the estimator doesn't provide one %s"
% self.best_estimator_)
return self.scorer_(self.best_estimator_, X, y)
return self.scorer_(self.best_estimator_, X, y, **kwargs)

@property
def predict(self):
Expand All @@ -339,7 +347,7 @@ def decision_function(self):
def transform(self):
return self.best_estimator_.transform

def _fit(self, X, y, parameter_iterable):
def _fit(self, X, y, sample_weight, parameter_iterable):
"""Actual fitting, performing the search over parameters."""

estimator = self.estimator
Expand All @@ -349,15 +357,21 @@ def _fit(self, X, y, parameter_iterable):
score_func=self.score_func)

n_samples = _num_samples(X)
X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr',
allow_nans=True)
X, y, sample_weight = check_arrays(X, y, sample_weight,
allow_lists=True,
sparse_format='csr',
allow_nans=True)

if y is not None:
if len(y) != n_samples:
raise ValueError('Target variable (y) has a different number '
'of samples (%i) than data (X: %i samples)'
% (len(y), n_samples))
y = np.asarray(y)

if sample_weight is not None:
sample_weight = np.asarray(sample_weight)

cv = check_cv(cv, X, y, classifier=is_classifier(estimator))

if self.verbose > 0:
Expand All @@ -375,9 +389,10 @@ def _fit(self, X, y, parameter_iterable):
n_jobs=self.n_jobs, verbose=self.verbose,
pre_dispatch=pre_dispatch
)(
delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
train, test, self.verbose, parameters,
self.fit_params, return_parameters=True)
delayed(_fit_and_score)(clone(base_estimator), X, y, sample_weight,
self.scorer_, train, test,
self.verbose, parameters, self.fit_params,
return_parameters=True)
for parameters in parameter_iterable
for train, test in cv)

Expand Down Expand Up @@ -419,14 +434,18 @@ def _fit(self, X, y, parameter_iterable):
self.best_score_ = best.mean_validation_score

if self.refit:
fit_params = self.fit_params
if sample_weight is not None:
fit_params = fit_params.copy()
fit_params['sample_weight'] = sample_weight
# fit the best estimator using the entire dataset
# clone first to work around broken estimators
best_estimator = clone(base_estimator).set_params(
**best.parameters)
if y is not None:
best_estimator.fit(X, y, **self.fit_params)
best_estimator.fit(X, y, **fit_params)
else:
best_estimator.fit(X, **self.fit_params)
best_estimator.fit(X, **fit_params)
self.best_estimator_ = best_estimator
return self

Expand Down Expand Up @@ -581,7 +600,7 @@ def __init__(self, estimator, param_grid, scoring=None, loss_func=None,
self.param_grid = param_grid
_check_param_grid(param_grid)

def fit(self, X, y=None):
def fit(self, X, y=None, sample_weight=None):
"""Run fit with all sets of parameters.
Parameters
Expand All @@ -595,8 +614,10 @@ def fit(self, X, y=None):
Target relative to X for classification or regression;
None for unsupervised learning.
sample_weight : array-like, shape = [n_samples], optional
Sample weights.
"""
return self._fit(X, y, ParameterGrid(self.param_grid))
return self._fit(X, y, sample_weight, ParameterGrid(self.param_grid))


class RandomizedSearchCV(BaseSearchCV):
Expand Down Expand Up @@ -732,7 +753,7 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch)

def fit(self, X, y=None):
def fit(self, X, y=None, sample_weight=None):
"""Run fit on the estimator with randomly drawn parameters.
Parameters
Expand All @@ -745,8 +766,11 @@ def fit(self, X, y=None):
Target relative to X for classification or regression;
None for unsupervised learning.
sample_weight : array-like, shape = [n_samples], optional
Sample weights.
"""
sampled_params = ParameterSampler(self.param_distributions,
self.n_iter,
random_state=self.random_state)
return self._fit(X, y, sampled_params)
return self._fit(X, y, sample_weight, sampled_params)
8 changes: 4 additions & 4 deletions sklearn/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,12 +866,12 @@ def test_safe_split_with_precomputed_kernel():
cv = cval.ShuffleSplit(X.shape[0], test_size=0.25, random_state=0)
tr, te = list(cv)[0]

X_tr, y_tr = cval._safe_split(clf, X, y, tr)
K_tr, y_tr2 = cval._safe_split(clfp, K, y, tr)
X_tr, y_tr, _ = cval._safe_split(clf, X, y, None, tr)
K_tr, y_tr2, _ = cval._safe_split(clfp, K, y, None, tr)
assert_array_almost_equal(K_tr, np.dot(X_tr, X_tr.T))

X_te, y_te = cval._safe_split(clf, X, y, te, tr)
K_te, y_te2 = cval._safe_split(clfp, K, y, te, tr)
X_te, y_te, _ = cval._safe_split(clf, X, y, None, te, tr)
K_te, y_te2, _ = cval._safe_split(clfp, K, y, None, te, tr)
assert_array_almost_equal(K_te, np.dot(X_te, X_tr.T))


Expand Down
27 changes: 25 additions & 2 deletions sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,13 @@ class MockClassifier(object):
def __init__(self, foo_param=0):
self.foo_param = foo_param

def fit(self, X, Y):
def fit(self, X, Y, sample_weight=None):
assert_true(len(X) == len(Y))
if sample_weight is not None:
assert_true(len(sample_weight) == len(X),
'MockClassifier sample_weight.shape[0]'
' is {0}, should be {1}'.format(len(sample_weight),
len(X)))
return self

def predict(self, T):
Expand All @@ -61,7 +66,12 @@ def predict(self, T):
decision_function = predict
transform = predict

def score(self, X=None, Y=None):
def score(self, X=None, Y=None, sample_weight=None):
if X is not None and sample_weight is not None:
assert_true(len(sample_weight) == len(X),
'MockClassifier sample_weight.shape[0]'
' is {0}, should be {1}'.format(len(sample_weight),
len(X)))
if self.foo_param > 1:
score = 1.
else:
Expand Down Expand Up @@ -115,6 +125,7 @@ def score(self):

X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
y = np.array([1, 1, 2, 2])
sample_weight = np.array([1, 2, 3, 4])


def test_parameter_grid():
Expand Down Expand Up @@ -668,3 +679,15 @@ def test_grid_search_allows_nans():
('classifier', MockClassifier()),
])
gs = GridSearchCV(p, {'classifier__foo_param': [1, 2, 3]}, cv=2).fit(X, y)


def test_grid_search_with_sample_weights():
"""Test grid searching with sample weights"""
est_parameters = {"foo_param": [1, 2, 3]}
cv = KFold(y.shape[0], n_folds=2, random_state=0)
for search_cls in (GridSearchCV, RandomizedSearchCV):
grid_search = search_cls(MockClassifier(), est_parameters, cv=cv)
grid_search.fit(X, y, sample_weight=sample_weight)
# check that sample_weight can be a list
grid_search = GridSearchCV(MockClassifier(), est_parameters, cv=cv)
grid_search.fit(X, y, sample_weight=sample_weight.tolist())

0 comments on commit 3da7fb7

Please sign in to comment.