Skip to content

Commit

Permalink
Removes labels from tree data generation (#82)
Browse files Browse the repository at this point in the history
* changes

* removes labels

* reset scala version

* adding metadata

* bumping spark release
  • Loading branch information
thunterdb committed Dec 14, 2016
1 parent 685c50d commit 53091a1
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Expand Up @@ -14,7 +14,7 @@ sparkPackageName := "databricks/spark-sql-perf"
// All Spark Packages need a license
licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0"))

sparkVersion := "2.0.0"
sparkVersion := "2.0.1"

sparkComponents ++= Seq("sql", "hive", "mllib")

Expand Down
Expand Up @@ -2,6 +2,7 @@ package com.databricks.spark.sql.perf.mllib

import com.typesafe.scalalogging.slf4j.{LazyLogging => Logging}

import org.apache.spark.ml.attribute.{NominalAttribute, NumericAttribute}
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.sql._
Expand Down Expand Up @@ -76,7 +77,22 @@ trait TrainingSetFromTransformer {
final override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
val initial = initialData(ctx)
val model = trueModel(ctx)
model.transform(initial).select(col("features"), col("prediction").as("label"))
val fCol = col("features")
// Special case for the trees: we need to set the number of labels.
// numClasses is set? We will add the number of classes to the final column.
val lCol = ctx.params.numClasses match {
case Some(numClasses) =>
val labelAttribute = if (numClasses == 0) {
NumericAttribute.defaultAttr.withName("label")
} else {
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
}
val labelMetadata = labelAttribute.toMetadata()
col("prediction").as("label", labelMetadata)
case None =>
col("prediction").as("label")
}
model.transform(initial).select(fCol, lCol)
}
}

Expand Down
Expand Up @@ -20,7 +20,7 @@ abstract class TreeOrForestClassification extends BenchmarkAlgorithm
val featureArity: Array[Int] = getFeatureArity(ctx)
val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
ctx.seed(), numPartitions, featureArity)
TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
TreeUtils.setMetadata(data, "features", featureArity)
}

override protected def trueModel(ctx: MLBenchContext): Transformer = {
Expand Down
Expand Up @@ -20,7 +20,7 @@ object GBTClassification extends BenchmarkAlgorithm
val featureArity: Array[Int] = getFeatureArity(ctx)
val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
ctx.seed(), numPartitions, featureArity)
TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
TreeUtils.setMetadata(data, "features", featureArity)
}

override protected def trueModel(ctx: MLBenchContext): Transformer = {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/configs/mllib-small.yaml
Expand Up @@ -60,7 +60,7 @@ benchmarks:
numExamples: 100
numTestExamples: 10
depth: 3
numClasses: 4
numClasses: 2
numFeatures: 5
maxIter: 3
- name: regression.LinearRegression
Expand Down
13 changes: 1 addition & 12 deletions src/main/scala/org/apache/spark/ml/TreeUtils.scala
Expand Up @@ -9,25 +9,15 @@ object TreeUtils {
*
* @param data Dataset. Categorical features and labels must already have 0-based indices.
* This must be non-empty.
* @param labelColName Name of the label column on which to set the metadata.
* @param numClasses Number of classes label can take. If 0, mark as continuous.
* @param featuresColName Name of the features column
* @param featureArity Array of length numFeatures, where 0 indicates continuous feature and
* value > 0 indicates a categorical feature of that arity.
* @return DataFrame with metadata
*/
def setMetadata(
data: DataFrame,
labelColName: String,
numClasses: Int,
featuresColName: String,
featureArity: Array[Int]): DataFrame = {
val labelAttribute = if (numClasses == 0) {
NumericAttribute.defaultAttr.withName(labelColName)
} else {
NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
}
val labelMetadata = labelAttribute.toMetadata()
val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) =>
if (arity > 0) {
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity)
Expand All @@ -36,7 +26,6 @@ object TreeUtils {
}
}
val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
data.select(data(featuresColName).as(featuresColName, featuresMetadata),
data(labelColName).as(labelColName, labelMetadata))
data.select(data(featuresColName).as(featuresColName, featuresMetadata))
}
}

0 comments on commit 53091a1

Please sign in to comment.