Skip to content

Commit

Permalink
start adding fit_params in learning_curve
Browse files Browse the repository at this point in the history
  • Loading branch information
gxyd committed Dec 9, 2017
1 parent 2af3c9c commit d0aacf8
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sklearn/model_selection/_validation.py
Expand Up @@ -975,7 +975,7 @@ def learning_curve(estimator, X, y, groups=None,
train_sizes=np.linspace(0.1, 1.0, 5), cv=None, scoring=None,
exploit_incremental_learning=False, n_jobs=1,
pre_dispatch="all", verbose=0, shuffle=False,
random_state=None):
random_state=None, fit_params=None):
"""Learning curve.
Determines cross-validated training and test scores for different training
Expand Down Expand Up @@ -1062,6 +1062,9 @@ def learning_curve(estimator, X, y, groups=None,
If None, the random number generator is the RandomState instance used
by `np.random`. Used when ``shuffle`` is True.
fit_params : dict, optional
Parameters to pass to the fit method of the estimator.
Returns
-------
train_sizes_abs : array, shape = (n_unique_ticks,), dtype int
Expand Down Expand Up @@ -1121,7 +1124,8 @@ def learning_curve(estimator, X, y, groups=None,

out = parallel(delayed(_fit_and_score)(
clone(estimator), X, y, scorer, train, test,
verbose, parameters=None, fit_params=None, return_train_score=True)
verbose, parameters=None, fit_params=fit_params,
return_train_score=True)
for train, test in train_test_proportions)
out = np.array(out)
n_cv_folds = out.shape[0] // n_unique_ticks
Expand Down

0 comments on commit d0aacf8

Please sign in to comment.