Skip to content

Commit

Permalink
[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Partial revert of apache#15277 to instead sort and store input to model rather than require sorted input

## How was this patch tested?

Existing tests.

Author: Sean Owen <sowen@cloudera.com>

Closes apache#15299 from srowen/SPARK-17704.2.
  • Loading branch information
srowen committed Oct 1, 2016
1 parent af6ece3 commit b88cb63
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
Expand Up @@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] (

import ChiSqSelectorModel._

/** list of indices to select (filter). Must be ordered asc */
/** list of indices to select (filter). */
@Since("1.6.0")
val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures

Expand Down
Expand Up @@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession}
/**
* Chi Squared selector model.
*
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
* @param selectedFeatures list of indices to select (filter).
*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {

require(isSorted(selectedFeatures), "Array has to be sorted asc")
private val filterIndices = selectedFeatures.sorted

@deprecated("not intended for subclasses to use", "2.1.0")
protected def isSorted(array: Array[Int]): Boolean = {
var i = 1
val len = array.length
Expand All @@ -61,17 +62,16 @@ class ChiSqSelectorModel @Since("1.3.0") (
*/
@Since("1.3.0")
override def transform(vector: Vector): Vector = {
compress(vector, selectedFeatures)
compress(vector)
}

/**
* Returns a vector with features filtered.
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
* @param filterIndices indices of features to filter, must be ordered asc
*/
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
private def compress(features: Vector): Vector = {
features match {
case SparseVector(size, indices, values) =>
val newSize = filterIndices.length
Expand Down Expand Up @@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
*/
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val chiSqTestResult = Statistics.chiSqTest(data)
val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
val features = selectorType match {
case ChiSqSelector.KBest =>
chiSqTestResult.zipWithIndex
chiSqTestResult
.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
case ChiSqSelector.Percentile =>
chiSqTestResult.zipWithIndex
chiSqTestResult
.sortBy { case (res, _) => -res.statistic }
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelector.FPR =>
chiSqTestResult.zipWithIndex
.filter{ case (res, _) => res.pValue < alpha }
chiSqTestResult
.filter { case (res, _) => res.pValue < alpha }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
val indices = features.map { case (_, indices) => indices }.sorted
val indices = features.map { case (_, index) => index }
new ChiSqSelectorModel(indices)
}
}
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/feature.py
Expand Up @@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
@since("2.0.0")
def selectedFeatures(self):
"""
List of indices to select (filter). Must be ordered asc.
List of indices to select (filter).
"""
return self._call_java("selectedFeatures")

Expand Down

0 comments on commit b88cb63

Please sign in to comment.