Skip to content

Commit

Permalink
[SPARK-21481][ML] Add indexOf method in ml.feature.HashingTF
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Add indexOf method for ml.feature.HashingTF.

## How was this patch tested?

Add Unit test.

Closes apache#25250 from huaxingao/spark-21481.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
  • Loading branch information
huaxingao authored and srowen committed Jul 28, 2019
1 parent d943ee0 commit 70f82fd
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 8 deletions.
36 changes: 28 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@

package org.apache.spark.ml.feature

import scala.collection.mutable

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
import org.apache.spark.mllib.feature.HashingTF.murmur3Hash
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}
import org.apache.spark.util.Utils

/**
* Maps a sequence of terms to their term frequencies using the hashing trick.
Expand All @@ -41,6 +44,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType}
class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {

private[this] val hashFunc: Any => Int = murmur3Hash

@Since("1.2.0")
def this() = this(Identifiable.randomUID("hashingTF"))

Expand Down Expand Up @@ -94,15 +99,22 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)

val hashingTF = new OldHashingTF($(numFeatures)).setBinary($(binary))
val func = (terms: Seq[_]) => {
val seq = hashingTF.transformImpl(terms)
Vectors.sparse(hashingTF.numFeatures, seq)
val hashUDF = udf { terms: Seq[_] =>
val numOfFeatures = $(numFeatures)
val isBinary = $(binary)
val termFrequencies = mutable.HashMap.empty[Int, Double].withDefaultValue(0.0)
terms.foreach { term =>
val i = indexOf(term)
if (isBinary) {
termFrequencies(i) = 1.0
} else {
termFrequencies(i) += 1.0
}
}
Vectors.sparse($(numFeatures), termFrequencies.toSeq)
}

val transformer = udf(func)
dataset.withColumn($(outputCol), transformer(col($(inputCol))),
dataset.withColumn($(outputCol), hashUDF(col($(inputCol))),
outputSchema($(outputCol)).metadata)
}

Expand All @@ -115,6 +127,14 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
}

/**
* Returns the index of the input term.
*/
@Since("3.0.0")
def indexOf(term: Any): Int = {
Utils.nonNegativeMod(hashFunc(term), $(numFeatures))
}

@Since("1.4.1")
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {
assert(features ~== expected absTol 1e-14)
}

test("indexOf method") {
val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words")
val n = 100
val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("features")
.setNumFeatures(n)
val mLlibHashingTF = new MLlibHashingTF(n)
assert(hashingTF.indexOf("a") === mLlibHashingTF.indexOf("a"))
assert(hashingTF.indexOf("b") === mLlibHashingTF.indexOf("b"))
assert(hashingTF.indexOf("c") === mLlibHashingTF.indexOf("c"))
assert(hashingTF.indexOf("d") === mLlibHashingTF.indexOf("d"))
}

test("read/write") {
val t = new HashingTF()
.setInputCol("myInputCol")
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,8 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
>>> loadedHashingTF = HashingTF.load(hashingTFPath)
>>> loadedHashingTF.getNumFeatures() == hashingTF.getNumFeatures()
True
>>> hashingTF.indexOf("b")
1
.. versionadded:: 1.3.0
"""
Expand Down Expand Up @@ -956,6 +958,14 @@ def getBinary(self):
"""
return self.getOrDefault(self.binary)

@since("3.0.0")
def indexOf(self, term):
"""
Returns the index of the input term.
"""
self._transfer_params_to_java()
return self._java_obj.indexOf(term)


@inherit_doc
class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
Expand Down

0 comments on commit 70f82fd

Please sign in to comment.