Skip to content

Commit

Permalink
Added numClasses and objective, infer actualNumClasses from objective (
Browse files Browse the repository at this point in the history
…microsoft#348)

* Added numClasses and objective, infer actualNumClasses from objective

* Update LightGBM notebook example

* Remove numClasses since it is now inferred from dataset
  • Loading branch information
terrytangyuan authored and drdarshan committed Aug 4, 2018
1 parent bdf354f commit 7dfcc96
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"outputs": [],
"source": [
"from mmlspark import LightGBMRegressor\n",
"model = LightGBMRegressor(application='quantile',\n",
"model = LightGBMRegressor(objective='quantile',\n",
" alpha=0.2,\n",
" learningRate=0.3,\n",
" numLeaves=31).fit(train)"
Expand Down
12 changes: 0 additions & 12 deletions src/lightgbm/src/main/scala/LightGBMBooster.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,6 @@ class LightGBMBooster(val model: String) extends Serializable {
@transient
var boosterPtr: SWIGTYPE_p_void = null

def numClasses(): Int = {
if (boosterPtr == null) {
LightGBMUtils.initializeNativeLibrary()
boosterPtr = getModel()
}
val numClasses = lightgbmlib.new_intp()
LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterGetNumClasses(boosterPtr, numClasses),
"Booster GetNumClasses")
lightgbmlib.intp_value(numClasses)
}

def score(features: Vector, raw: Boolean): Double = {
// Reload booster on each node
if (boosterPtr == null) {
Expand Down
17 changes: 11 additions & 6 deletions src/lightgbm/src/main/scala/LightGBMClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ class LightGBMClassifier(override val uid: String)
log.info(s"Nodes used for LightGBM: ${nodes.mkString(",")}")
val trainParams = ClassifierTrainParams(getParallelism, getNumIterations, getLearningRate, getNumLeaves,
getMaxBin, getBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numWorkers)
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numWorkers, getObjective)
/* The native code for getting numClasses is always 1 unless it is multiclass-classification problem
* so we infer the actual numClasses from the dataset here
*/
val actualNumClasses = getNumClasses(dataset)
val networkParams = NetworkParams(nodes.toMap, getDefaultListenPort, inetAddress, port)
val lightGBMBooster = df
.mapPartitions(TrainUtils.trainLightGBM(networkParams, getLabelCol, getFeaturesCol,
Expand All @@ -60,7 +64,7 @@ class LightGBMClassifier(override val uid: String)
Await.result(future, Duration(getTimeout, SECONDS))
new LightGBMClassificationModel(uid, lightGBMBooster, getLabelCol, getFeaturesCol,
getPredictionCol, getProbabilityCol, getRawPredictionCol,
if (isDefined(thresholds)) Some(getThresholds) else None)
if (isDefined(thresholds)) Some(getThresholds) else None, actualNumClasses)
}

override def copy(extra: ParamMap): LightGBMClassifier = defaultCopy(extra)
Expand All @@ -71,7 +75,8 @@ class LightGBMClassifier(override val uid: String)
class LightGBMClassificationModel(
override val uid: String, model: LightGBMBooster, labelColName: String,
featuresColName: String, predictionColName: String, probColName: String,
rawPredictionColName: String, thresholdValues: Option[Array[Double]])
rawPredictionColName: String, thresholdValues: Option[Array[Double]],
actualNumClasses: Int)
extends ProbabilisticClassificationModel[Vector, LightGBMClassificationModel]
with ConstructorWritable[LightGBMClassificationModel] {

Expand All @@ -96,7 +101,7 @@ class LightGBMClassificationModel(
}
}

override def numClasses: Int = model.numClasses()
override def numClasses: Int = this.actualNumClasses

override protected def predictRaw(features: Vector): Vector = {
val prediction = model.score(features, true)
Expand All @@ -105,14 +110,14 @@ class LightGBMClassificationModel(

override def copy(extra: ParamMap): LightGBMClassificationModel =
new LightGBMClassificationModel(uid, model, labelColName, featuresColName, predictionColName, probColName,
rawPredictionColName, thresholdValues)
rawPredictionColName, thresholdValues, actualNumClasses)

override val ttag: TypeTag[LightGBMClassificationModel] =
typeTag[LightGBMClassificationModel]

override def objectsToSave: List[Any] =
List(uid, model, getLabelCol, getFeaturesCol, getPredictionCol,
getProbabilityCol, getRawPredictionCol, thresholdValues)
getProbabilityCol, getRawPredictionCol, thresholdValues, actualNumClasses)

def saveNativeModel(session: SparkSession, filename: String): Unit = {
model.saveNativeModel(session, filename)
Expand Down
8 changes: 8 additions & 0 deletions src/lightgbm/src/main/scala/LightGBMParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ trait LightGBMParams extends MMLParams {
def getNumLeaves: Int = $(numLeaves)
def setNumLeaves(value: Int): this.type = set(numLeaves, value)

val objective = StringParam(this, "objective",
"The Objective. For regression applications, this can be: " +
"regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie. " +
"For classification applications, this can be: binary, multiclass, or multiclassova. ", "regression")

def getObjective: String = $(objective)
def setObjective(value: String): this.type = set(objective, value)

val maxBin = IntParam(this, "maxBin", "Max bin", 255)

def getMaxBin: Int = $(maxBin)
Expand Down
9 changes: 1 addition & 8 deletions src/lightgbm/src/main/scala/LightGBMRegressor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ class LightGBMRegressor(override val uid: String)
def getAlpha: Double = $(alpha)
def setAlpha(value: Double): this.type = set(alpha, value)

val application = StringParam(this, "application",
"Regression application, regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie",
"regression")

def getApplication: String = $(application)
def setApplication(value: String): this.type = set(application, value)

/** Trains the LightGBM Regression model.
*
* @param dataset The input dataset to train.
Expand All @@ -72,7 +65,7 @@ class LightGBMRegressor(override val uid: String)
val encoder = Encoders.kryo[LightGBMBooster]
log.info(s"Nodes used for LightGBM: ${nodes.mkString(",")}")
val trainParams = RegressorTrainParams(getParallelism, getNumIterations, getLearningRate, getNumLeaves,
getApplication, getAlpha, getMaxBin, getBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound,
getObjective, getAlpha, getMaxBin, getBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numWorkers)
val networkParams = NetworkParams(nodes.toMap, getDefaultListenPort, inetAddress, port)
val lightGBMBooster = df
Expand Down
12 changes: 7 additions & 5 deletions src/lightgbm/src/main/scala/TrainParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ abstract class TrainParams extends Serializable {
def maxDepth: Int
def minSumHessianInLeaf: Double
def numMachines: Int
def objective: String

override def toString(): String = {
s"is_pre_partition=True boosting_type=gbdt tree_learner=$parallelism num_iterations=$numIterations " +
s"learning_rate=$learningRate num_leaves=$numLeaves " +
s"max_bin=$maxBin bagging_fraction=$baggingFraction bagging_freq=$baggingFreq " +
s"bagging_seed=$baggingSeed early_stopping_round=$earlyStoppingRound " +
s"feature_fraction=$featureFraction max_depth=$maxDepth min_sum_hessian_in_leaf=$minSumHessianInLeaf " +
s"num_machines=$numMachines"
s"num_machines=$numMachines objective=$objective"
}
}

Expand All @@ -35,22 +36,23 @@ abstract class TrainParams extends Serializable {
case class ClassifierTrainParams(val parallelism: String, val numIterations: Int, val learningRate: Double,
val numLeaves: Int, val maxBin: Int, val baggingFraction: Double, val baggingFreq: Int,
val baggingSeed: Int, val earlyStoppingRound: Int, val featureFraction: Double,
val maxDepth: Int, val minSumHessianInLeaf: Double, val numMachines: Int)
val maxDepth: Int, val minSumHessianInLeaf: Double,
val numMachines: Int, val objective: String)
extends TrainParams {
override def toString(): String = {
s"objective=binary metric=binary_logloss,auc ${super.toString()}"
s"metric=binary_logloss,auc ${super.toString()}"
}
}

/** Defines the Booster parameters passed to the LightGBM regressor.
*/
case class RegressorTrainParams(val parallelism: String, val numIterations: Int, val learningRate: Double,
val numLeaves: Int, val application: String, val alpha: Double, val maxBin: Int,
val numLeaves: Int, val objective: String, val alpha: Double, val maxBin: Int,
val baggingFraction: Double, val baggingFreq: Int,
val baggingSeed: Int, val earlyStoppingRound: Int, val featureFraction: Double,
val maxDepth: Int, val minSumHessianInLeaf: Double, val numMachines: Int)
extends TrainParams {
override def toString(): String = {
s"objective=$application alpha=$alpha ${super.toString()}"
s"alpha=$alpha ${super.toString()}"
}
}
8 changes: 7 additions & 1 deletion src/lightgbm/src/test/scala/VerifyLightGBMClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
lazy val moduleName = "lightgbm"
var portIndex = 30
val numPartitions = 2
val objective = "binary"

// TODO: Need to add multiclass param with objective function
// verifyLearnerOnMulticlassCsvFile("abalone.csv", "Rings", 2)
Expand Down Expand Up @@ -51,6 +52,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
.setDefaultListenPort(LightGBMConstants.defaultLocalListenPort + portIndex)
.setNumLeaves(5)
.setNumIterations(10)
.setObjective(objective)

val paramGrid = new ParamGridBuilder()
.addGrid(lgbm.numLeaves, Array(5, 10))
Expand Down Expand Up @@ -102,6 +104,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
.setDefaultListenPort(LightGBMConstants.defaultLocalListenPort + portIndex)
.setNumLeaves(5)
.setNumIterations(10)
.setObjective(objective)
.fit(trainData)
val scoredResult = model.transform(trainData).drop(featuresColumn)
val splitFeatureImportances = model.getFeatureImportances("split")
Expand Down Expand Up @@ -134,6 +137,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
.setDefaultListenPort(LightGBMConstants.defaultLocalListenPort + portIndex)
.setNumLeaves(5)
.setNumIterations(10)
.setObjective(objective)
.fit(trainData)
val scoredResult = model.transform(trainData).drop(featuresColumn)
val splitFeatureImportances = model.getFeatureImportances("split")
Expand Down Expand Up @@ -161,7 +165,8 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
new LightGBMClassifier()
.setLabelCol(labelCol)
.setFeaturesCol(featuresCol)
.setNumLeaves(5),
.setNumLeaves(5)
.setObjective(objective),
train))
}

Expand All @@ -183,6 +188,7 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
.setDefaultListenPort(LightGBMConstants.defaultLocalListenPort + portIndex)
.setNumLeaves(5)
.setNumIterations(10)
.setObjective(objective)
.fit(featurizer.transform(dataset))

val targetDir: Path = Paths.get(getClass.getResource("/").toURI)
Expand Down

0 comments on commit 7dfcc96

Please sign in to comment.