Skip to content

Commit

Permalink
[jvm-packages] handle NaN as missing value explicitly (#4309)
Browse files Browse the repository at this point in the history
* handle nan

* handle nan explicitly

* make code better and handle sparse vector in spark

* Update XGBoostGeneralSuite.scala
  • Loading branch information
CodingCat committed Mar 30, 2019
1 parent 7ea5b77 commit ad4de0d
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 27 deletions.
Expand Up @@ -70,30 +70,53 @@ private[spark] case class XGBLabeledPointGroup(
object XGBoost extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")

private[spark] def removeMissingValues(
private def verifyMissingSetting(xgbLabelPoints: Iterator[XGBLabeledPoint], missing: Float):
Iterator[XGBLabeledPoint] = {
if (missing != 0.0f) {
xgbLabelPoints.map(labeledPoint => {
if (labeledPoint.indices != null) {
throw new RuntimeException("you can only specify missing value as 0.0 when you have" +
" SparseVector as your feature format")
}
labeledPoint
})
} else {
xgbLabelPoints
}
}

private def removeMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float,
keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
valuesBuilder += value
}
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
}

private[spark] def processMissingValues(
xgbLabelPoints: Iterator[XGBLabeledPoint],
missing: Float): Iterator[XGBLabeledPoint] = {
if (!missing.isNaN) {
xgbLabelPoints.map { labeledPoint =>
val indicesBuilder = new mutable.ArrayBuilder.ofInt()
val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
for ((value, i) <- labeledPoint.values.zipWithIndex if value != missing) {
indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
valuesBuilder += value
}
labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
}
removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing),
missing, (v: Float) => v != missing)
} else {
xgbLabelPoints
removeMissingValues(xgbLabelPoints, missing, (v: Float) => !v.isNaN)
}
}

private def removeMissingValuesWithGroup(
private def processMissingValuesWithGroup(
xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
missing: Float): Iterator[Array[XGBLabeledPoint]] = {
if (!missing.isNaN) {
xgbLabelPointGroups.map {
labeledPoints => XGBoost.removeMissingValues(labeledPoints.iterator, missing).toArray
labeledPoints => XGBoost.processMissingValues(labeledPoints.iterator, missing).toArray
}
} else {
xgbLabelPointGroups
Expand Down Expand Up @@ -310,7 +333,7 @@ object XGBoost extends Serializable {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPoints => {
val watches = Watches.buildWatches(params,
removeMissingValues(labeledPoints, missing),
processMissingValues(labeledPoints, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
Expand All @@ -320,7 +343,7 @@ object XGBoost extends Serializable {
nameAndLabeledPointSets =>
val watches = Watches.buildWatches(
nameAndLabeledPointSets.map {
case (name, iter) => (name, removeMissingValues(iter, missing))},
case (name, iter) => (name, processMissingValues(iter, missing))},
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound,
obj, eval, prevBooster)
Expand All @@ -340,7 +363,7 @@ object XGBoost extends Serializable {
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPointGroups => {
val watches = Watches.buildWatchesWithGroup(params,
removeMissingValuesWithGroup(labeledPointGroups, missing),
processMissingValuesWithGroup(labeledPointGroups, missing),
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster)
}).cache()
Expand All @@ -349,7 +372,7 @@ object XGBoost extends Serializable {
labeledPointGroupSets => {
val watches = Watches.buildWatchesWithGroup(
labeledPointGroupSets.map {
case (name, iter) => (name, removeMissingValuesWithGroup(iter, missing))
case (name, iter) => (name, processMissingValuesWithGroup(iter, missing))
},
getCacheDirName(useExternalMemory))
buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval,
Expand Down
Expand Up @@ -256,7 +256,7 @@ class XGBoostClassificationModel private[ml](
*/
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
val probability = _booster.predict(data = dm)(0).map(_.toDouble)
if (numClasses == 2) {
math.round(probability(0))
Expand Down Expand Up @@ -303,7 +303,7 @@ class XGBoostClassificationModel private[ml](
}
}
val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try {
val Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr) =
Expand Down
Expand Up @@ -247,7 +247,7 @@ class XGBoostRegressionModel private[ml] (
*/
override def predict(features: Vector): Double = {
import DataUtils._
val dm = new DMatrix(XGBoost.removeMissingValues(Iterator(features.asXGB), $(missing)))
val dm = new DMatrix(XGBoost.processMissingValues(Iterator(features.asXGB), $(missing)))
_booster.predict(data = dm)(0)(0)
}

Expand Down Expand Up @@ -275,7 +275,7 @@ class XGBoostRegressionModel private[ml] (
}
}
val dm = new DMatrix(
XGBoost.removeMissingValues(featuresIterator.map(_.asXGB), $(missing)),
XGBoost.processMissingValues(featuresIterator.map(_.asXGB), $(missing)),
cacheInfo)
try {
val Array(originalPredictionItr, predLeafItr, predContribItr) =
Expand Down
Expand Up @@ -33,6 +33,8 @@ import scala.util.Random

import ml.dmlc.xgboost4j.java.Rabit

import org.apache.spark.ml.feature.VectorAssembler

class XGBoostGeneralSuite extends FunSuite with PerTest {

test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
Expand Down Expand Up @@ -227,26 +229,45 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
def buildDenseDataFrame(): DataFrame = {
val numRows = 100
val numCols = 5

val data = (0 until numRows).map { x =>
val label = Random.nextInt(2)
val values = Array.tabulate[Double](numCols) { c =>
if (c == numCols - 1) -0.1 else Random.nextDouble
if (c == numCols - 1) 0 else Random.nextDouble
}

(label, Vectors.dense(values))
}

ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
}

val denseDF = buildDenseDataFrame().repartition(4)
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "missing" -> -0.1f, "num_workers" -> numWorkers).toMap
"objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
val model = new XGBoostClassifier(paramMap).fit(denseDF)
model.transform(denseDF).collect()
}

test("handle Float.NaN as missing value correctly") {
val spark = ss
import spark.implicits._
val testDF = Seq(
(1.0f, 0.0f, Float.NaN, 1.0),
(1.0f, 0.0f, 1.0f, 1.0),
(0.0f, 1.0f, 0.0f, 0.0),
(1.0f, 0.0f, 1.0f, 1.0),
(1.0f, Float.NaN, 0.0f, 0.0),
(0.0f, 0.0f, 0.0f, 0.0),
(0.0f, 1.0f, 0.0f, 1.0),
(Float.NaN, 0.0f, 0.0f, 1.0)
).toDF("col1", "col2", "col3", "label")
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("col1", "col2", "col3"))
.setOutputCol("features")
val inputDF = vectorAssembler.transform(testDF).select("features", "label")
val paramMap = List("eta" -> "1", "max_depth" -> "2",
"objective" -> "binary:logistic", "num_workers" -> 1).toMap
val model = new XGBoostClassifier(paramMap).fit(inputDF)
model.transform(inputDF).collect()
}

test("training with spark parallelism checks disabled") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
Expand Down

0 comments on commit ad4de0d

Please sign in to comment.