Skip to content

Commit

Permalink
[SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to outpu…
Browse files Browse the repository at this point in the history
…t column

Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes apache#5789 from jkbradley/vector-indexer-metadata and squashes the following commits:

b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.
  • Loading branch information
jkbradley committed Apr 29, 2015
1 parent f8cbb0a commit b1ef6a6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ private object VectorIndexer {
* - Continuous features (columns) are left unchanged.
* This also appends metadata to the output column, marking features as Numeric (continuous),
* Nominal (categorical), or Binary (either continuous or categorical).
* Non-ML metadata is not carried over from the input to the output column.
*
* This maintains vector sparsity.
*
Expand Down Expand Up @@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] (

// TODO: Check more carefully about whether this whole class will be included in a closure.

/** Per-vector transform function */
private val transformFunc: Vector => Vector = {
val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted
val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
val localVectorMap = categoryMaps
val f: Vector => Vector = {
case dv: DenseVector =>
val tmpv = dv.copy
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
}
tmpv
case sv: SparseVector =>
// We use the fact that categorical value 0 is always mapped to index 0.
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
if (featureIndex < tmpv.indices(k)) {
catFeatureIdx += 1
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
catFeatureIdx += 1
k += 1
val localNumFeatures = numFeatures
val f: Vector => Vector = { (v: Vector) =>
assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
s" $numFeatures but found length ${v.size}")
v match {
case dv: DenseVector =>
val tmpv = dv.copy
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
}
}
tmpv
tmpv
case sv: SparseVector =>
// We use the fact that categorical value 0 is always mapped to index 0.
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCatFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
if (featureIndex < tmpv.indices(k)) {
catFeatureIdx += 1
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
catFeatureIdx += 1
k += 1
}
}
tmpv
}
}
f
}
Expand All @@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] (
val map = extractParamMap(paramMap)
val newField = prepOutputField(dataset.schema, map)
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
// For now, just check the first row of inputCol for vector length.
val firstRow = dataset.select(map(inputCol)).take(1)
if (firstRow.length != 0) {
val actualNumFeatures = firstRow(0).getAs[Vector](0).size
require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
s" $numFeatures but found length $actualNumFeatures")
}
dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
}

Expand All @@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] (
s"VectorIndexerModel requires output column parameter: $outputCol")
SchemaUtils.checkColumnType(schema, map(inputCol), dataType)

// If the input metadata specifies numFeatures, compare with expected numFeatures.
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
Some(origAttrGroup.attributes.get.length)
Expand All @@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] (
* Prepare the output column field, including per-feature metadata.
* @param schema Input schema
* @param map Parameter map (with this class' embedded parameter map folded in)
* @return Output column field
* @return Output column field. This field does not contain non-ML metadata.
*/
private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
Expand All @@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] (
partialFeatureAttributes
}
val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
newAttributeGroup.toStructField(schema(map(inputCol)).metadata)
newAttributeGroup.toStructField()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.util.TestingUtils
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work
intercept[IllegalArgumentException] {
model.transform(densePoints2)
intercept[SparkException] {
model.transform(densePoints2).collect()
println("Did not throw error when fit, transform were called on vectors of different lengths")
}
intercept[SparkException] {
Expand Down Expand Up @@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
// TODO: Once input features marked as categorical are handled correctly, check that here.
}
}
// Check that non-ML metadata are preserved.
TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
}
}

Expand Down
60 changes: 0 additions & 60 deletions mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala

This file was deleted.

0 comments on commit b1ef6a6

Please sign in to comment.