Permalink
Browse files

refactor duplicate evaluation implementation (#1852)

  • Loading branch information...
1 parent 2b6aa77 commit d9584ab82e888de78e1d33829d6de5689d46cd54 @fromradio fromradio committed with CodingCat Dec 9, 2016
@@ -82,15 +82,6 @@ abstract class XGBoostModel(protected var _booster: Booster)
def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait = null,
iter: Int = -1, useExternalCache: Boolean = false): String = {
require(evalFunc != null || iter != -1, "you have to specify the value of either eval or iter")
- if (evalFunc == null) {
- eval(evalDataset, evalName, iter)
- } else {
- eval(evalDataset, evalName, evalFunc)
- }
- }
-
- // TODO: refactor to remove duplicate code in two variations of eval()
- private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, iter: Int): String = {
val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
val appName = evalDataset.context.appName
@@ -109,43 +100,19 @@ abstract class XGBoostModel(protected var _booster: Booster)
}
import DataUtils._
val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
- val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
- val Array(evName, predNumeric) = predStr.split(":")
- Rabit.shutdown()
- Iterator(Some(evName, predNumeric.toFloat))
- } else {
- Iterator(None)
- }
- }.filter(_.isDefined).collect()
- val evalPrefix = allEvalMetrics.map(_.get._1).head
- val evalMetricMean = allEvalMetrics.map(_.get._2).sum / allEvalMetrics.length
- s"$evalPrefix = $evalMetricMean"
- }
-
- private def eval(evalDataset: RDD[MLLabeledPoint], evalName: String, evalFunc: EvalTrait):
- String = {
- require(evalFunc != null, "you have to specify the value of either eval or iter")
- val broadcastBooster = evalDataset.sparkContext.broadcast(_booster)
- val broadcastUseExternalCache = evalDataset.sparkContext.broadcast($(useExternalMemory))
- val appName = evalDataset.context.appName
- val allEvalMetrics = evalDataset.mapPartitions {
- labeledPointsPartition =>
- if (labeledPointsPartition.hasNext) {
- val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
- Rabit.init(rabitEnv.asJava)
- val cacheFileName = {
- if (broadcastUseExternalCache.value) {
- s"$appName-${TaskContext.get().stageId()}-$evalName" +
- s"-deval_cache-${TaskContext.getPartitionId()}"
- } else {
- null
+ (evalFunc, iter) match {
+ case (null, _) => {
+ val predStr = broadcastBooster.value.evalSet(Array(dMatrix), Array(evalName), iter)
+ val Array(evName, predNumeric) = predStr.split(":")
+ Rabit.shutdown()
+ Iterator(Some(evName, predNumeric.toFloat))
+ }
+ case _ => {
+ val predictions = broadcastBooster.value.predict(dMatrix)
+ Rabit.shutdown()
+ Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
}
}
- import DataUtils._
- val dMatrix = new DMatrix(labeledPointsPartition, cacheFileName)
- val predictions = broadcastBooster.value.predict(dMatrix)
- Rabit.shutdown()
- Iterator(Some((evalName, evalFunc.eval(predictions, dMatrix))))
} else {
Iterator(None)
}

0 comments on commit d9584ab

Please sign in to comment.