forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 16
/
_passive_aggressive.py
471 lines (372 loc) · 16.9 KB
/
_passive_aggressive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
# Authors: Rob Zinkov, Mathieu Blondel
# License: BSD 3 clause
from ..utils.validation import _deprecate_positional_args
from ._stochastic_gradient import BaseSGDClassifier
from ._stochastic_gradient import BaseSGDRegressor
from ._stochastic_gradient import DEFAULT_EPSILON
class PassiveAggressiveClassifier(BaseSGDClassifier):
"""Passive Aggressive Classifier
Read more in the :ref:`User Guide <passive_aggressive>`.
Parameters
----------
C : float
Maximum step size (regularization). Defaults to 1.0.
fit_intercept : bool, default=False
Whether the intercept should be estimated or not. If False, the
data is assumed to be already centered.
max_iter : int, optional (default=1000)
The maximum number of passes over the training data (aka epochs).
It only impacts the behavior in the ``fit`` method, and not the
:meth:`partial_fit` method.
.. versionadded:: 0.19
tol : float or None, optional (default=1e-3)
The stopping criterion. If it is not None, the iterations will stop
when (loss > previous_loss - tol).
.. versionadded:: 0.19
early_stopping : bool, default=False
Whether to use early stopping to terminate training when validation.
score is not improving. If set to True, it will automatically set aside
a stratified fraction of training data as validation and terminate
training when validation score is not improving by at least tol for
n_iter_no_change consecutive epochs.
.. versionadded:: 0.20
validation_fraction : float, default=0.1
The proportion of training data to set aside as validation set for
early stopping. Must be between 0 and 1.
Only used if early_stopping is True.
.. versionadded:: 0.20
n_iter_no_change : int, default=5
Number of iterations with no improvement to wait before early stopping.
.. versionadded:: 0.20
shuffle : bool, default=True
Whether or not the training data should be shuffled after each epoch.
verbose : integer, optional
The verbosity level
loss : string, optional
The loss function to be used:
hinge: equivalent to PA-I in the reference paper.
squared_hinge: equivalent to PA-II in the reference paper.
n_jobs : int or None, optional (default=None)
The number of CPUs to use to do the OVA (One Versus All, for
multi-class problems) computation.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
random_state : int, RandomState instance, default=None
Used to shuffle the training data, when ``shuffle`` is set to
``True``. Pass an int for reproducible output across multiple
function calls.
See :term:`Glossary <random_state>`.
warm_start : bool, optional
When set to True, reuse the solution of the previous call to fit as
initialization, otherwise, just erase the previous solution.
See :term:`the Glossary <warm_start>`.
Repeatedly calling fit or partial_fit when warm_start is True can
result in a different solution than when calling fit a single time
because of the way the data is shuffled.
class_weight : dict, {class_label: weight} or "balanced" or None, optional
Preset for the class_weight fit parameter.
Weights associated with classes. If not given, all classes
are supposed to have weight one.
The "balanced" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``
.. versionadded:: 0.17
parameter *class_weight* to automatically weight samples.
average : bool or int, optional
When set to True, computes the averaged SGD weights and stores the
result in the ``coef_`` attribute. If set to an int greater than 1,
averaging will begin once the total number of samples seen reaches
average. So average=10 will begin averaging after seeing 10 samples.
.. versionadded:: 0.19
parameter *average* to use weights averaging in SGD
Attributes
----------
coef_ : array, shape = [1, n_features] if n_classes == 2 else [n_classes,\
n_features]
Weights assigned to the features.
intercept_ : array, shape = [1] if n_classes == 2 else [n_classes]
Constants in decision function.
n_iter_ : int
The actual number of iterations to reach the stopping criterion.
For multiclass fits, it is the maximum over every binary fit.
classes_ : array of shape (n_classes,)
The unique classes labels.
t_ : int
Number of weight updates performed during training.
Same as ``(n_iter_ * n_samples)``.
loss_function_ : callable
Loss function used by the algorithm.
Examples
--------
>>> from sklearn.linear_model import PassiveAggressiveClassifier
>>> from sklearn.datasets import make_classification
>>> X, y = make_classification(n_features=4, random_state=0)
>>> clf = PassiveAggressiveClassifier(max_iter=1000, random_state=0,
... tol=1e-3)
>>> clf.fit(X, y)
PassiveAggressiveClassifier(random_state=0)
>>> print(clf.coef_)
[[0.26642044 0.45070924 0.67251877 0.64185414]]
>>> print(clf.intercept_)
[1.84127814]
>>> print(clf.predict([[0, 0, 0, 0]]))
[1]
See also
--------
SGDClassifier
Perceptron
References
----------
Online Passive-Aggressive Algorithms
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
"""
@_deprecate_positional_args
def __init__(self, *, C=1.0, fit_intercept=True, max_iter=1000, tol=1e-3,
early_stopping=False, validation_fraction=0.1,
n_iter_no_change=5, shuffle=True, verbose=0, loss="hinge",
n_jobs=None, random_state=None, warm_start=False,
class_weight=None, average=False):
super().__init__(
penalty=None,
fit_intercept=fit_intercept,
max_iter=max_iter,
tol=tol,
early_stopping=early_stopping,
validation_fraction=validation_fraction,
n_iter_no_change=n_iter_no_change,
shuffle=shuffle,
verbose=verbose,
random_state=random_state,
eta0=1.0,
warm_start=warm_start,
class_weight=class_weight,
average=average,
n_jobs=n_jobs)
self.C = C
self.loss = loss
def partial_fit(self, X, y, classes=None):
"""Fit linear model with Passive Aggressive algorithm.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Subset of the training data
y : numpy array of shape [n_samples]
Subset of the target values
classes : array, shape = [n_classes]
Classes across all calls to partial_fit.
Can be obtained by via `np.unique(y_all)`, where y_all is the
target vector of the entire dataset.
This argument is required for the first call to partial_fit
and can be omitted in the subsequent calls.
Note that y doesn't need to contain all labels in `classes`.
Returns
-------
self : returns an instance of self.
"""
self._validate_params(for_partial_fit=True)
if self.class_weight == 'balanced':
raise ValueError("class_weight 'balanced' is not supported for "
"partial_fit. For 'balanced' weights, use "
"`sklearn.utils.compute_class_weight` with "
"`class_weight='balanced'`. In place of y you "
"can use a large enough subset of the full "
"training set target to properly estimate the "
"class frequency distributions. Pass the "
"resulting weights as the class_weight "
"parameter.")
lr = "pa1" if self.loss == "hinge" else "pa2"
return self._partial_fit(X, y, alpha=1.0, C=self.C,
loss="hinge", learning_rate=lr, max_iter=1,
classes=classes, sample_weight=None,
coef_init=None, intercept_init=None)
def fit(self, X, y, coef_init=None, intercept_init=None):
"""Fit linear model with Passive Aggressive algorithm.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training data
y : numpy array of shape [n_samples]
Target values
coef_init : array, shape = [n_classes,n_features]
The initial coefficients to warm-start the optimization.
intercept_init : array, shape = [n_classes]
The initial intercept to warm-start the optimization.
Returns
-------
self : returns an instance of self.
"""
self._validate_params()
lr = "pa1" if self.loss == "hinge" else "pa2"
return self._fit(X, y, alpha=1.0, C=self.C,
loss="hinge", learning_rate=lr,
coef_init=coef_init, intercept_init=intercept_init)
class PassiveAggressiveRegressor(BaseSGDRegressor):
"""Passive Aggressive Regressor
Read more in the :ref:`User Guide <passive_aggressive>`.
Parameters
----------
C : float
Maximum step size (regularization). Defaults to 1.0.
fit_intercept : bool
Whether the intercept should be estimated or not. If False, the
data is assumed to be already centered. Defaults to True.
max_iter : int, optional (default=1000)
The maximum number of passes over the training data (aka epochs).
It only impacts the behavior in the ``fit`` method, and not the
:meth:`partial_fit` method.
.. versionadded:: 0.19
tol : float or None, optional (default=1e-3)
The stopping criterion. If it is not None, the iterations will stop
when (loss > previous_loss - tol).
.. versionadded:: 0.19
early_stopping : bool, default=False
Whether to use early stopping to terminate training when validation.
score is not improving. If set to True, it will automatically set aside
a fraction of training data as validation and terminate
training when validation score is not improving by at least tol for
n_iter_no_change consecutive epochs.
.. versionadded:: 0.20
validation_fraction : float, default=0.1
The proportion of training data to set aside as validation set for
early stopping. Must be between 0 and 1.
Only used if early_stopping is True.
.. versionadded:: 0.20
n_iter_no_change : int, default=5
Number of iterations with no improvement to wait before early stopping.
.. versionadded:: 0.20
shuffle : bool, default=True
Whether or not the training data should be shuffled after each epoch.
verbose : integer, optional
The verbosity level
loss : string, optional
The loss function to be used:
epsilon_insensitive: equivalent to PA-I in the reference paper.
squared_epsilon_insensitive: equivalent to PA-II in the reference
paper.
epsilon : float
If the difference between the current prediction and the correct label
is below this threshold, the model is not updated.
random_state : int, RandomState instance, default=None
Used to shuffle the training data, when ``shuffle`` is set to
``True``. Pass an int for reproducible output across multiple
function calls.
See :term:`Glossary <random_state>`.
warm_start : bool, optional
When set to True, reuse the solution of the previous call to fit as
initialization, otherwise, just erase the previous solution.
See :term:`the Glossary <warm_start>`.
Repeatedly calling fit or partial_fit when warm_start is True can
result in a different solution than when calling fit a single time
because of the way the data is shuffled.
average : bool or int, optional
When set to True, computes the averaged SGD weights and stores the
result in the ``coef_`` attribute. If set to an int greater than 1,
averaging will begin once the total number of samples seen reaches
average. So average=10 will begin averaging after seeing 10 samples.
.. versionadded:: 0.19
parameter *average* to use weights averaging in SGD
Attributes
----------
coef_ : array, shape = [1, n_features] if n_classes == 2 else [n_classes,\
n_features]
Weights assigned to the features.
intercept_ : array, shape = [1] if n_classes == 2 else [n_classes]
Constants in decision function.
n_iter_ : int
The actual number of iterations to reach the stopping criterion.
t_ : int
Number of weight updates performed during training.
Same as ``(n_iter_ * n_samples)``.
Examples
--------
>>> from sklearn.linear_model import PassiveAggressiveRegressor
>>> from sklearn.datasets import make_regression
>>> X, y = make_regression(n_features=4, random_state=0)
>>> regr = PassiveAggressiveRegressor(max_iter=100, random_state=0,
... tol=1e-3)
>>> regr.fit(X, y)
PassiveAggressiveRegressor(max_iter=100, random_state=0)
>>> print(regr.coef_)
[20.48736655 34.18818427 67.59122734 87.94731329]
>>> print(regr.intercept_)
[-0.02306214]
>>> print(regr.predict([[0, 0, 0, 0]]))
[-0.02306214]
See also
--------
SGDRegressor
References
----------
Online Passive-Aggressive Algorithms
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
"""
@_deprecate_positional_args
def __init__(self, *, C=1.0, fit_intercept=True, max_iter=1000, tol=1e-3,
early_stopping=False, validation_fraction=0.1,
n_iter_no_change=5, shuffle=True, verbose=0,
loss="epsilon_insensitive", epsilon=DEFAULT_EPSILON,
random_state=None, warm_start=False,
average=False):
super().__init__(
penalty=None,
l1_ratio=0,
epsilon=epsilon,
eta0=1.0,
fit_intercept=fit_intercept,
max_iter=max_iter,
tol=tol,
early_stopping=early_stopping,
validation_fraction=validation_fraction,
n_iter_no_change=n_iter_no_change,
shuffle=shuffle,
verbose=verbose,
random_state=random_state,
warm_start=warm_start,
average=average)
self.C = C
self.loss = loss
def partial_fit(self, X, y):
"""Fit linear model with Passive Aggressive algorithm.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Subset of training data
y : numpy array of shape [n_samples]
Subset of target values
Returns
-------
self : returns an instance of self.
"""
self._validate_params(for_partial_fit=True)
lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2"
return self._partial_fit(X, y, alpha=1.0, C=self.C,
loss="epsilon_insensitive",
learning_rate=lr, max_iter=1,
sample_weight=None,
coef_init=None, intercept_init=None)
def fit(self, X, y, coef_init=None, intercept_init=None):
"""Fit linear model with Passive Aggressive algorithm.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training data
y : numpy array of shape [n_samples]
Target values
coef_init : array, shape = [n_features]
The initial coefficients to warm-start the optimization.
intercept_init : array, shape = [1]
The initial intercept to warm-start the optimization.
Returns
-------
self : returns an instance of self.
"""
self._validate_params()
lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2"
return self._fit(X, y, alpha=1.0, C=self.C,
loss="epsilon_insensitive",
learning_rate=lr,
coef_init=coef_init,
intercept_init=intercept_init)