Skip to content

Commit

Permalink
Add support for multi_class=='multinominal'
Browse files Browse the repository at this point in the history
  • Loading branch information
massich committed May 30, 2017
1 parent e6a3c23 commit 4ac33e8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
10 changes: 6 additions & 4 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,12 @@ def _multinomial_loss_grad(w, X, Y, alpha, sample_weight):
n_classes = Y.shape[1]
n_features = X.shape[1]
fit_intercept = (w.size == n_classes * (n_features + 1))
grad = np.zeros((n_classes, n_features + bool(fit_intercept)))
grad = np.zeros((n_classes, n_features + bool(fit_intercept)),
dtype=X.dtype)
loss, p, w = _multinomial_loss(w, X, Y, alpha, sample_weight)
sample_weight = sample_weight[:, np.newaxis]
diff = sample_weight * (p - Y)
diff = diff.astype(X.dtype)
grad[:, :n_features] = safe_sparse_dot(diff.T, X)
grad[:, :n_features] += alpha * w
if fit_intercept:
Expand Down Expand Up @@ -608,10 +610,10 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
# and check length
# Otherwise set them to 1 for all examples
if sample_weight is not None:
sample_weight = np.array(sample_weight, dtype=np.float64, order='C')
sample_weight = np.array(sample_weight, dtype=X.dtype, order='C')
check_consistent_length(y, sample_weight)
else:
sample_weight = np.ones(X.shape[0])
sample_weight = np.ones(X.shape[0], dtype=X.dtype)

# If class_weights is a dict (provided by the user), the weights
# are assigned to the original labels. If it is "balanced", then
Expand Down Expand Up @@ -648,7 +650,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
Y_multi = le.fit_transform(y)

w0 = np.zeros((classes.size, n_features + int(fit_intercept)),
order='F')
order='F', dtype=X.dtype)

if coef is not None:
# it must work both giving the bias term and not
Expand Down
8 changes: 4 additions & 4 deletions sklearn/utils/class_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def compute_class_weight(class_weight, classes, y):
"be in y")
if class_weight is None or len(class_weight) == 0:
# uniform class weights
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
weight = np.ones(classes.shape[0], dtype=y.dtype, order='C')
elif class_weight == 'balanced':
# Find the weight of each class as present in y.
le = LabelEncoder()
Expand All @@ -55,11 +55,11 @@ def compute_class_weight(class_weight, classes, y):
raise ValueError("classes should have valid labels that are in y")

recip_freq = len(y) / (len(le.classes_) *
bincount(y_ind).astype(np.float64))
bincount(y_ind).astype(y.dtype))
weight = recip_freq[le.transform(classes)]
else:
# user-defined dictionary
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
weight = np.ones(classes.shape[0], dtype=y.dtype, order='C')
if not isinstance(class_weight, dict):
raise ValueError("class_weight must be dict, 'balanced', or None,"
" got: %r" % class_weight)
Expand Down Expand Up @@ -176,6 +176,6 @@ def compute_sample_weight(class_weight, y, indices=None):

expanded_class_weight = np.prod(expanded_class_weight,
axis=0,
dtype=np.float64)
dtype=y.dtype)

return expanded_class_weight

0 comments on commit 4ac33e8

Please sign in to comment.