forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-8455] [ML] Implement n-gram feature transformer
Implementation of n-gram feature transformer for ML. Author: Feynman Liang <fliang@databricks.com> Closes apache#6887 from feynmanliang/ngram-featurizer and squashes the following commits: d2c839f [Feynman Liang] Make n > input length yield empty output 9fadd36 [Feynman Liang] Add empty and corner test cases, fix names and spaces fe93873 [Feynman Liang] Implement n-gram feature transformer
- Loading branch information
Showing
2 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
69 changes: 69 additions & 0 deletions
69
mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.feature | ||
|
||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.ml.UnaryTransformer | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.sql.types.{ArrayType, DataType, StringType} | ||
|
||
/** | ||
* :: Experimental :: | ||
* A feature transformer that converts the input array of strings into an array of n-grams. Null | ||
* values in the input array are ignored. | ||
* It returns an array of n-grams where each n-gram is represented by a space-separated string of | ||
* words. | ||
* | ||
* When the input is empty, an empty array is returned. | ||
* When the input array length is less than n (number of elements per n-gram), no n-grams are | ||
* returned. | ||
*/ | ||
@Experimental | ||
class NGram(override val uid: String) | ||
extends UnaryTransformer[Seq[String], Seq[String], NGram] { | ||
|
||
def this() = this(Identifiable.randomUID("ngram")) | ||
|
||
/** | ||
* Minimum n-gram length, >= 1. | ||
* Default: 2, bigram features | ||
* @group param | ||
*/ | ||
val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)", | ||
ParamValidators.gtEq(1)) | ||
|
||
/** @group setParam */ | ||
def setN(value: Int): this.type = set(n, value) | ||
|
||
/** @group getParam */ | ||
def getN: Int = $(n) | ||
|
||
setDefault(n -> 2) | ||
|
||
override protected def createTransformFunc: Seq[String] => Seq[String] = { | ||
_.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq | ||
} | ||
|
||
override protected def validateInputType(inputType: DataType): Unit = { | ||
require(inputType.sameType(ArrayType(StringType)), | ||
s"Input type must be ArrayType(StringType) but got $inputType.") | ||
} | ||
|
||
override protected def outputDataType: DataType = new ArrayType(StringType, false) | ||
} |
94 changes: 94 additions & 0 deletions
94
mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.feature | ||
|
||
import scala.beans.BeanInfo | ||
|
||
import org.apache.spark.SparkFunSuite | ||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
import org.apache.spark.sql.{DataFrame, Row} | ||
|
||
@BeanInfo | ||
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) | ||
|
||
class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { | ||
import org.apache.spark.ml.feature.NGramSuite._ | ||
|
||
test("default behavior yields bigram features") { | ||
val nGram = new NGram() | ||
.setInputCol("inputTokens") | ||
.setOutputCol("nGrams") | ||
val dataset = sqlContext.createDataFrame(Seq( | ||
NGramTestData( | ||
Array("Test", "for", "ngram", "."), | ||
Array("Test for", "for ngram", "ngram .") | ||
))) | ||
testNGram(nGram, dataset) | ||
} | ||
|
||
test("NGramLength=4 yields length 4 n-grams") { | ||
val nGram = new NGram() | ||
.setInputCol("inputTokens") | ||
.setOutputCol("nGrams") | ||
.setN(4) | ||
val dataset = sqlContext.createDataFrame(Seq( | ||
NGramTestData( | ||
Array("a", "b", "c", "d", "e"), | ||
Array("a b c d", "b c d e") | ||
))) | ||
testNGram(nGram, dataset) | ||
} | ||
|
||
test("empty input yields empty output") { | ||
val nGram = new NGram() | ||
.setInputCol("inputTokens") | ||
.setOutputCol("nGrams") | ||
.setN(4) | ||
val dataset = sqlContext.createDataFrame(Seq( | ||
NGramTestData( | ||
Array(), | ||
Array() | ||
))) | ||
testNGram(nGram, dataset) | ||
} | ||
|
||
test("input array < n yields empty output") { | ||
val nGram = new NGram() | ||
.setInputCol("inputTokens") | ||
.setOutputCol("nGrams") | ||
.setN(6) | ||
val dataset = sqlContext.createDataFrame(Seq( | ||
NGramTestData( | ||
Array("a", "b", "c", "d", "e"), | ||
Array() | ||
))) | ||
testNGram(nGram, dataset) | ||
} | ||
} | ||
|
||
object NGramSuite extends SparkFunSuite { | ||
|
||
def testNGram(t: NGram, dataset: DataFrame): Unit = { | ||
t.transform(dataset) | ||
.select("nGrams", "wantedNGrams") | ||
.collect() | ||
.foreach { case Row(actualNGrams, wantedNGrams) => | ||
assert(actualNGrams === wantedNGrams) | ||
} | ||
} | ||
} |