Skip to content

Commit

Permalink
[MRG + 2] FIX LogisticRegressionCV to correctly handle string labels (s…
Browse files Browse the repository at this point in the history
…cikit-learn#5874)

* TST if LogisticRegressionCV handles string labels properly
* TST Add a test with class_weight dict
* ENH Encode y and class_weight dict
* Better variable names
* TYPO casses --> classes
* FIX Use dict comprehension; classes_labels --> classes
* Revert dict comprehension (for Python 2.6 compat)
* MNT reorder validation to improve clarity
* Add whatsnew entry
  • Loading branch information
raghavrv authored and maskani-moh committed Nov 15, 2017
1 parent a46d105 commit c19c27e
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 46 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -97,6 +97,10 @@ Bug fixes
attribute in `transform()`. :issue:`7553` by :user:`Ekaterina
Krivich <kiote>`.

- :class:`sklearn.linear_model.LogisticRegressionCV` now correctly handles
string labels. :issue:`5874` by `Raghav RV`_.


.. _changes_0_18_1:

Version 0.18.1
Expand Down
93 changes: 47 additions & 46 deletions sklearn/linear_model/logistic.py
@@ -1,4 +1,3 @@

"""
Logistic Regression
"""
Expand Down Expand Up @@ -28,7 +27,6 @@
from ..utils.extmath import row_norms
from ..utils.optimize import newton_cg
from ..utils.validation import check_X_y
from ..exceptions import DataConversionWarning
from ..exceptions import NotFittedError
from ..utils.fixes import expit
from ..utils.multiclass import check_classification_targets
Expand Down Expand Up @@ -925,9 +923,6 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
y_test = np.ones(y_test.shape, dtype=np.float64)
y_test[~mask] = -1.

# To deal with object dtypes, we need to convert into an array of floats.
y_test = check_array(y_test, dtype=np.float64, ensure_2d=False)

scores = list()

if isinstance(scoring, six.string_types):
Expand Down Expand Up @@ -1561,64 +1556,64 @@ def fit(self, X, y, sample_weight=None):

X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64,
order="C")
check_classification_targets(y)

class_weight = self.class_weight
if class_weight and not(isinstance(class_weight, dict) or
class_weight in ['balanced', 'auto']):
# 'auto' is deprecated and will be removed in 0.19
raise ValueError("class_weight provided should be a "
"dict or 'balanced'")

# Encode for string labels
label_encoder = LabelEncoder().fit(y)
y = label_encoder.transform(y)
if isinstance(class_weight, dict):
class_weight = dict((label_encoder.transform([cls])[0], v)
for cls, v in class_weight.items())

# The original class labels
classes = self.classes_ = label_encoder.classes_
encoded_labels = label_encoder.transform(label_encoder.classes_)

if self.solver == 'sag':
max_squared_sum = row_norms(X, squared=True).max()
else:
max_squared_sum = None

check_classification_targets(y)

if y.ndim == 2 and y.shape[1] == 1:
warnings.warn(
"A column-vector y was passed when a 1d array was"
" expected. Please change the shape of y to "
"(n_samples, ), for example using ravel().",
DataConversionWarning)
y = np.ravel(y)

check_consistent_length(X, y)

# init cross-validation generator
cv = check_cv(self.cv, y, classifier=True)
folds = list(cv.split(X, y))

self._enc = LabelEncoder()
self._enc.fit(y)

labels = self.classes_ = np.unique(y)
n_classes = len(labels)
# Use the label encoded classes
n_classes = len(encoded_labels)

if n_classes < 2:
raise ValueError("This solver needs samples of at least 2 classes"
" in the data, but the data contains only one"
" class: %r" % self.classes_[0])
" class: %r" % classes[0])

if n_classes == 2:
# OvR in case of binary problems is as good as fitting
# the higher label
n_classes = 1
labels = labels[1:]
encoded_labels = encoded_labels[1:]
classes = classes[1:]

# We need this hack to iterate only once over labels, in the case of
# multi_class = multinomial, without changing the value of the labels.
iter_labels = labels
if self.multi_class == 'multinomial':
iter_labels = [None]

if self.class_weight and not(isinstance(self.class_weight, dict) or
self.class_weight in
['balanced', 'auto']):
# 'auto' is deprecated and will be removed in 0.19
raise ValueError("class_weight provided should be a "
"dict or 'balanced'")
iter_encoded_labels = iter_classes = [None]
else:
iter_encoded_labels = encoded_labels
iter_classes = classes

# compute the class weights for the entire dataset y
if self.class_weight in ("auto", "balanced"):
classes = np.unique(y)
class_weight = compute_class_weight(self.class_weight, classes, y)
class_weight = dict(zip(classes, class_weight))
else:
class_weight = self.class_weight
if class_weight in ("auto", "balanced"):
class_weight = compute_class_weight(class_weight,
np.arange(len(self.classes_)),
y)
class_weight = dict(enumerate(class_weight))

path_func = delayed(_log_reg_scoring_path)

Expand All @@ -1638,7 +1633,7 @@ def fit(self, X, y, sample_weight=None):
max_squared_sum=max_squared_sum,
sample_weight=sample_weight
)
for label in iter_labels
for label in iter_encoded_labels
for train, test in folds)

if self.multi_class == 'multinomial':
Expand Down Expand Up @@ -1669,9 +1664,9 @@ def fit(self, X, y, sample_weight=None):
self.n_iter_ = np.reshape(n_iter_, (n_classes, len(folds),
len(self.Cs_)))

self.coefs_paths_ = dict(zip(labels, coefs_paths))
self.coefs_paths_ = dict(zip(classes, coefs_paths))
scores = np.reshape(scores, (n_classes, len(folds), -1))
self.scores_ = dict(zip(labels, scores))
self.scores_ = dict(zip(classes, scores))

self.C_ = list()
self.coef_ = np.empty((n_classes, X.shape[1]))
Expand All @@ -1682,10 +1677,14 @@ def fit(self, X, y, sample_weight=None):
scores = multi_scores
coefs_paths = multi_coefs_paths

for index, label in enumerate(iter_labels):
for index, (cls, encoded_label) in enumerate(
zip(iter_classes, iter_encoded_labels)):

if self.multi_class == 'ovr':
scores = self.scores_[label]
coefs_paths = self.coefs_paths_[label]
# The scores_ / coefs_paths_ dict have unencoded class
# labels as their keys
scores = self.scores_[cls]
coefs_paths = self.coefs_paths_[cls]

if self.refit:
best_index = scores.sum(axis=0).argmax()
Expand All @@ -1698,8 +1697,10 @@ def fit(self, X, y, sample_weight=None):
else:
coef_init = np.mean(coefs_paths[:, best_index, :], axis=0)

# Note that y is label encoded and hence pos_class must be
# the encoded label / None (for 'multinomial')
w, _, _ = logistic_regression_path(
X, y, pos_class=label, Cs=[C_], solver=self.solver,
X, y, pos_class=encoded_label, Cs=[C_], solver=self.solver,
fit_intercept=self.fit_intercept, coef=coef_init,
max_iter=self.max_iter, tol=self.tol,
penalty=self.penalty, copy=False,
Expand Down
39 changes: 39 additions & 0 deletions sklearn/linear_model/tests/test_logistic.py
Expand Up @@ -27,6 +27,7 @@
from sklearn.model_selection import StratifiedKFold
from sklearn.datasets import load_iris, make_classification
from sklearn.metrics import log_loss
from sklearn.preprocessing import LabelEncoder

X = [[-1, 0], [0, 1], [1, 1]]
X_sp = sp.csr_matrix(X)
Expand Down Expand Up @@ -398,6 +399,44 @@ def test_logistic_cv():
assert_array_equal(scores.shape, (1, 3, 1))


def test_multinomial_logistic_regression_string_inputs():
# Test with string labels for LogisticRegression(CV)
n_samples, n_features, n_classes = 50, 5, 3
X_ref, y = make_classification(n_samples=n_samples, n_features=n_features,
n_classes=n_classes, n_informative=3)
y_str = LabelEncoder().fit(['bar', 'baz', 'foo']).inverse_transform(y)
# For numerical labels, let y values be taken from set (-1, 0, 1)
y = np.array(y) - 1
# Test for string labels
lr = LogisticRegression(solver='lbfgs', multi_class='multinomial')
lr_cv = LogisticRegressionCV(solver='lbfgs', multi_class='multinomial')
lr_str = LogisticRegression(solver='lbfgs', multi_class='multinomial')
lr_cv_str = LogisticRegressionCV(solver='lbfgs', multi_class='multinomial')

lr.fit(X_ref, y)
lr_cv.fit(X_ref, y)
lr_str.fit(X_ref, y_str)
lr_cv_str.fit(X_ref, y_str)

assert_array_almost_equal(lr.coef_, lr_str.coef_)
assert_equal(sorted(lr_str.classes_), ['bar', 'baz', 'foo'])
assert_array_almost_equal(lr_cv.coef_, lr_cv_str.coef_)
assert_equal(sorted(lr_str.classes_), ['bar', 'baz', 'foo'])
assert_equal(sorted(lr_cv_str.classes_), ['bar', 'baz', 'foo'])

# The predictions should be in original labels
assert_equal(sorted(np.unique(lr_str.predict(X_ref))),
['bar', 'baz', 'foo'])
assert_equal(sorted(np.unique(lr_cv_str.predict(X_ref))),
['bar', 'baz', 'foo'])

# Make sure class weights can be given with string labels
lr_cv_str = LogisticRegression(
solver='lbfgs', class_weight={'bar': 1, 'baz': 2, 'foo': 0},
multi_class='multinomial').fit(X_ref, y_str)
assert_equal(sorted(np.unique(lr_cv_str.predict(X_ref))), ['bar', 'baz'])


def test_logistic_cv_sparse():
X, y = make_classification(n_samples=50, n_features=5,
random_state=0)
Expand Down

0 comments on commit c19c27e

Please sign in to comment.