Skip to content

Commit

Permalink
[MRG+1] MLPRegressor quits fitting too soon due to `self._no_improv…
Browse files Browse the repository at this point in the history
…ement_count` (scikit-learn#9457)
  • Loading branch information
nnadeau authored and maskani-moh committed Nov 15, 2017
1 parent 6704dd3 commit f0574b9
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 31 deletions.
26 changes: 14 additions & 12 deletions doc/modules/neural_networks_supervised.rst
Expand Up @@ -91,12 +91,13 @@ training samples::
...
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
MLPClassifier(activation='relu', alpha=1e-05, batch_size='auto',
beta_1=0.9, beta_2=0.999, early_stopping=False,
epsilon=1e-08, hidden_layer_sizes=(5, 2), learning_rate='constant',
learning_rate_init=0.001, max_iter=200, momentum=0.9,
nesterovs_momentum=True, power_t=0.5, random_state=1, shuffle=True,
solver='lbfgs', tol=0.0001, validation_fraction=0.1, verbose=False,
warm_start=False)
beta_1=0.9, beta_2=0.999, early_stopping=False,
epsilon=1e-08, hidden_layer_sizes=(5, 2),
learning_rate='constant', learning_rate_init=0.001,
max_iter=200, momentum=0.9, n_iter_no_change=10,
nesterovs_momentum=True, power_t=0.5, random_state=1,
shuffle=True, solver='lbfgs', tol=0.0001,
validation_fraction=0.1, verbose=False, warm_start=False)

After fitting (training), the model can predict labels for new samples::

Expand Down Expand Up @@ -139,12 +140,13 @@ indices where the value is `1` represents the assigned classes of that sample::
...
>>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE
MLPClassifier(activation='relu', alpha=1e-05, batch_size='auto',
beta_1=0.9, beta_2=0.999, early_stopping=False,
epsilon=1e-08, hidden_layer_sizes=(15,), learning_rate='constant',
learning_rate_init=0.001, max_iter=200, momentum=0.9,
nesterovs_momentum=True, power_t=0.5, random_state=1, shuffle=True,
solver='lbfgs', tol=0.0001, validation_fraction=0.1, verbose=False,
warm_start=False)
beta_1=0.9, beta_2=0.999, early_stopping=False,
epsilon=1e-08, hidden_layer_sizes=(15,),
learning_rate='constant', learning_rate_init=0.001,
max_iter=200, momentum=0.9, n_iter_no_change=10,
nesterovs_momentum=True, power_t=0.5, random_state=1,
shuffle=True, solver='lbfgs', tol=0.0001,
validation_fraction=0.1, verbose=False, warm_start=False)
>>> clf.predict([[1., 2.]])
array([[1, 1]])
>>> clf.predict([[0., 0.]])
Expand Down
20 changes: 20 additions & 0 deletions doc/whats_new/v0.20.rst
Expand Up @@ -18,6 +18,9 @@ random sampling procedures.
- :class:`decomposition.IncrementalPCA` in Python 2 (bug fix)
- :class:`isotonic.IsotonicRegression` (bug fix)
- :class:`metrics.roc_auc_score` (bug fix)
- :class:`neural_network.BaseMultilayerPerceptron` (bug fix)
- :class:`neural_network.MLPRegressor` (bug fix)
- :class:`neural_network.MLPClassifier` (bug fix)

Details are listed in the changelog below.

Expand Down Expand Up @@ -65,6 +68,13 @@ Classifiers and regressors
:class:`sklearn.naive_bayes.GaussianNB` to give a precise control over
variances calculation. :issue:`9681` by :user:`Dmitry Mottl <Mottl>`.

- Add `n_iter_no_change` parameter in
:class:`neural_network.BaseMultilayerPerceptron`,
:class:`neural_network.MLPRegressor`, and
:class:`neural_network.MLPClassifier` to give control over
maximum number of epochs to not meet ``tol`` improvement.
:issue:`9456` by :user:`Nicholas Nadeau <nnadeau>`.

- A parameter ``check_inverse`` was added to :class:`FunctionTransformer`
to ensure that ``func`` and ``inverse_func`` are the inverse of each
other.
Expand Down Expand Up @@ -96,6 +106,16 @@ Classifiers and regressors
identical X values.
:issue:`9432` by :user:`Dallas Card <dallascard>`

- Fixed a bug in :class:`neural_network.BaseMultilayerPerceptron`,
:class:`neural_network.MLPRegressor`, and
:class:`neural_network.MLPClassifier` with new ``n_iter_no_change``
parameter now at 10 from previously hardcoded 2.
:issue:`9456` by :user:`Nicholas Nadeau <nnadeau>`.

- Fixed a bug in :class:`neural_network.MLPRegressor` where fitting
quit unexpectedly early due to local minima or fluctuations.
:issue:`9456` by :user:`Nicholas Nadeau <nnadeau>`

- Fixed a bug in :class:`naive_bayes.GaussianNB` which incorrectly raised
error for prior list which summed to 1.
:issue:`10005` by :user:`Gaurav Dhingra <gxyd>`.
Expand Down
59 changes: 40 additions & 19 deletions sklearn/neural_network/multilayer_perceptron.py
Expand Up @@ -51,7 +51,8 @@ def __init__(self, hidden_layer_sizes, activation, solver,
alpha, batch_size, learning_rate, learning_rate_init, power_t,
max_iter, loss, shuffle, random_state, tol, verbose,
warm_start, momentum, nesterovs_momentum, early_stopping,
validation_fraction, beta_1, beta_2, epsilon):
validation_fraction, beta_1, beta_2, epsilon,
n_iter_no_change):
self.activation = activation
self.solver = solver
self.alpha = alpha
Expand All @@ -74,6 +75,7 @@ def __init__(self, hidden_layer_sizes, activation, solver,
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.n_iter_no_change = n_iter_no_change

def _unpack(self, packed_parameters):
"""Extract the coefficients and intercepts from packed_parameters."""
Expand Down Expand Up @@ -415,6 +417,9 @@ def _validate_hyperparameters(self):
self.beta_2)
if self.epsilon <= 0.0:
raise ValueError("epsilon must be > 0, got %s." % self.epsilon)
if self.n_iter_no_change <= 0:
raise ValueError("n_iter_no_change must be > 0, got %s."
% self.n_iter_no_change)

# raise ValueError if not registered
supported_activations = ('identity', 'logistic', 'tanh', 'relu')
Expand Down Expand Up @@ -537,15 +542,17 @@ def _fit_stochastic(self, X, y, activations, deltas, coef_grads,
# for learning rate that needs to be updated at iteration end
self._optimizer.iteration_ends(self.t_)

if self._no_improvement_count > 2:
# not better than last two iterations by tol.
if self._no_improvement_count > self.n_iter_no_change:
# not better than last `n_iter_no_change` iterations by tol
# stop or decrease learning rate
if early_stopping:
msg = ("Validation score did not improve more than "
"tol=%f for two consecutive epochs." % self.tol)
"tol=%f for %d consecutive epochs." % (
self.tol, self.n_iter_no_change))
else:
msg = ("Training loss did not improve more than tol=%f"
" for two consecutive epochs." % self.tol)
" for %d consecutive epochs." % (
self.tol, self.n_iter_no_change))

is_stopping = self._optimizer.trigger_stopping(
msg, self.verbose)
Expand Down Expand Up @@ -780,9 +787,9 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin):
tol : float, optional, default 1e-4
Tolerance for the optimization. When the loss or score is not improving
by at least tol for two consecutive iterations, unless `learning_rate`
is set to 'adaptive', convergence is considered to be reached and
training stops.
by at least ``tol`` for ``n_iter_no_change`` consecutive iterations,
unless ``learning_rate`` is set to 'adaptive', convergence is
considered to be reached and training stops.
verbose : bool, optional, default False
Whether to print progress messages to stdout.
Expand All @@ -804,8 +811,8 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin):
Whether to use early stopping to terminate training when validation
score is not improving. If set to true, it will automatically set
aside 10% of training data as validation and terminate training when
validation score is not improving by at least tol for two consecutive
epochs.
validation score is not improving by at least tol for
``n_iter_no_change`` consecutive epochs.
Only effective when solver='sgd' or 'adam'
validation_fraction : float, optional, default 0.1
Expand All @@ -824,6 +831,12 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin):
epsilon : float, optional, default 1e-8
Value for numerical stability in adam. Only used when solver='adam'
n_iter_no_change : int, optional, default 10
Maximum number of epochs to not meet ``tol`` improvement.
Only effective when solver='sgd' or 'adam'
.. versionadded:: 0.20
Attributes
----------
classes_ : array or list of array of shape (n_classes,)
Expand Down Expand Up @@ -890,7 +903,7 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu",
verbose=False, warm_start=False, momentum=0.9,
nesterovs_momentum=True, early_stopping=False,
validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
epsilon=1e-8):
epsilon=1e-8, n_iter_no_change=10):

sup = super(MLPClassifier, self)
sup.__init__(hidden_layer_sizes=hidden_layer_sizes,
Expand All @@ -903,7 +916,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu",
nesterovs_momentum=nesterovs_momentum,
early_stopping=early_stopping,
validation_fraction=validation_fraction,
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon)
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon,
n_iter_no_change=n_iter_no_change)

def _validate_input(self, X, y, incremental):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'],
Expand Down Expand Up @@ -1157,9 +1171,9 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin):
tol : float, optional, default 1e-4
Tolerance for the optimization. When the loss or score is not improving
by at least tol for two consecutive iterations, unless `learning_rate`
is set to 'adaptive', convergence is considered to be reached and
training stops.
by at least ``tol`` for ``n_iter_no_change`` consecutive iterations,
unless ``learning_rate`` is set to 'adaptive', convergence is
considered to be reached and training stops.
verbose : bool, optional, default False
Whether to print progress messages to stdout.
Expand All @@ -1181,8 +1195,8 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin):
Whether to use early stopping to terminate training when validation
score is not improving. If set to true, it will automatically set
aside 10% of training data as validation and terminate training when
validation score is not improving by at least tol for two consecutive
epochs.
validation score is not improving by at least ``tol`` for
``n_iter_no_change`` consecutive epochs.
Only effective when solver='sgd' or 'adam'
validation_fraction : float, optional, default 0.1
Expand All @@ -1201,6 +1215,12 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin):
epsilon : float, optional, default 1e-8
Value for numerical stability in adam. Only used when solver='adam'
n_iter_no_change : int, optional, default 10
Maximum number of epochs to not meet ``tol`` improvement.
Only effective when solver='sgd' or 'adam'
.. versionadded:: 0.20
Attributes
----------
loss_ : float
Expand Down Expand Up @@ -1265,7 +1285,7 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu",
verbose=False, warm_start=False, momentum=0.9,
nesterovs_momentum=True, early_stopping=False,
validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
epsilon=1e-8):
epsilon=1e-8, n_iter_no_change=10):

sup = super(MLPRegressor, self)
sup.__init__(hidden_layer_sizes=hidden_layer_sizes,
Expand All @@ -1278,7 +1298,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu",
nesterovs_momentum=nesterovs_momentum,
early_stopping=early_stopping,
validation_fraction=validation_fraction,
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon)
beta_1=beta_1, beta_2=beta_2, epsilon=epsilon,
n_iter_no_change=n_iter_no_change)

def predict(self, X):
"""Predict using the multi-layer perceptron model.
Expand Down
45 changes: 45 additions & 0 deletions sklearn/neural_network/tests/test_mlp.py
Expand Up @@ -420,6 +420,7 @@ def test_params_errors():
assert_raises(ValueError, clf(beta_2=1).fit, X, y)
assert_raises(ValueError, clf(beta_2=-0.5).fit, X, y)
assert_raises(ValueError, clf(epsilon=-0.5).fit, X, y)
assert_raises(ValueError, clf(n_iter_no_change=-1).fit, X, y)

assert_raises(ValueError, clf(solver='hadoken').fit, X, y)
assert_raises(ValueError, clf(learning_rate='converge').fit, X, y)
Expand Down Expand Up @@ -588,3 +589,47 @@ def test_warm_start():
'classes as in the previous call to fit.'
' Previously got [0 1 2], `y` has %s' % np.unique(y_i))
assert_raise_message(ValueError, message, clf.fit, X, y_i)


def test_n_iter_no_change():
# test n_iter_no_change using binary data set
# the classifying fitting process is not prone to loss curve fluctuations
X = X_digits_binary[:100]
y = y_digits_binary[:100]
tol = 0.01
max_iter = 3000

# test multiple n_iter_no_change
for n_iter_no_change in [2, 5, 10, 50, 100]:
clf = MLPClassifier(tol=tol, max_iter=max_iter, solver='sgd',
n_iter_no_change=n_iter_no_change)
clf.fit(X, y)

# validate n_iter_no_change
assert_equal(clf._no_improvement_count, n_iter_no_change + 1)
assert_greater(max_iter, clf.n_iter_)


@ignore_warnings(category=ConvergenceWarning)
def test_n_iter_no_change_inf():
# test n_iter_no_change using binary data set
# the fitting process should go to max_iter iterations
X = X_digits_binary[:100]
y = y_digits_binary[:100]

# set a ridiculous tolerance
# this should always trigger _update_no_improvement_count()
tol = 1e9

# fit
n_iter_no_change = np.inf
max_iter = 3000
clf = MLPClassifier(tol=tol, max_iter=max_iter, solver='sgd',
n_iter_no_change=n_iter_no_change)
clf.fit(X, y)

# validate n_iter_no_change doesn't cause early stopping
assert_equal(clf.n_iter_, max_iter)

# validate _update_no_improvement_count() was always triggered
assert_equal(clf._no_improvement_count, clf.n_iter_ - 1)

0 comments on commit f0574b9

Please sign in to comment.