From 53091a19353c361ce0762e896769bef08f7f73eb Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 13 Dec 2016 16:47:31 -0800 Subject: [PATCH] Removes labels from tree data generation (#82) * changes * removes labels * reset scala version * adding metadata * bumping spark release --- build.sbt | 2 +- .../sql/perf/mllib/BenchmarkAlgorithm.scala | 18 +++++++++++++++++- .../DecisionTreeClassification.scala | 2 +- .../classification/GBTClassification.scala | 2 +- src/main/scala/configs/mllib-small.yaml | 2 +- .../scala/org/apache/spark/ml/TreeUtils.scala | 13 +------------ 6 files changed, 22 insertions(+), 17 deletions(-) diff --git a/build.sbt b/build.sbt index 609635dd..332bc7c7 100644 --- a/build.sbt +++ b/build.sbt @@ -14,7 +14,7 @@ sparkPackageName := "databricks/spark-sql-perf" // All Spark Packages need a license licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")) -sparkVersion := "2.0.0" +sparkVersion := "2.0.1" sparkComponents ++= Seq("sql", "hive", "mllib") diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index a17e1937..858f911f 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -2,6 +2,7 @@ package com.databricks.spark.sql.perf.mllib import com.typesafe.scalalogging.slf4j.{LazyLogging => Logging} +import org.apache.spark.ml.attribute.{NominalAttribute, NumericAttribute} import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.sql._ @@ -76,7 +77,22 @@ trait TrainingSetFromTransformer { final override def trainingDataSet(ctx: MLBenchContext): DataFrame = { val initial = initialData(ctx) val model = trueModel(ctx) - model.transform(initial).select(col("features"), col("prediction").as("label")) + val fCol = col("features") + // Special case for the trees: we need to set the number of labels. + // numClasses is set? We will add the number of classes to the final column. + val lCol = ctx.params.numClasses match { + case Some(numClasses) => + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + col("prediction").as("label", labelMetadata) + case None => + col("prediction").as("label") + } + model.transform(initial).select(fCol, lCol) } } diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala index 45ce7f8a..47cf4c42 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala @@ -20,7 +20,7 @@ abstract class TreeOrForestClassification extends BenchmarkAlgorithm val featureArity: Array[Int] = getFeatureArity(ctx) val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, featureArity) - TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity) + TreeUtils.setMetadata(data, "features", featureArity) } override protected def trueModel(ctx: MLBenchContext): Transformer = { diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala index dfd172d0..547a0502 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala @@ -20,7 +20,7 @@ object GBTClassification extends BenchmarkAlgorithm val featureArity: Array[Int] = getFeatureArity(ctx) val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, featureArity) - TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity) + TreeUtils.setMetadata(data, "features", featureArity) } override protected def trueModel(ctx: MLBenchContext): Transformer = { diff --git a/src/main/scala/configs/mllib-small.yaml b/src/main/scala/configs/mllib-small.yaml index 6beeedde..3e574b48 100644 --- a/src/main/scala/configs/mllib-small.yaml +++ b/src/main/scala/configs/mllib-small.yaml @@ -60,7 +60,7 @@ benchmarks: numExamples: 100 numTestExamples: 10 depth: 3 - numClasses: 4 + numClasses: 2 numFeatures: 5 maxIter: 3 - name: regression.LinearRegression diff --git a/src/main/scala/org/apache/spark/ml/TreeUtils.scala b/src/main/scala/org/apache/spark/ml/TreeUtils.scala index 1bd3c127..badef4fd 100644 --- a/src/main/scala/org/apache/spark/ml/TreeUtils.scala +++ b/src/main/scala/org/apache/spark/ml/TreeUtils.scala @@ -9,8 +9,6 @@ object TreeUtils { * * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. - * @param labelColName Name of the label column on which to set the metadata. - * @param numClasses Number of classes label can take. If 0, mark as continuous. * @param featuresColName Name of the features column * @param featureArity Array of length numFeatures, where 0 indicates continuous feature and * value > 0 indicates a categorical feature of that arity. @@ -18,16 +16,8 @@ object TreeUtils { */ def setMetadata( data: DataFrame, - labelColName: String, - numClasses: Int, featuresColName: String, featureArity: Array[Int]): DataFrame = { - val labelAttribute = if (numClasses == 0) { - NumericAttribute.defaultAttr.withName(labelColName) - } else { - NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) - } - val labelMetadata = labelAttribute.toMetadata() val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) => if (arity > 0) { NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity) @@ -36,7 +26,6 @@ object TreeUtils { } } val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() - data.select(data(featuresColName).as(featuresColName, featuresMetadata), - data(labelColName).as(labelColName, labelMetadata)) + data.select(data(featuresColName).as(featuresColName, featuresMetadata)) } }