Skip to content

Commit

Permalink
[Breaking] Change default evaluation metric for classification to log…
Browse files Browse the repository at this point in the history
…loss / mlogloss (#6183)

* Change DefaultEvalMetric of classification from error to logloss

* Change default binary metric in plugin/example/custom_obj.cc

* Set old error metric in python tests

* Set old error metric in R tests

* Fix missed eval metrics and typos in R tests

* Fix setting eval_metric twice in R tests

* Add warning for empty eval_metric for classification

* Fix Dask tests

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
  • Loading branch information
lorentzenchr and hcho3 committed Oct 2, 2020
1 parent e0e4f15 commit cf4f019
Show file tree
Hide file tree
Showing 18 changed files with 56 additions and 32 deletions.
16 changes: 9 additions & 7 deletions R-package/tests/testthat/test_basic.R
Expand Up @@ -17,7 +17,8 @@ test_that("train and predict binary classification", {
nrounds <- 2
expect_output(
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
eta = 1, nthread = 2, nrounds = nrounds, objective = "binary:logistic")
eta = 1, nthread = 2, nrounds = nrounds, objective = "binary:logistic",
eval_metric = "error")
, "train-error")
expect_equal(class(bst), "xgb.Booster")
expect_equal(bst$niter, nrounds)
Expand Down Expand Up @@ -122,7 +123,7 @@ test_that("train and predict softprob", {
expect_output(
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
max_depth = 3, eta = 0.5, nthread = 2, nrounds = 5,
objective = "multi:softprob", num_class = 3)
objective = "multi:softprob", num_class = 3, eval_metric = "merror")
, "train-merror")
expect_false(is.null(bst$evaluation_log))
expect_lt(bst$evaluation_log[, min(train_merror)], 0.025)
Expand Down Expand Up @@ -150,7 +151,7 @@ test_that("train and predict softmax", {
expect_output(
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
max_depth = 3, eta = 0.5, nthread = 2, nrounds = 5,
objective = "multi:softmax", num_class = 3)
objective = "multi:softmax", num_class = 3, eval_metric = "merror")
, "train-merror")
expect_false(is.null(bst$evaluation_log))
expect_lt(bst$evaluation_log[, min(train_merror)], 0.025)
Expand All @@ -167,7 +168,7 @@ test_that("train and predict RF", {
lb <- train$label
# single iteration
bst <- xgboost(data = train$data, label = lb, max_depth = 5,
nthread = 2, nrounds = 1, objective = "binary:logistic",
nthread = 2, nrounds = 1, objective = "binary:logistic", eval_metric = "error",
num_parallel_tree = 20, subsample = 0.6, colsample_bytree = 0.1)
expect_equal(bst$niter, 1)
expect_equal(xgb.ntree(bst), 20)
Expand All @@ -193,7 +194,8 @@ test_that("train and predict RF with softprob", {
set.seed(11)
bst <- xgboost(data = as.matrix(iris[, -5]), label = lb,
max_depth = 3, eta = 0.9, nthread = 2, nrounds = nrounds,
objective = "multi:softprob", num_class = 3, verbose = 0,
objective = "multi:softprob", eval_metric = "merror",
num_class = 3, verbose = 0,
num_parallel_tree = 4, subsample = 0.5, colsample_bytree = 0.5)
expect_equal(bst$niter, 15)
expect_equal(xgb.ntree(bst), 15 * 3 * 4)
Expand Down Expand Up @@ -274,7 +276,7 @@ test_that("xgb.cv works", {
expect_output(
cv <- xgb.cv(data = train$data, label = train$label, max_depth = 2, nfold = 5,
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
verbose = TRUE)
eval_metric = "error", verbose = TRUE)
, "train-error:")
expect_is(cv, 'xgb.cv.synchronous')
expect_false(is.null(cv$evaluation_log))
Expand All @@ -299,7 +301,7 @@ test_that("xgb.cv works with stratified folds", {
eta = 1., nthread = 2, nrounds = 2, objective = "binary:logistic",
verbose = TRUE, stratified = TRUE)
# Stratified folds should result in a different evaluation logs
expect_true(all(cv$evaluation_log[, test_error_mean] != cv2$evaluation_log[, test_error_mean]))
expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean]))
})

test_that("train and predict with non-strict classes", {
Expand Down
8 changes: 5 additions & 3 deletions R-package/tests/testthat/test_callbacks.R
Expand Up @@ -26,7 +26,8 @@ watchlist <- list(train = dtrain, test = dtest)

err <- function(label, pr) sum((pr > 0.5) != label) / length(label)

param <- list(objective = "binary:logistic", max_depth = 2, nthread = 2)
param <- list(objective = "binary:logistic", eval_metric = "error",
max_depth = 2, nthread = 2)


test_that("cb.print.evaluation works as expected", {
Expand Down Expand Up @@ -105,7 +106,8 @@ test_that("cb.evaluation.log works as expected", {
})


param <- list(objective = "binary:logistic", max_depth = 4, nthread = 2)
param <- list(objective = "binary:logistic", eval_metric = "error",
max_depth = 4, nthread = 2)

test_that("can store evaluation_log without printing", {
expect_silent(
Expand Down Expand Up @@ -236,7 +238,7 @@ test_that("early stopping xgb.train works", {
test_that("early stopping using a specific metric works", {
set.seed(11)
expect_output(
bst <- xgb.train(param, dtrain, nrounds = 20, watchlist, eta = 0.6,
bst <- xgb.train(param[-2], dtrain, nrounds = 20, watchlist, eta = 0.6,
eval_metric = "logloss", eval_metric = "auc",
callbacks = list(cb.early.stop(stopping_rounds = 3, maximize = FALSE,
metric_name = 'test_logloss')))
Expand Down
2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_glm.R
Expand Up @@ -8,7 +8,7 @@ test_that("gblinear works", {
dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label)
dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)

param <- list(objective = "binary:logistic", booster = "gblinear",
param <- list(objective = "binary:logistic", eval_metric = "error", booster = "gblinear",
nthread = 2, eta = 0.8, alpha = 0.0001, lambda = 0.0001)
watchlist <- list(eval = dtest, train = dtrain)

Expand Down
3 changes: 2 additions & 1 deletion demo/guide-python/custom_softmax.py
Expand Up @@ -142,7 +142,8 @@ def main(args):

native_results = {}
# Use the same objective function defined in XGBoost.
booster_native = xgb.train({'num_class': kClasses},
booster_native = xgb.train({'num_class': kClasses,
'eval_metric': 'merror'},
m,
num_boost_round=kRounds,
evals_result=native_results,
Expand Down
2 changes: 1 addition & 1 deletion doc/parameter.rst
Expand Up @@ -376,7 +376,7 @@ Specify the learning task and the corresponding learning objective. The objectiv

* ``eval_metric`` [default according to objective]

- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and error for classification, mean average precision for ranking)
- Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
- User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
- The choices are listed below:

Expand Down
Expand Up @@ -154,10 +154,10 @@ class XGBoostClassifier (
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
if ($(objective).startsWith("multi")) {
// multi
"merror"
"mlogloss"
} else {
// binary
"error"
"logloss"
}
}

Expand Down
2 changes: 1 addition & 1 deletion plugin/example/custom_obj.cc
Expand Up @@ -56,7 +56,7 @@ class MyLogistic : public ObjFunction {
}
}
const char* DefaultEvalMetric() const override {
return "error";
return "logloss";
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
// transform margin value to probability.
Expand Down
2 changes: 1 addition & 1 deletion plugin/updater_oneapi/regression_loss_oneapi.h
Expand Up @@ -103,7 +103,7 @@ struct LogisticRegressionOneAPI {

// logistic loss for binary classification task
struct LogisticClassificationOneAPI : public LogisticRegressionOneAPI {
static const char* DefaultEvalMetric() { return "error"; }
static const char* DefaultEvalMetric() { return "logloss"; }
static const char* Name() { return "binary:logistic_oneapi"; }
};

Expand Down
12 changes: 12 additions & 0 deletions src/learner.cc
Expand Up @@ -1031,6 +1031,18 @@ class LearnerImpl : public LearnerIO {
std::ostringstream os;
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
auto warn_default_eval_metric = [](const std::string& objective, const std::string& before,
const std::string& after) {
LOG(WARNING) << "Starting in XGBoost 1.3.0, the default evaluation metric used with the "
<< "objective '" << objective << "' was changed from '" << before
<< "' to '" << after << "'. Explicitly set eval_metric if you'd like to "
<< "restore the old behavior.";
};
if (tparam_.objective == "binary:logistic") {
warn_default_eval_metric(tparam_.objective, "error", "logloss");
} else if ((tparam_.objective == "multi:softmax" || tparam_.objective == "multi:softprob")) {
warn_default_eval_metric(tparam_.objective, "merror", "mlogloss");
}
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &generic_parameters_));
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
}
Expand Down
2 changes: 1 addition & 1 deletion src/objective/multiclass_obj.cu
Expand Up @@ -125,7 +125,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
this->Transform(io_preds, true);
}
const char* DefaultEvalMetric() const override {
return "merror";
return "mlogloss";
}

inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
Expand Down
2 changes: 1 addition & 1 deletion src/objective/regression_loss.h
Expand Up @@ -131,7 +131,7 @@ struct PseudoHuberError {

// logistic loss for binary classification task
struct LogisticClassification : public LogisticRegression {
static const char* DefaultEvalMetric() { return "error"; }
static const char* DefaultEvalMetric() { return "logloss"; }
static const char* Name() { return "binary:logistic"; }
};

Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/plugin/test_example_objective.cc
Expand Up @@ -8,7 +8,7 @@ namespace xgboost {
TEST(Plugin, ExampleObjective) {
xgboost::GenericParameter tparam = CreateEmptyGenericParam(GPUIDX);
auto * obj = xgboost::ObjFunction::Create("mylogistic", &tparam);
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"error"});
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"logloss"});
delete obj;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_basic.py
Expand Up @@ -81,7 +81,7 @@ def test_record_results(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic'}
'objective': 'binary:logistic', 'eval_metric': 'error'}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
Expand Down
8 changes: 5 additions & 3 deletions tests/python/test_basic_models.py
Expand Up @@ -117,7 +117,8 @@ def run_eta_decay(self, tree_method):
# learning_rates as a list
# init eta with 0 to check whether learning_rates work
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
'objective': 'binary:logistic', 'tree_method': tree_method}
'objective': 'binary:logistic', 'eval_metric': 'error',
'tree_method': tree_method}
evals_result = {}
bst = xgb.train(param, dtrain, num_round, watchlist,
callbacks=[xgb.callback.reset_learning_rate([
Expand All @@ -131,7 +132,8 @@ def run_eta_decay(self, tree_method):

# init learning_rate with 0 to check whether learning_rates work
param = {'max_depth': 2, 'learning_rate': 0, 'verbosity': 0,
'objective': 'binary:logistic', 'tree_method': tree_method}
'objective': 'binary:logistic', 'eval_metric': 'error',
'tree_method': tree_method}
evals_result = {}
bst = xgb.train(param, dtrain, num_round, watchlist,
callbacks=[xgb.callback.reset_learning_rate(
Expand All @@ -145,7 +147,7 @@ def run_eta_decay(self, tree_method):
# check if learning_rates override default value of eta/learning_rate
param = {
'max_depth': 2, 'verbosity': 0, 'objective': 'binary:logistic',
'tree_method': tree_method
'eval_metric': 'error', 'tree_method': tree_method
}
evals_result = {}
bst = xgb.train(param, dtrain, num_round, watchlist,
Expand Down
12 changes: 8 additions & 4 deletions tests/python/test_dmatrix.py
Expand Up @@ -115,7 +115,9 @@ def test_slice(self):

eval_res_0 = {}
booster = xgb.train(
{'num_class': 3, 'objective': 'multi:softprob'}, d,
{'num_class': 3, 'objective': 'multi:softprob',
'eval_metric': 'merror'},
d,
num_boost_round=2, evals=[(d, 'd')], evals_result=eval_res_0)

predt = booster.predict(d)
Expand All @@ -130,9 +132,11 @@ def test_slice(self):
assert sliced_margin.shape[0] == len(ridxs) * 3

eval_res_1 = {}
xgb.train({'num_class': 3, 'objective': 'multi:softprob'}, sliced,
num_boost_round=2, evals=[(sliced, 'd')],
evals_result=eval_res_1)
xgb.train(
{'num_class': 3, 'objective': 'multi:softprob',
'eval_metric': 'merror'},
sliced,
num_boost_round=2, evals=[(sliced, 'd')], evals_result=eval_res_1)

eval_res_0 = eval_res_0['d']['merror']
eval_res_1 = eval_res_1['d']['merror']
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_early_stopping.py
Expand Up @@ -58,7 +58,7 @@ def test_cv_early_stopping(self):
y = digits['target']
dm = xgb.DMatrix(X, label=y)
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic'}
'objective': 'binary:logistic', 'eval_metric': 'error'}

cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
early_stopping_rounds=10)
Expand Down
5 changes: 3 additions & 2 deletions tests/python/test_with_dask.py
Expand Up @@ -274,7 +274,7 @@ def test_dask_classifier():
X, y = generate_array()
y = (y * 10).astype(np.int32)
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2)
verbosity=1, n_estimators=2, eval_metric='merror')
classifier.client = client
classifier.fit(X, y, eval_set=[(X, y)])
prediction = classifier.predict(X)
Expand Down Expand Up @@ -386,6 +386,7 @@ def _check_outputs(out, predictions):
y = dd.from_array(np.random.randint(low=0, high=n_classes, size=kRows))
dtrain = xgb.dask.DaskDMatrix(client, X, y)
parameters['objective'] = 'multi:softprob'
parameters['eval_metric'] = 'merror'
parameters['num_class'] = n_classes

out = xgb.dask.train(client, parameters,
Expand Down Expand Up @@ -482,7 +483,7 @@ async def run_dask_classifier_asyncio(scheduler_address):
X, y = generate_array()
y = (y * 10).astype(np.int32)
classifier = await xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2)
verbosity=1, n_estimators=2, eval_metric='merror')
classifier.client = client
await classifier.fit(X, y, eval_set=[(X, y)])
prediction = await classifier.predict(X)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_with_pandas.py
Expand Up @@ -174,7 +174,7 @@ def test_pandas_weight(self):
def test_cv_as_pandas(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic'}
'objective': 'binary:logistic', 'eval_metric': 'error'}

cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
assert isinstance(cv, pd.DataFrame)
Expand Down

0 comments on commit cf4f019

Please sign in to comment.