Skip to content

Commit

Permalink
trivial refine CheckElementsIntervalClosed & ObtainMinMaxSum (#1049)
Browse files Browse the repository at this point in the history
* refine common.h

* fix typo

* specify captured variables
  • Loading branch information
wxchan authored and guolinke committed Nov 14, 2017
1 parent bd5e5e3 commit 302f84b
Showing 1 changed file with 62 additions and 16 deletions.
78 changes: 62 additions & 16 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,29 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
template <typename T>
inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
for (int i = 0; i < ny; ++i) {
if (y[i] < ymin || y[i] > ymax) {
std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
Log::Fatal(os.str().c_str(), callername, i);
auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) {
std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
Log::Fatal(os.str().c_str(), callername, i);
};
for (int i = 1; i < ny; i += 2) {
if (y[i - 1] < y[i]) {
if (y[i - 1] < ymin) {
fatal_msg(i - 1);
} else if (y[i] > ymax) {
fatal_msg(i);
}
} else {
if (y[i - 1] > ymax) {
fatal_msg(i - 1);
} else if (y[i] < ymin) {
fatal_msg(i);
}
}
}
if (ny & 1) { // odd
if (y[ny - 1] < ymin || y[ny - 1] > ymax) {
fatal_msg(ny - 1);
}
}
}
Expand All @@ -609,17 +627,45 @@ inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, cons
// this is useful for checking weight requirements.
template <typename T1, typename T2>
inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
T1 minw = w[0];
T1 maxw = w[0];
T2 sumw = static_cast<T2>(w[0]);
for (int i = 1; i < nw; ++i) {
sumw += w[i];
if (w[i] < minw) minw = w[i];
if (w[i] > maxw) maxw = w[i];
}
if (mi != nullptr) *mi = minw;
if (ma != nullptr) *ma = maxw;
if (su != nullptr) *su = sumw;
T1 minw;
T1 maxw;
T1 sumw;
int i;
if (nw & 1) { // odd
minw = w[0];
maxw = w[0];
sumw = w[0];
i = 2;
} else { // even
if (w[0] < w[1]) {
minw = w[0];
maxw = w[1];
} else {
minw = w[1];
maxw = w[0];
}
sumw = w[0] + w[1];
i = 3;
}
for (; i < nw; i += 2) {
if (w[i - 1] < w[i]) {
minw = std::min(minw, w[i - 1]);
maxw = std::max(maxw, w[i]);
} else {
minw = std::min(minw, w[i]);
maxw = std::max(maxw, w[i - 1]);
}
sumw += w[i - 1] + w[i];
}
if (mi != nullptr) {
*mi = minw;
}
if (ma != nullptr) {
*ma = maxw;
}
if (su != nullptr) {
*su = static_cast<T2>(sumw);
}
}

template<class T>
Expand Down

0 comments on commit 302f84b

Please sign in to comment.