From b3db9e924a8d48a0b4d67097cf7dcc79bc50b9d9 Mon Sep 17 00:00:00 2001 From: Belinda Trotta Date: Sun, 26 May 2019 21:08:45 +1000 Subject: [PATCH] Top k multi error (#2178) * Implement top-k multiclass error metric. Add new parameter top_k_threshold. * Add test for multiclass metrics * Make test less sensitive to avoid floating-point issues. * Change tabs to spaces. * Fix problem with test in Python 2. Refactor to use np.testing. Decrease number of training rounds so loss is larger and easier to compare. * Move multiclass tests into test_engine.py * Change parameter name from top_k_threshold to multi_error_top_k. * Fix top-k error metric to handle case where scores are equal. Update tests and docs. * Change name of top-k metric to multi_error@k. * Change tabs to spaces. * Fix formatting. * Fix minor issues in docs. --- docs/Parameters.rst | 12 +++++ include/LightGBM/config.h | 8 ++++ src/io/config_auto.cpp | 5 +++ src/metric/multiclass_metric.hpp | 33 +++++++------- tests/python_package_test/test_engine.py | 57 ++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 15 deletions(-) diff --git a/docs/Parameters.rst b/docs/Parameters.rst index d7bfb6d325e..08696347654 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -843,6 +843,18 @@ Metric Parameters - `NDCG `__ and `MAP `__ evaluation positions, separated by ``,`` +- ``multi_error_top_k`` :raw-html:`🔗︎`, default = ``1``, type = int, constraints: ``multi_error_top_k > 0`` + + - used only with ``multi_error`` metric + + - threshold for top-k multi-error metric + + - the error on each sample is ``0`` if the true class is among the top ``multi_error_top_k`` predictions, and ``1`` otherwise + + - more precisely, the error on a sample is ``0`` if there are at least ``num_classes - multi_error_top_k`` predictions strictly less than the prediction on the true class + + - when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric + Network Parameters ------------------ diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index a4996e9af34..7ddb7d7e5e4 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -747,6 +747,14 @@ struct Config { // desc = `NDCG `__ and `MAP `__ evaluation positions, separated by ``,`` std::vector eval_at; + // check = >0 + // desc = used only with ``multi_error`` metric + // desc = threshold for top-k multi-error metric + // desc = the error on each sample is ``0`` if the true class is among the top ``multi_error_top_k`` predictions, and ``1`` otherwise + // descl2 = more precisely, the error on a sample is ``0`` if there are at least ``num_classes - multi_error_top_k`` predictions strictly less than the prediction on the true class + // desc = when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric + int multi_error_top_k = 1; + #pragma endregion #pragma region Network Parameters diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 9d5c3f85352..774387b397d 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -260,6 +260,7 @@ std::unordered_set Config::parameter_set({ "metric_freq", "is_provide_training_metric", "eval_at", + "multi_error_top_k", "num_machines", "local_listen_port", "time_out", @@ -521,6 +522,9 @@ void Config::GetMembersFromString(const std::unordered_map(tmp_str, ','); } + GetInt(params, "multi_error_top_k", &multi_error_top_k); + CHECK(multi_error_top_k >0); + GetInt(params, "num_machines", &num_machines); CHECK(num_machines >0); @@ -637,6 +641,7 @@ std::string Config::SaveMembersToString() const { str_buf << "[metric_freq: " << metric_freq << "]\n"; str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n"; + str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n"; str_buf << "[num_machines: " << num_machines << "]\n"; str_buf << "[local_listen_port: " << local_listen_port << "]\n"; str_buf << "[time_out: " << time_out << "]\n"; diff --git a/src/metric/multiclass_metric.hpp b/src/metric/multiclass_metric.hpp index c94660ac236..e700480f037 100644 --- a/src/metric/multiclass_metric.hpp +++ b/src/metric/multiclass_metric.hpp @@ -20,7 +20,7 @@ namespace LightGBM { template class MulticlassMetric: public Metric { public: - explicit MulticlassMetric(const Config& config) { + explicit MulticlassMetric(const Config& config) :config_(config){ num_class_ = config.num_class; } @@ -28,7 +28,7 @@ class MulticlassMetric: public Metric { } void Init(const Metadata& metadata, data_size_t num_data) override { - name_.emplace_back(PointWiseLossCalculator::Name()); + name_.emplace_back(PointWiseLossCalculator::Name(config_)); num_data_ = num_data; // get label label_ = metadata.label(); @@ -72,7 +72,7 @@ class MulticlassMetric: public Metric { std::vector rec(num_pred_per_row); objective->ConvertOutput(raw_score.data(), rec.data()); // add loss - sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec); + sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_); } } else { #pragma omp parallel for schedule(static) reduction(+:sum_loss) @@ -85,7 +85,7 @@ class MulticlassMetric: public Metric { std::vector rec(num_pred_per_row); objective->ConvertOutput(raw_score.data(), rec.data()); // add loss - sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i]; + sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_) * weights_[i]; } } } else { @@ -98,7 +98,7 @@ class MulticlassMetric: public Metric { rec[k] = static_cast(score[idx]); } // add loss - sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec); + sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_); } } else { #pragma omp parallel for schedule(static) reduction(+:sum_loss) @@ -109,7 +109,7 @@ class MulticlassMetric: public Metric { rec[k] = static_cast(score[idx]); } // add loss - sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i]; + sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_) * weights_[i]; } } } @@ -129,25 +129,28 @@ class MulticlassMetric: public Metric { /*! \brief Name of this test set */ std::vector name_; int num_class_; + /*! \brief config parameters*/ + Config config_; }; -/*! \brief L2 loss for multiclass task */ +/*! \brief top-k error for multiclass task; if k=1 (default) this is the usual multi-error */ class MultiErrorMetric: public MulticlassMetric { public: explicit MultiErrorMetric(const Config& config) :MulticlassMetric(config) {} - inline static double LossOnPoint(label_t label, std::vector& score) { + inline static double LossOnPoint(label_t label, std::vector& score, const Config& config) { size_t k = static_cast(label); + int num_larger = 0; for (size_t i = 0; i < score.size(); ++i) { - if (i != k && score[i] >= score[k]) { - return 1.0f; - } + if (score[i] >= score[k]) ++num_larger; + if (num_larger > config.multi_error_top_k) return 1.0f; } return 0.0f; } - inline static const char* Name() { - return "multi_error"; + inline static const std::string Name(const Config& config) { + if (config.multi_error_top_k == 1) return "multi_error"; + else return "multi_error@" + std::to_string(config.multi_error_top_k); } }; @@ -156,7 +159,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric(config) {} - inline static double LossOnPoint(label_t label, std::vector& score) { + inline static double LossOnPoint(label_t label, std::vector& score, const Config&) { size_t k = static_cast(label); if (score[k] > kEpsilon) { return static_cast(-std::log(score[k])); @@ -165,7 +168,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric max_rest)) + + class TestEngine(unittest.TestCase): def test_binary(self): X, y = load_breast_cancer(True) @@ -363,6 +370,56 @@ def test_multiclass_prediction_early_stopping(self): ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter)) self.assertLess(ret, 0.2) + def test_multi_class_error(self): + X, y = load_digits(return_X_y=True) + params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'num_leaves': 4, 'seed': 0, + 'num_rounds': 30, 'verbose': -1} + lgb_data = lgb.Dataset(X, label=y) + results = {} + est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results) + predict_default = est.predict(X) + params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 1, + 'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10} + results = {} + est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results) + predict_1 = est.predict(X) + # check that default gives same result as k = 1 + np.testing.assert_array_almost_equal(predict_1, predict_default, 5) + # check against independent calculation for k = 1 + err = top_k_error(y, predict_1, 1) + np.testing.assert_almost_equal(results['train']['multi_error'][-1], err, 5) + # check against independent calculation for k = 2 + params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 2, + 'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10} + results = {} + est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results) + predict_2 = est.predict(X) + err = top_k_error(y, predict_2, 2) + np.testing.assert_almost_equal(results['train']['multi_error@2'][-1], err, 5) + # check against independent calculation for k = 10 + params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 10, + 'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10} + results = {} + est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results) + predict_2 = est.predict(X) + err = top_k_error(y, predict_2, 10) + np.testing.assert_almost_equal(results['train']['multi_error@10'][-1], err, 5) + # check case where predictions are equal + X = np.array([[0, 0], [0, 0]]) + y = np.array([0, 1]) + lgb_data = lgb.Dataset(X, label=y) + params = {'objective': 'multiclass', 'num_classes': 2, 'metric': 'multi_error', 'multi_error_top_k': 1, + 'num_leaves': 4, 'seed': 0, 'num_rounds': 1, 'verbose': -1, 'metric_freq': 10} + results = {} + lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results) + np.testing.assert_almost_equal(results['train']['multi_error'][-1], 1, 5) + lgb_data = lgb.Dataset(X, label=y) + params = {'objective': 'multiclass', 'num_classes': 2, 'metric': 'multi_error', 'multi_error_top_k': 2, + 'num_leaves': 4, 'seed': 0, 'num_rounds': 1, 'verbose': -1, 'metric_freq': 10} + results = {} + lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results) + np.testing.assert_almost_equal(results['train']['multi_error@2'][-1], 0, 5) + def test_early_stopping(self): X, y = load_breast_cancer(True) params = {