Skip to content

Commit

Permalink
fix multi-class objective (softmax) (#3256)
Browse files Browse the repository at this point in the history
* Update multiclass_objective.hpp

* Apply suggestions from code review

* Update src/objective/multiclass_objective.hpp

* Apply suggestions from code review

* Update test_basic.R

* Update test_basic.R

* Update src/objective/multiclass_objective.hpp

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Apply suggestions from code review

Co-authored-by: James Lamb <jaylamb20@gmail.com>
  • Loading branch information
guolinke and jameslamb committed Aug 5, 2020
1 parent b5027de commit 4f28233
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
6 changes: 3 additions & 3 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ test_that("train and predict softmax", {
data = as.matrix(iris[, -5L])
, label = lb
, num_leaves = 4L
, learning_rate = 0.1
, learning_rate = 0.05
, nrounds = 20L
, min_data = 20L
, min_hessian = 20.0
, min_hessian = 10.0
, objective = "multiclass"
, metric = "multi_error"
, num_class = 3L
Expand All @@ -53,7 +53,7 @@ test_that("train and predict softmax", {

expect_false(is.null(bst$record_evals))
record_results <- lgb.get.eval.result(bst, "train", "multi_error")
expect_lt(min(record_results), 0.05)
expect_lt(min(record_results), 0.06)

pred <- predict(bst, as.matrix(iris[, -5L]))
expect_equal(length(pred), nrow(iris) * 3L)
Expand Down
10 changes: 8 additions & 2 deletions src/objective/multiclass_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class MulticlassSoftmax: public ObjectiveFunction {
public:
explicit MulticlassSoftmax(const Config& config) {
num_class_ = config.num_class;
// This factor is to rescale the redundant form of K-classification, to the non-redundant form.
// In the traditional settings of K-classification, there is one redundant class, whose output is set to 0 (like the class 0 in binary classification).
// This is from the Friedman GBDT paper.
factor_ = static_cast<double>(num_class_) / (num_class_ - 1.0f);
}

explicit MulticlassSoftmax(const std::vector<std::string>& strs) {
Expand All @@ -40,6 +44,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
if (num_class_ < 0) {
Log::Fatal("Objective should contain num_class field");
}
factor_ = static_cast<double>(num_class_) / (num_class_ - 1.0f);
}

~MulticlassSoftmax() {
Expand Down Expand Up @@ -97,7 +102,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
} else {
gradients[idx] = static_cast<score_t>(p);
}
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p));
hessians[idx] = static_cast<score_t>(factor_ * p * (1.0f - p));
}
}
} else {
Expand All @@ -118,7 +123,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
} else {
gradients[idx] = static_cast<score_t>((p) * weights_[i]);
}
hessians[idx] = static_cast<score_t>((2.0f * p * (1.0f - p))* weights_[i]);
hessians[idx] = static_cast<score_t>((factor_ * p * (1.0f - p))* weights_[i]);
}
}
}
Expand Down Expand Up @@ -161,6 +166,7 @@ class MulticlassSoftmax: public ObjectiveFunction {
}

private:
double factor_;
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
Expand Down

0 comments on commit 4f28233

Please sign in to comment.