From 613b90ee10d083c5eb15a39e1bfdd876a08cae31 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 30 Jun 2015 12:52:51 -0700 Subject: [PATCH] Changed GBTRegressor so it does NOT threshold the prediction --- .../spark/ml/regression/GBTRegressor.scala | 3 +-- .../ml/regression/GBTRegressorSuite.scala | 23 ++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 036e3acb07412..47c110d027d67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -172,8 +172,7 @@ final class GBTRegressionModel( // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) - if (prediction > 0.0) 1.0 else 0.0 + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } override def copy(extra: ParamMap): GBTRegressionModel = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 98fb3d3f5f22c..9682edcd9ba84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** @@ -67,6 +68,26 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("GBTRegressor behaves reasonably on toy data") { + val df = sqlContext.createDataFrame(Seq( + LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), + LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), + LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), + LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)), + LabeledPoint(9, Vectors.dense(1, 2, 6, 4)), + LabeledPoint(-4, Vectors.dense(6, 3, 2, 2)) + )) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(df) + val preds = model.transform(df) + val predictions = preds.select("prediction").map(_.getDouble(0)) + // Checks based on SPARK-8736 (to ensure it is not doing classification) + assert(predictions.max() > 2) + assert(predictions.min() < -1) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") {