Skip to content

Commit

Permalink
ENH: make it possible to pass class_weight='auto' as constructor para…
Browse files Browse the repository at this point in the history
…m for SGDClassifier
  • Loading branch information
ogrisel committed Nov 8, 2011
1 parent 4538b7e commit 9c1c99b
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 39 deletions.
7 changes: 5 additions & 2 deletions doc/modules/sgd.rst
Expand Up @@ -38,7 +38,10 @@ The disadvantages of Stochastic Gradient Descent include:
Classification
==============

.. warning:: Make sure you permute (shuffle) your training data before fitting the model or use `shuffle=True` to shuffle after each iterations.
.. warning::

Make sure you permute (shuffle) your training data before fitting the
model or use `shuffle=True` to shuffle after each iterations.

The class :class:`SGDClassifier` implements a plain stochastic gradient
descent learning routine which supports different loss functions and
Expand All @@ -59,7 +62,7 @@ for the training samples::
>>> y = [0, 1]
>>> clf = SGDClassifier(loss="hinge", penalty="l2")
>>> clf.fit(X, y)
SGDClassifier(alpha=0.0001, eta0=0.0, fit_intercept=True,
SGDClassifier(alpha=0.0001, class_weight=None, eta0=0.0, fit_intercept=True,
learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,
penalty='l2', power_t=0.5, rho=1.0, seed=0, shuffle=False,
verbose=0)
Expand Down
27 changes: 14 additions & 13 deletions sklearn/linear_model/base.py
Expand Up @@ -161,7 +161,7 @@ class BaseSGD(BaseEstimator):
def __init__(self, loss, penalty='l2', alpha=0.0001,
rho=0.85, fit_intercept=True, n_iter=5, shuffle=False,
verbose=0, seed=0, learning_rate="optimal", eta0=0.0,
power_t=0.5):
power_t=0.5, class_weight=None):
self.loss = str(loss)
self.penalty = str(penalty)
self.alpha = float(alpha)
Expand All @@ -185,6 +185,7 @@ def __init__(self, loss, penalty='l2', alpha=0.0001,
if self.learning_rate != "optimal":
if eta0 <= 0.0:
raise ValueError("eta0 must be greater than 0.0")
self.class_weight = class_weight

def _set_learning_rate(self, learning_rate):
learning_rate_codes = {"constant": 1, "optimal": 2, "invscaling": 3}
Expand Down Expand Up @@ -273,20 +274,20 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,


class BaseSGDClassifier(BaseSGD, ClassifierMixin):
"""Base class for dense and sparse classification using SGD.
"""
"""Base class for dense and sparse classification using SGD."""

def __init__(self, loss="hinge", penalty='l2', alpha=0.0001,
rho=0.85, fit_intercept=True, n_iter=5, shuffle=False,
verbose=0, n_jobs=1, seed=0, learning_rate="optimal",
eta0=0.0, power_t=0.5):
eta0=0.0, power_t=0.5, class_weight=None):
super(BaseSGDClassifier, self).__init__(loss=loss, penalty=penalty,
alpha=alpha, rho=rho,
fit_intercept=fit_intercept,
n_iter=n_iter, shuffle=shuffle,
verbose=verbose, seed=seed,
learning_rate=learning_rate,
eta0=eta0, power_t=power_t)
eta0=eta0, power_t=power_t,
class_weight=class_weight)
self.n_jobs = int(n_jobs)

def _set_loss_function(self, loss):
Expand All @@ -303,8 +304,9 @@ def _set_loss_function(self, loss):

def _set_class_weight(self, class_weight, classes, y):
"""Estimate class weights for unbalanced datasets."""
class_weight = {} if class_weight is None else class_weight
if class_weight == {}:
if class_weight is None:
class_weight = self.class_weight
if class_weight is None or len(class_weight) == 0:
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
elif class_weight == 'auto':
weight = np.array([1.0 / np.sum(y == i) for i in classes],
Expand All @@ -313,15 +315,16 @@ def _set_class_weight(self, class_weight, classes, y):
else:
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
if not isinstance(class_weight, dict):
raise ValueError("class_weight must be dict, 'auto', or None.")
raise ValueError("class_weight must be dict, 'auto', or None,"
" got: %r" % class_weight)
for c in class_weight:
i = np.searchsorted(classes, c)
if classes[i] != c:
raise ValueError("Class label %d not present." % c)
else:
weight[i] = class_weight[c]

self.class_weight = weight
self._expanded_class_weight = weight

def fit(self, X, y, coef_init=None, intercept_init=None,
class_weight=None, sample_weight=None):
Expand Down Expand Up @@ -454,8 +457,7 @@ def predict_proba(self, X):


class BaseSGDRegressor(BaseSGD, RegressorMixin):
"""Base class for dense and sparse regression using SGD.
"""
"""Base class for dense and sparse regression using SGD."""
def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001,
rho=0.85, fit_intercept=True, n_iter=5, shuffle=False,
verbose=0, p=0.1, seed=0, learning_rate="invscaling",
Expand Down Expand Up @@ -541,8 +543,7 @@ def predict(self, X):


class CoefSelectTransformerMixin(TransformerMixin):
"""Mixin for linear models that can find sparse solutions.
"""
"""Mixin for linear models that can find sparse solutions."""

def transform(self, X, threshold=1e-10):
if len(self.coef_.shape) == 1 or self.coef_.shape[1] == 1:
Expand Down
26 changes: 17 additions & 9 deletions sklearn/linear_model/sparse/stochastic_gradient.py
Expand Up @@ -93,6 +93,14 @@ class SGDClassifier(BaseSGDClassifier):
power_t : double, optional
The exponent for inverse scaling learning rate [default 0.25].
class_weight : dict, {class_label : weight} or "auto" or None, optional
Preset for the class_weight fit parameter.
Weights associated with classes. If not given, all classes
are supposed to have weight one.
The "auto" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies.
Attributes
----------
Expand All @@ -115,7 +123,7 @@ class SGDClassifier(BaseSGDClassifier):
>>> y = np.array([1, 1, 2, 2])
>>> clf = linear_model.sparse.SGDClassifier()
>>> clf.fit(X, y)
SGDClassifier(alpha=0.0001, eta0=0.0, fit_intercept=True,
SGDClassifier(alpha=0.0001, class_weight=None, eta0=0.0, fit_intercept=True,
learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,
penalty='l2', power_t=0.5, rho=1.0, seed=0, shuffle=False,
verbose=0)
Expand All @@ -129,8 +137,7 @@ class SGDClassifier(BaseSGDClassifier):
"""

def _fit_binary(self, X, y):
"""Fit a binary classifier.
"""
"""Fit a binary classifier."""
X = _tocsr(X)

# encode original class labels as 1 (classes[1]) or -1 (classes[0]).
Expand All @@ -155,8 +162,8 @@ def _fit_binary(self, X, y):
int(self.verbose),
int(self.shuffle),
int(self.seed),
self.class_weight[1],
self.class_weight[0],
self._expanded_class_weight[1],
self._expanded_class_weight[0],
self.sample_weight,
self.learning_rate_code,
self.eta0, self.power_t)
Expand All @@ -166,9 +173,10 @@ def _fit_binary(self, X, y):
self.intercept_ = np.asarray(intercept_)

def _fit_multiclass(self, X, y):
"""Fit a multi-class classifier with a combination
of binary classifiers, each predicts one class versus
all others (OVA: One Versus All).
"""Fit a multi-class classifier as a combination of binary classifiers
Each binary classifier predicts one class versus all others
(OVA: One Versus All).
"""
X = _tocsr(X)

Expand All @@ -187,7 +195,7 @@ def _fit_multiclass(self, X, y):
self.fit_intercept,
self.verbose, self.shuffle,
self.seed,
self.class_weight[i],
self._expanded_class_weight[i],
self.sample_weight,
self.learning_rate_code,
self.eta0, self.power_t)
Expand Down
16 changes: 12 additions & 4 deletions sklearn/linear_model/stochastic_gradient.py
Expand Up @@ -85,6 +85,14 @@ class SGDClassifier(BaseSGDClassifier):
power_t : double
The exponent for inverse scaling learning rate [default 0.25].
class_weight : dict, {class_label : weight} or "auto" or None, optional
Preset for the class_weight fit parameter.
Weights associated with classes. If not given, all classes
are supposed to have weight one.
The "auto" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies.
Attributes
----------
Expand All @@ -103,7 +111,7 @@ class SGDClassifier(BaseSGDClassifier):
>>> Y = np.array([1, 1, 2, 2])
>>> clf = linear_model.SGDClassifier()
>>> clf.fit(X, Y)
SGDClassifier(alpha=0.0001, eta0=0.0, fit_intercept=True,
SGDClassifier(alpha=0.0001, class_weight=None, eta0=0.0, fit_intercept=True,
learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,
penalty='l2', power_t=0.5, rho=1.0, seed=0, shuffle=False,
verbose=0)
Expand Down Expand Up @@ -137,8 +145,8 @@ def _fit_binary(self, X, y):
int(self.verbose),
int(self.shuffle),
self.seed,
self.class_weight[1],
self.class_weight[0],
self._expanded_class_weight[1],
self._expanded_class_weight[0],
self.sample_weight,
self.learning_rate_code, self.eta0,
self.power_t)
Expand All @@ -164,7 +172,7 @@ def _fit_multiclass(self, X, y):
self.fit_intercept,
self.verbose, self.shuffle,
self.seed,
self.class_weight[i],
self._expanded_class_weight[i],
self.sample_weight,
self.learning_rate_code,
self.eta0, self.power_t)
Expand Down
29 changes: 18 additions & 11 deletions sklearn/linear_model/tests/test_sgd.py
Expand Up @@ -222,19 +222,21 @@ def test_sgd_l1(self):
pred = clf.predict(X)
assert_array_equal(pred, Y)

def test_class_weight(self):
def test_class_weights(self):
"""
Test class weights.
"""
X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],
[1.0, 1.0], [1.0, 0.0]])
y = [1, 1, 1, -1, -1]

clf = self.factory(alpha=0.1, n_iter=1000, fit_intercept=False)
clf = self.factory(alpha=0.1, n_iter=1000, fit_intercept=False,
class_weight=None)
clf.fit(X, y)
assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([1]))

# we give a small weights to class 1
clf = self.factory(alpha=0.1, n_iter=1000, fit_intercept=False, )
clf.fit(X, y, class_weight={1: 0.001})

# now the hyperplane should rotate clock-wise and
Expand All @@ -245,7 +247,7 @@ def test_equal_class_weight(self):
"""Test if equal class weights approx. equals no class weights. """
X = [[1, 0], [1, 0], [0, 1], [0, 1]]
y = [0, 0, 1, 1]
clf = self.factory(alpha=0.1, n_iter=1000)
clf = self.factory(alpha=0.1, n_iter=1000, class_weight=None)
clf.fit(X, y)

X = [[1, 0], [0, 1]]
Expand Down Expand Up @@ -279,12 +281,13 @@ def test_auto_weight(self):
np.random.shuffle(idx)
X = X[idx]
y = y[idx]
clf = self.factory(alpha=0.0001, n_iter=1000).fit(X, y)
clf = self.factory(alpha=0.0001, n_iter=1000,
class_weight=None).fit(X, y)
assert_approx_equal(metrics.f1_score(y, clf.predict(X)), 0.96, 2)

# make the same prediction using automated class_weight
clf_auto = self.factory(alpha=0.0001,
n_iter=1000).fit(X, y, class_weight="auto")
clf_auto = self.factory(alpha=0.0001, n_iter=1000,
class_weight="auto").fit(X, y)
assert_approx_equal(metrics.f1_score(y, clf_auto.predict(X)), 0.96, 2)

# Make sure that in the balanced case it does not change anything
Expand All @@ -299,21 +302,25 @@ def test_auto_weight(self):
y_imbalanced = np.concatenate([y] + [y_0] * 10)

# fit a model on the imbalanced data without class weight info
clf = self.factory(n_iter=1000)
clf = self.factory(n_iter=1000, class_weight=None)
clf.fit(X_imbalanced, y_imbalanced)
y_pred = clf.predict(X)
assert metrics.f1_score(y, y_pred) < 0.96

# fit a model with auto class_weight enabled
clf = self.factory(n_iter=1000)
clf = self.factory(n_iter=1000, class_weight="auto")
clf.fit(X_imbalanced, y_imbalanced)
y_pred = clf.predict(X)
assert metrics.f1_score(y, y_pred) > 0.96

# fit another using a fit parameter override
clf = self.factory(n_iter=1000, class_weight=None)
clf.fit(X_imbalanced, y_imbalanced, class_weight="auto")
y_pred = clf.predict(X)
assert metrics.f1_score(y, y_pred) > 0.96

def test_sample_weights(self):
"""
Test weights on individual samples
"""
"""Test weights on individual samples"""
X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],
[1.0, 1.0], [1.0, 0.0]])
y = [1, 1, 1, -1, -1]
Expand Down

0 comments on commit 9c1c99b

Please sign in to comment.