Skip to content

Commit

Permalink
Change functions in common.h into template functions (#969) (#973)
Browse files Browse the repository at this point in the history
* Fix coding style (#969)

Function names must be in the "Pascal Case" style.

* check_elements_interval_closed to CheckElementsIntervalClosed

* obtain_min_max_sum to ObtainMinMaxSum

* Change functions in common.h into template functions (#969)

* CheckElementsIntervalClosed

* ObtainMinMaxSum

These two functions were changed into template functions.

* Remove an unpreferable overload

* remove an overload of the function ObtainMinMaxSum

* Use stringstream to format T type
  • Loading branch information
Tony-Y authored and guolinke committed Oct 8, 2017
1 parent 6d34fb8 commit 87fa8b5
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
16 changes: 10 additions & 6 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,20 +580,24 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
}

// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
inline void CheckElementsIntervalClosed(const float *y, float ymin, float ymax, int ny, const char *callername) {
template <typename T>
inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
for (int i = 0; i < ny; ++i) {
if (y[i] < ymin || y[i] > ymax) {
Log::Fatal("[%s]: does not tolerate element [#%i = %f] outside [%f, %f]", callername, i, y[i], ymin, ymax);
std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
Log::Fatal(os.str().c_str(), callername, i);
}
}
}

// One-pass scan over array w with nw elements: find min, max and sum of elements;
// this is useful for checking weight requirements.
inline void ObtainMinMaxSum(const float *w, int nw, float *mi, float *ma, double *su) {
float minw = w[0];
float maxw = w[0];
double sumw = static_cast<double>(w[0]);
template <typename T1, typename T2>
inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
T1 minw = w[0];
T1 maxw = w[0];
T2 sumw = static_cast<T2>(w[0]);
for (int i = 1; i < nw; ++i) {
sumw += w[i];
if (w[i] < minw) minw = w[i];
Expand Down
6 changes: 3 additions & 3 deletions src/metric/xentropy_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class CrossEntropyMetric : public Metric {
sum_weights_ = static_cast<double>(num_data_);
} else {
float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sum_weights_);
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sum_weights_);
if (minw < 0.0f) {
Log::Fatal("[%s:%s]: (metric) weights not allowed to be negative", GetName()[0].c_str(), __func__);
}
Expand Down Expand Up @@ -178,7 +178,7 @@ class CrossEntropyLambdaMetric : public Metric {
// check all weights are strictly positive; throw error if not
if (weights_ != nullptr) {
float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, nullptr);
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, (float*)nullptr);
if (minw <= 0.0f) {
Log::Fatal("[%s:%s]: (metric) all weights must be positive", GetName()[0].c_str(), __func__);
}
Expand Down Expand Up @@ -263,7 +263,7 @@ class KullbackLeiblerDivergence : public Metric {
sum_weights_ = static_cast<double>(num_data_);
} else {
float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sum_weights_);
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sum_weights_);
if (minw < 0.0f) {
Log::Fatal("[%s:%s]: (metric) at least one weight is negative", GetName()[0].c_str(), __func__);
}
Expand Down
2 changes: 1 addition & 1 deletion src/objective/regression_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class RegressionPoissonLoss: public ObjectiveFunction {
// Safety check of labels
float miny;
double sumy;
Common::ObtainMinMaxSum(label_, num_data_, &miny, nullptr, &sumy);
Common::ObtainMinMaxSum(label_, num_data_, &miny, (float*)nullptr, &sumy);
if (miny < 0.0f) {
Log::Fatal("[%s]: at least one target label is negative.", GetName());
}
Expand Down
4 changes: 2 additions & 2 deletions src/objective/xentropy_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CrossEntropy: public ObjectiveFunction {
if (weights_ != nullptr) {
float minw;
double sumw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sumw);
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sumw);
if (minw < 0.0f) {
Log::Fatal("[%s]: at least one weight is negative.", GetName());
}
Expand Down Expand Up @@ -163,7 +163,7 @@ class CrossEntropyLambda: public ObjectiveFunction {

if (weights_ != nullptr) {

Common::ObtainMinMaxSum(weights_, num_data_, &min_weight_, &max_weight_, nullptr);
Common::ObtainMinMaxSum(weights_, num_data_, &min_weight_, &max_weight_, (float*)nullptr);
if (min_weight_ <= 0.0f) {
Log::Fatal("[%s]: at least one weight is non-positive.", GetName());
}
Expand Down

0 comments on commit 87fa8b5

Please sign in to comment.