Skip to content

Commit bf8e5e2

Browse files
authored
[ML] Add offset in the MSLE computation (#1200)
While adding the additional function parameter in #1168, I wired it in the constructor of the MSLE loss function, but not in the computation of the objective. This PR fixes this, it basically substitutes log(1+x) by log(offset+x) in many different places. I mark it as a non-issue since the MSLE loss function was not released yet.
1 parent e8d46a7 commit bf8e5e2

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

include/maths/CBoostedTreeLoss.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class MATHS_EXPORT CArgMinMsleImpl final : public CArgMinLossImpl {
7676
using TObjective = std::function<double(double)>;
7777

7878
public:
79-
CArgMinMsleImpl(double lambda);
79+
CArgMinMsleImpl(double lambda, double offset = 1.0);
8080
std::unique_ptr<CArgMinLossImpl> clone() const override;
8181
bool nextPass() override;
8282
void add(const TMemoryMappedFloatVector& prediction, double actual, double weight = 1.0) override;
@@ -122,6 +122,7 @@ class MATHS_EXPORT CArgMinMsleImpl final : public CArgMinLossImpl {
122122

123123
private:
124124
std::size_t m_CurrentPass = 0;
125+
double m_Offset = 1.0;
125126
TMinMaxAccumulator m_ExpPredictionMinMax;
126127
TMinMaxAccumulator m_LogActualMinMax;
127128
TVectorMeanAccumulatorVecVec m_Buckets;

lib/maths/CBoostedTreeLoss.cc

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ CArgMinMultinomialLogisticLossImpl::objectiveGradient() const {
388388
};
389389
}
390390

391-
CArgMinMsleImpl::CArgMinMsleImpl(double lambda)
392-
: CArgMinLossImpl{lambda}, m_Buckets(MSLE_BUCKET_SIZE) {
391+
CArgMinMsleImpl::CArgMinMsleImpl(double lambda, double offset)
392+
: CArgMinLossImpl{lambda}, m_Offset{offset}, m_Buckets(MSLE_BUCKET_SIZE) {
393393
for (auto& bucket : m_Buckets) {
394394
bucket.resize(MSLE_BUCKET_SIZE);
395395
}
@@ -406,7 +406,7 @@ bool CArgMinMsleImpl::nextPass() {
406406

407407
void CArgMinMsleImpl::add(const TMemoryMappedFloatVector& prediction, double actual, double weight) {
408408
double expPrediction{CTools::stableExp(prediction[0])};
409-
double logActual{CTools::fastLog(1.0 + actual)};
409+
double logActual{CTools::fastLog(m_Offset + actual)};
410410
switch (m_CurrentPass) {
411411
case 0: {
412412
m_ExpPredictionMinMax.add(expPrediction);
@@ -415,7 +415,7 @@ void CArgMinMsleImpl::add(const TMemoryMappedFloatVector& prediction, double act
415415
break;
416416
}
417417
case 1: {
418-
double logError{logActual - CTools::fastLog(1.0 + expPrediction)};
418+
double logError{logActual - CTools::fastLog(m_Offset + expPrediction)};
419419
TVector example;
420420
example(MSLE_PREDICTION_INDEX) = expPrediction;
421421
example(MSLE_ACTUAL_INDEX) = logActual;
@@ -497,7 +497,7 @@ CArgMinMsleImpl::TObjective CArgMinMsleImpl::objective() const {
497497
if (this->bucketWidth().first == 0.0) {
498498
// prediction is constant
499499
double expPrediction{m_ExpPredictionMinMax.max()};
500-
double logPrediction{CTools::fastLog(1.0 + expPrediction * weight)};
500+
double logPrediction{CTools::fastLog(m_Offset + expPrediction * weight)};
501501
double meanLogActual{CBasicStatistics::mean(m_MeanLogActual)};
502502
double meanLogActualSquared{CBasicStatistics::variance(m_MeanLogActual) +
503503
CTools::pow2(meanLogActual)};
@@ -514,7 +514,7 @@ CArgMinMsleImpl::TObjective CArgMinMsleImpl::objective() const {
514514
const auto& bucketMean{CBasicStatistics::mean(bucketActual)};
515515
double expPrediction{bucketMean(MSLE_PREDICTION_INDEX)};
516516
double logActual{bucketMean(MSLE_ACTUAL_INDEX)};
517-
double logPrediction{CTools::fastLog(1.0 + expPrediction * weight)};
517+
double logPrediction{CTools::fastLog(m_Offset + expPrediction * weight)};
518518
loss += count * CTools::pow2(logActual - logPrediction);
519519
totalCount += count;
520520
}
@@ -776,22 +776,22 @@ std::size_t CMsle::numberParameters() const {
776776

777777
double CMsle::value(const TMemoryMappedFloatVector& logPrediction, double actual, double weight) const {
778778
double prediction{CTools::stableExp(logPrediction(0))};
779-
double log1PlusPrediction{CTools::fastLog(1.0 + prediction)};
779+
double logOffsetPrediction{CTools::stableLog(m_Offset + prediction)};
780780
if (actual < 0.0) {
781781
HANDLE_FATAL(<< "Input error: target value needs to be non-negative to use "
782782
<< "with MSLE loss, received: " << actual)
783783
}
784-
double log1PlusActual{CTools::fastLog(1.0 + actual)};
785-
return weight * CTools::pow2(log1PlusPrediction - log1PlusActual);
784+
double logOffsetActual{CTools::stableLog(m_Offset + actual)};
785+
return weight * CTools::pow2(logOffsetPrediction - logOffsetActual);
786786
}
787787

788788
void CMsle::gradient(const TMemoryMappedFloatVector& logPrediction,
789789
double actual,
790790
TWriter writer,
791791
double weight) const {
792792
double prediction{CTools::stableExp(logPrediction(0))};
793-
double log1PlusPrediction{CTools::fastLog(1.0 + prediction)};
794-
double log1PlusActual{CTools::fastLog(1.0 + actual)};
793+
double log1PlusPrediction{CTools::stableLog(m_Offset + prediction)};
794+
double log1PlusActual{CTools::stableLog(m_Offset + actual)};
795795
writer(0, 2.0 * weight * (log1PlusPrediction - log1PlusActual) / (prediction + 1.0));
796796
}
797797

@@ -800,12 +800,13 @@ void CMsle::curvature(const TMemoryMappedFloatVector& logPrediction,
800800
TWriter writer,
801801
double weight) const {
802802
double prediction{CTools::stableExp(logPrediction(0))};
803-
double log1PlusPrediction{CTools::fastLog(1.0 + prediction)};
804-
double log1PlusActual{CTools::fastLog(1.0 + actual)};
803+
double logOffsetPrediction{CTools::stableLog(m_Offset + prediction)};
804+
double logOffsetActual{CTools::stableLog(m_Offset + actual)};
805805
// Apply L'Hopital's rule in the limit prediction -> actual.
806-
writer(0, prediction == actual ? 0.0
807-
: 2.0 * weight * (log1PlusPrediction - log1PlusActual) /
808-
((prediction + 1) * (prediction - actual)));
806+
writer(0, prediction == actual
807+
? 0.0
808+
: 2.0 * weight * (logOffsetPrediction - logOffsetActual) /
809+
((prediction + m_Offset) * (prediction - actual)));
809810
}
810811

811812
bool CMsle::isCurvatureConstant() const {

0 commit comments

Comments
 (0)