Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add interface function for updating learning_rate per each iteration in LightGBMDelegate #849

Merged
merged 10 commits into from
Apr 7, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ abstract class LightGBMDelegate extends Serializable {
boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean, isFinished: Boolean,
trainEvalResults: Option[Map[String, Double]],
validEvalResults: Option[Map[String, Double]]): Unit

def getLearningRate(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
previousLearningRate: Double): Double = previousLearningRate
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,19 @@ private object TrainUtils extends Serializable {
val bestIter = new Array[Int](evalCounts)
val delegate = trainParams.delegate
val partitionId = TaskContext.getPartitionId
var learningRate: Double = trainParams.learningRate
while (!isFinished && iters < trainParams.numIterations) {

if (delegate.isDefined) {
delegate.get.beforeTrainIteration(partitionId, iters, log, trainParams, boosterPtr, hasValid)
val newLearningRate = delegate.get.getLearningRate(partitionId, iters, log, trainParams, learningRate)
if (newLearningRate != learningRate) {
log.info(s"LightGBM worker calling LGBM_BoosterResetParameter to reset learningRate" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice logging!

s" (newLearningRate: $newLearningRate)")
LightGBMUtils.validate(lightgbmlib.LGBM_BoosterResetParameter(boosterPtr.get,
s"learning_rate=$newLearningRate"), "Booster Reset learning_rate Param")
learningRate = newLearningRate
}
}

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package com.microsoft.ml.spark.lightgbm.split1
import java.io.File
import java.nio.file.{Files, Path, Paths}

import com.microsoft.ml.lightgbm.SWIGTYPE_p_void
import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.benchmarks.{Benchmarks, DatasetUtils}
import com.microsoft.ml.spark.core.test.fuzzing.{EstimatorFuzzing, TestObject}
Expand All @@ -23,6 +24,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions._
import org.slf4j.Logger

// scalastyle:off magic.number
trait LightGBMTestUtils extends TestBase {
Expand Down Expand Up @@ -360,6 +362,47 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
assert(metric > 0.8)
}

test("Verify LightGBM Classifier updating learning_rate on training by using LightGBMDelegate") {

class TrainDelegate extends LightGBMDelegate {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the test is failing because this class is not serializable?


override def beforeTrainIteration(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean): Unit = {
// nothing
}

override def afterTrainIteration(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean, isFinished: Boolean,
trainEvalResults: Option[Map[String, Double]],
validEvalResults: Option[Map[String, Double]]): Unit = {
// nothing
}

override def getLearningRate(partitionId: Int, curIters: Int, log: Logger, trainParams: TrainParams,
previousLearningRate: Double): Double = {
if (curIters == 0) {
previousLearningRate
} else {
previousLearningRate * 0.05
}
}

}

val Array(train, _) = indexedBankTrainDF.randomSplit(Array(0.8, 0.2), seed)
val delegate = new TrainDelegate()
val untrainedModel = baseModel
.setCategoricalSlotNames(indexedBankTrainDF.columns.filter(_.startsWith("c_")))
.setDelegate(delegate)
.setLearningRate(0.1)
.setNumIterations(2) // expected learning_rate: iters 0 => 0.1, iters 1 => 0.005

val model = untrainedModel.fit(train)

// Verify updating learning_rate
assert(model.getModel.model.contains("learning_rate: 0.005"))
}

test("Verify LightGBM Classifier leaf prediction") {
val Array(train, test) = indexedBankTrainDF.randomSplit(Array(0.8, 0.2), seed)
val untrainedModel = baseModel
Expand Down