Skip to content

Commit

Permalink
fix: added multiclass init score support (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Feb 28, 2020
1 parent e745784 commit df0244c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
(get(weightCol), Seq(DoubleType)),
(getOptGroupCol, Seq(IntegerType, LongType, StringType)),
(get(validationIndicatorCol), Seq(BooleanType)),
(get(initScoreCol), Seq(DoubleType)))
(get(initScoreCol), Seq(DoubleType, VectorType)))

colsToCheck.flatMap { case (col: Option[String], colType: Seq[DataType]) => {
if (col.isDefined) Some(col.get, colType) else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class LightGBMDataset(val dataset: SWIGTYPE_p_void) extends AutoCloseable {
// Generate the column and add to dataset
var colArray: Option[SWIGTYPE_p_double] = None
try {
colArray = Some(lightgbmlib.new_doubleArray(numRows))
colArray = Some(lightgbmlib.new_doubleArray(field.length))
field.zipWithIndex.foreach(ri =>
lightgbmlib.doubleArray_setitem(colArray.get, ri._2, ri._1))
val colAsVoidPtr = lightgbmlib.double_to_voidp_ptr(colArray.get)
Expand Down
30 changes: 26 additions & 4 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.apache.spark.BarrierTaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.slf4j.Logger
Expand Down Expand Up @@ -57,11 +58,9 @@ private object TrainUtils extends Serializable {
val weights = rows.map(row => row.getDouble(schema.fieldIndex(col)))
datasetPtr.get.addFloatField(weights, "weight", numRows)
}
addInitScoreColumn(rows, initScoreColumn, datasetPtr, numRows, schema)
addGroupColumn(rows, groupColumn, datasetPtr, numRows, schema)
initScoreColumn.foreach { col =>
val initScores = rows.map(row => row.getDouble(schema.fieldIndex(col)))
datasetPtr.get.addDoubleField(initScores, "init_score", numRows)
}

datasetPtr
}

Expand All @@ -79,6 +78,29 @@ private object TrainUtils extends Serializable {

import CardinalityTypes._

def addInitScoreColumn(rows: Array[Row], initScoreColumn: Option[String],
datasetPtr: Option[LightGBMDataset], numRows: Int, schema: StructType): Unit = {
initScoreColumn.foreach { col =>
val field = schema.fields(schema.fieldIndex(col))
if (field.dataType == VectorType) {
val initScores = rows.map(row => row.get(schema.fieldIndex(col)).asInstanceOf[DenseVector])
// Calculate # rows * # classes in multiclass case
val initScoresLength = initScores.length
val totalLength = initScoresLength * initScores(0).size
val flattenedInitScores = new Array[Double](totalLength)
initScores.zipWithIndex.foreach { case (rowVector, rowIndex) =>
rowVector.values.zipWithIndex.foreach { case (rowValue, colIndex) =>
flattenedInitScores(colIndex * initScoresLength + rowIndex) = rowValue
}
}
datasetPtr.get.addDoubleField(flattenedInitScores, "init_score", numRows)
} else {
val initScores = rows.map(row => row.getDouble(schema.fieldIndex(col)))
datasetPtr.get.addDoubleField(initScores, "init_score", numRows)
}
}
}

def addGroupColumn(rows: Array[Row], groupColumn: Option[String],
datasetPtr: Option[LightGBMDataset], numRows: Int, schema: StructType): Unit = {
groupColumn.foreach { col =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
assert(binaryEvaluator.evaluate(sdf1) < binaryEvaluator.evaluate(sdf2))
}

def assertMulticlassImprovement(sdf1: DataFrame, sdf2: DataFrame): Unit = {
assert(multiclassEvaluator.evaluate(sdf1) < multiclassEvaluator.evaluate(sdf2))
}

def assertBinaryImprovement(v1: LightGBMClassifier, train1: DataFrame, test1: DataFrame,
v2: LightGBMClassifier, train2: DataFrame, test2: DataFrame
): Unit = {
Expand All @@ -257,6 +261,15 @@ class VerifyLightGBMClassifier extends Benchmarks with EstimatorFuzzing[LightGBM
assertBinaryImprovement(scoredDF1, scoredDF2)
}

test("Verify LightGBM Multiclass Classifier with vector initial score") {
val scoredDF1 = baseModel.fit(breastTissueDF).transform(breastTissueDF)
val df2 = scoredDF1.withColumn(initScoreCol, col(rawPredCol))
.drop(predCol, rawPredCol, probCol, leafPredCol)
val scoredDF2 = baseModel.setInitScoreCol(initScoreCol).fit(df2).transform(df2)

assertMulticlassImprovement(scoredDF1, scoredDF2)
}

test("Verify LightGBM Classifier with min gain to split parameter") {
// If the min gain to split is too high, assert AUC lower for training data (assert parameter works)
val scoredDF1 = baseModel.setMinGainToSplit(99999).fit(pimaDF).transform(pimaDF)
Expand Down

0 comments on commit df0244c

Please sign in to comment.