Skip to content

Commit

Permalink
Top k multi error (#2178)
Browse files Browse the repository at this point in the history
* Implement top-k multiclass error metric. Add new parameter top_k_threshold.

* Add test for multiclass metrics

* Make test less sensitive to avoid floating-point issues.

* Change tabs to spaces.

* Fix problem with test in Python 2. Refactor to use np.testing. Decrease number of training rounds so loss is larger and easier to compare.

* Move multiclass tests into test_engine.py

* Change parameter name from top_k_threshold to multi_error_top_k.

* Fix top-k error metric to handle case where scores are equal. Update tests and docs.

* Change name of top-k metric to multi_error@k.

* Change tabs to spaces.

* Fix formatting.

* Fix minor issues in docs.
  • Loading branch information
btrotta authored and StrikerRUS committed May 26, 2019
1 parent 19de2be commit b3db9e9
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 15 deletions.
12 changes: 12 additions & 0 deletions docs/Parameters.rst
Expand Up @@ -843,6 +843,18 @@ Metric Parameters

- `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__ and `MAP <https://makarandtapaswi.wordpress.com/2012/07/02/intuition-behind-average-precision-and-map/>`__ evaluation positions, separated by ``,``

- ``multi_error_top_k`` :raw-html:`<a id="multi_error_top_k" title="Permalink to this parameter" href="#multi_error_top_k">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, constraints: ``multi_error_top_k > 0``

- used only with ``multi_error`` metric

- threshold for top-k multi-error metric

- the error on each sample is ``0`` if the true class is among the top ``multi_error_top_k`` predictions, and ``1`` otherwise

- more precisely, the error on a sample is ``0`` if there are at least ``num_classes - multi_error_top_k`` predictions strictly less than the prediction on the true class

- when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric

Network Parameters
------------------

Expand Down
8 changes: 8 additions & 0 deletions include/LightGBM/config.h
Expand Up @@ -747,6 +747,14 @@ struct Config {
// desc = `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__ and `MAP <https://makarandtapaswi.wordpress.com/2012/07/02/intuition-behind-average-precision-and-map/>`__ evaluation positions, separated by ``,``
std::vector<int> eval_at;

// check = >0
// desc = used only with ``multi_error`` metric
// desc = threshold for top-k multi-error metric
// desc = the error on each sample is ``0`` if the true class is among the top ``multi_error_top_k`` predictions, and ``1`` otherwise
// descl2 = more precisely, the error on a sample is ``0`` if there are at least ``num_classes - multi_error_top_k`` predictions strictly less than the prediction on the true class
// desc = when ``multi_error_top_k=1`` this is equivalent to the usual multi-error metric
int multi_error_top_k = 1;

#pragma endregion

#pragma region Network Parameters
Expand Down
5 changes: 5 additions & 0 deletions src/io/config_auto.cpp
Expand Up @@ -260,6 +260,7 @@ std::unordered_set<std::string> Config::parameter_set({
"metric_freq",
"is_provide_training_metric",
"eval_at",
"multi_error_top_k",
"num_machines",
"local_listen_port",
"time_out",
Expand Down Expand Up @@ -521,6 +522,9 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
eval_at = Common::StringToArray<int>(tmp_str, ',');
}

GetInt(params, "multi_error_top_k", &multi_error_top_k);
CHECK(multi_error_top_k >0);

GetInt(params, "num_machines", &num_machines);
CHECK(num_machines >0);

Expand Down Expand Up @@ -637,6 +641,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n";
str_buf << "[num_machines: " << num_machines << "]\n";
str_buf << "[local_listen_port: " << local_listen_port << "]\n";
str_buf << "[time_out: " << time_out << "]\n";
Expand Down
33 changes: 18 additions & 15 deletions src/metric/multiclass_metric.hpp
Expand Up @@ -20,15 +20,15 @@ namespace LightGBM {
template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric {
public:
explicit MulticlassMetric(const Config& config) {
explicit MulticlassMetric(const Config& config) :config_(config){
num_class_ = config.num_class;
}

virtual ~MulticlassMetric() {
}

void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back(PointWiseLossCalculator::Name());
name_.emplace_back(PointWiseLossCalculator::Name(config_));
num_data_ = num_data;
// get label
label_ = metadata.label();
Expand Down Expand Up @@ -72,7 +72,7 @@ class MulticlassMetric: public Metric {
std::vector<double> rec(num_pred_per_row);
objective->ConvertOutput(raw_score.data(), rec.data());
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
Expand All @@ -85,7 +85,7 @@ class MulticlassMetric: public Metric {
std::vector<double> rec(num_pred_per_row);
objective->ConvertOutput(raw_score.data(), rec.data());
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_) * weights_[i];
}
}
} else {
Expand All @@ -98,7 +98,7 @@ class MulticlassMetric: public Metric {
rec[k] = static_cast<double>(score[idx]);
}
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
Expand All @@ -109,7 +109,7 @@ class MulticlassMetric: public Metric {
rec[k] = static_cast<double>(score[idx]);
}
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec, config_) * weights_[i];
}
}
}
Expand All @@ -129,25 +129,28 @@ class MulticlassMetric: public Metric {
/*! \brief Name of this test set */
std::vector<std::string> name_;
int num_class_;
/*! \brief config parameters*/
Config config_;
};

/*! \brief L2 loss for multiclass task */
/*! \brief top-k error for multiclass task; if k=1 (default) this is the usual multi-error */
class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public:
explicit MultiErrorMetric(const Config& config) :MulticlassMetric<MultiErrorMetric>(config) {}

inline static double LossOnPoint(label_t label, std::vector<double>& score) {
inline static double LossOnPoint(label_t label, std::vector<double>& score, const Config& config) {
size_t k = static_cast<size_t>(label);
int num_larger = 0;
for (size_t i = 0; i < score.size(); ++i) {
if (i != k && score[i] >= score[k]) {
return 1.0f;
}
if (score[i] >= score[k]) ++num_larger;
if (num_larger > config.multi_error_top_k) return 1.0f;
}
return 0.0f;
}

inline static const char* Name() {
return "multi_error";
inline static const std::string Name(const Config& config) {
if (config.multi_error_top_k == 1) return "multi_error";
else return "multi_error@" + std::to_string(config.multi_error_top_k);
}
};

Expand All @@ -156,7 +159,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
public:
explicit MultiSoftmaxLoglossMetric(const Config& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {}

inline static double LossOnPoint(label_t label, std::vector<double>& score) {
inline static double LossOnPoint(label_t label, std::vector<double>& score, const Config&) {
size_t k = static_cast<size_t>(label);
if (score[k] > kEpsilon) {
return static_cast<double>(-std::log(score[k]));
Expand All @@ -165,7 +168,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
}
}

inline static const char* Name() {
inline static const std::string Name(const Config&) {
return "multi_logloss";
}
};
Expand Down
57 changes: 57 additions & 0 deletions tests/python_package_test/test_engine.py
Expand Up @@ -26,6 +26,13 @@ def multi_logloss(y_true, y_pred):
return np.mean([-math.log(y_pred[i][y]) for i, y in enumerate(y_true)])


def top_k_error(y_true, y_pred, k):
if k == y_pred.shape[1]:
return 0
max_rest = np.max(-np.partition(-y_pred, k)[:, k:], axis=1)
return 1 - np.mean((y_pred[np.arange(len(y_true)), y_true] > max_rest))


class TestEngine(unittest.TestCase):
def test_binary(self):
X, y = load_breast_cancer(True)
Expand Down Expand Up @@ -363,6 +370,56 @@ def test_multiclass_prediction_early_stopping(self):
ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter))
self.assertLess(ret, 0.2)

def test_multi_class_error(self):
X, y = load_digits(return_X_y=True)
params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'num_leaves': 4, 'seed': 0,
'num_rounds': 30, 'verbose': -1}
lgb_data = lgb.Dataset(X, label=y)
results = {}
est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
predict_default = est.predict(X)
params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 1,
'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10}
results = {}
est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
predict_1 = est.predict(X)
# check that default gives same result as k = 1
np.testing.assert_array_almost_equal(predict_1, predict_default, 5)
# check against independent calculation for k = 1
err = top_k_error(y, predict_1, 1)
np.testing.assert_almost_equal(results['train']['multi_error'][-1], err, 5)
# check against independent calculation for k = 2
params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 2,
'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10}
results = {}
est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
predict_2 = est.predict(X)
err = top_k_error(y, predict_2, 2)
np.testing.assert_almost_equal(results['train']['multi_error@2'][-1], err, 5)
# check against independent calculation for k = 10
params = {'objective': 'multiclass', 'num_classes': 10, 'metric': 'multi_error', 'multi_error_top_k': 10,
'num_leaves': 4, 'seed': 0, 'num_rounds': 30, 'verbose': -1, 'metric_freq': 10}
results = {}
est = lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
predict_2 = est.predict(X)
err = top_k_error(y, predict_2, 10)
np.testing.assert_almost_equal(results['train']['multi_error@10'][-1], err, 5)
# check case where predictions are equal
X = np.array([[0, 0], [0, 0]])
y = np.array([0, 1])
lgb_data = lgb.Dataset(X, label=y)
params = {'objective': 'multiclass', 'num_classes': 2, 'metric': 'multi_error', 'multi_error_top_k': 1,
'num_leaves': 4, 'seed': 0, 'num_rounds': 1, 'verbose': -1, 'metric_freq': 10}
results = {}
lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
np.testing.assert_almost_equal(results['train']['multi_error'][-1], 1, 5)
lgb_data = lgb.Dataset(X, label=y)
params = {'objective': 'multiclass', 'num_classes': 2, 'metric': 'multi_error', 'multi_error_top_k': 2,
'num_leaves': 4, 'seed': 0, 'num_rounds': 1, 'verbose': -1, 'metric_freq': 10}
results = {}
lgb.train(params, lgb_data, valid_sets=[lgb_data], valid_names=['train'], evals_result=results)
np.testing.assert_almost_equal(results['train']['multi_error@2'][-1], 0, 5)

def test_early_stopping(self):
X, y = load_breast_cancer(True)
params = {
Expand Down

0 comments on commit b3db9e9

Please sign in to comment.