forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-6529] [ML] Add Word2Vec transformer
See JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-6529). There are some notes: 1. I add `learningRate` in sharedParams since it is a common parameter for ML algorithms. 2. We will not support transform of finding synonyms from a `Vector`, which will support in further JIRA issues. 3. Word2Vec is different with other ML models that its training set and transformed set are different. Its training set is an `RDD[Iterable[String]]` which represents documents, but the transformed set we want is an `RDD[String]` that represents unique words. So you have to switch your `inputCol` in these two stages. Author: Xusen Yin <yinxusen@gmail.com> Closes apache#5596 from yinxusen/SPARK-6529 and squashes the following commits: ee2b37a [Xusen Yin] merge with former HEAD 4945462 [Xusen Yin] merge with apache#5626 3bc2cbd [Xusen Yin] change foldLeft to for loop and use blas 5dd4ee7 [Xusen Yin] fix scala style 743e0d5 [Xusen Yin] fix comments and code style 04c48e9 [Xusen Yin] ensure the functionality a190f2c [Xusen Yin] fix code style and refine the transform function of word2vec 02848fa [Xusen Yin] refine comments 34a55c0 [Xusen Yin] fix errors 109d124 [Xusen Yin] add test suite and pass it 04dde06 [Xusen Yin] add shared params c594095 [Xusen Yin] add word2vec transformer 23d77fa [Xusen Yin] merge with apache#5626 e8cfaf7 [Xusen Yin] fix conflict with master 66e7bd3 [Xusen Yin] change foldLeft to for loop and use blas 566ec20 [Xusen Yin] fix scala style b54399f [Xusen Yin] fix comments and code style 1211e86 [Xusen Yin] ensure the functionality 6b97ec8 [Xusen Yin] fix code style and refine the transform function of word2vec 7cde18f [Xusen Yin] rm sharedParams 618abd0 [Xusen Yin] refine comments e29680a [Xusen Yin] fix errors fe3afe9 [Xusen Yin] add test suite and pass it 02767fb [Xusen Yin] add shared params 6a514f1 [Xusen Yin] add word2vec transformer
- Loading branch information
Showing
4 changed files
with
267 additions
and
1 deletion.
There are no files selected for viewing
185 changes: 185 additions & 0 deletions
185
mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.feature | ||
|
||
import org.apache.spark.annotation.AlphaComponent | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared._ | ||
import org.apache.spark.ml.util.SchemaUtils | ||
import org.apache.spark.ml.{Estimator, Model} | ||
import org.apache.spark.mllib.feature | ||
import org.apache.spark.mllib.linalg.BLAS._ | ||
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.{DataFrame, Row} | ||
|
||
/** | ||
* Params for [[Word2Vec]] and [[Word2VecModel]]. | ||
*/ | ||
private[feature] trait Word2VecBase extends Params | ||
with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed { | ||
|
||
/** | ||
* The dimension of the code that you want to transform from words. | ||
*/ | ||
final val vectorSize = new IntParam( | ||
this, "vectorSize", "the dimension of codes after transforming from words") | ||
setDefault(vectorSize -> 100) | ||
|
||
/** @group getParam */ | ||
def getVectorSize: Int = getOrDefault(vectorSize) | ||
|
||
/** | ||
* Number of partitions for sentences of words. | ||
*/ | ||
final val numPartitions = new IntParam( | ||
this, "numPartitions", "number of partitions for sentences of words") | ||
setDefault(numPartitions -> 1) | ||
|
||
/** @group getParam */ | ||
def getNumPartitions: Int = getOrDefault(numPartitions) | ||
|
||
/** | ||
* The minimum number of times a token must appear to be included in the word2vec model's | ||
* vocabulary. | ||
*/ | ||
final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + | ||
"appear to be included in the word2vec model's vocabulary") | ||
setDefault(minCount -> 5) | ||
|
||
/** @group getParam */ | ||
def getMinCount: Int = getOrDefault(minCount) | ||
|
||
setDefault(stepSize -> 0.025) | ||
setDefault(maxIter -> 1) | ||
setDefault(seed -> 42L) | ||
|
||
/** | ||
* Validate and transform the input schema. | ||
*/ | ||
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { | ||
val map = extractParamMap(paramMap) | ||
SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true)) | ||
SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) | ||
} | ||
} | ||
|
||
/** | ||
* :: AlphaComponent :: | ||
* Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further | ||
* natural language processing or machine learning process. | ||
*/ | ||
@AlphaComponent | ||
final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { | ||
|
||
/** @group setParam */ | ||
def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
||
/** @group setParam */ | ||
def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
/** @group setParam */ | ||
def setVectorSize(value: Int): this.type = set(vectorSize, value) | ||
|
||
/** @group setParam */ | ||
def setStepSize(value: Double): this.type = set(stepSize, value) | ||
|
||
/** @group setParam */ | ||
def setNumPartitions(value: Int): this.type = set(numPartitions, value) | ||
|
||
/** @group setParam */ | ||
def setMaxIter(value: Int): this.type = set(maxIter, value) | ||
|
||
/** @group setParam */ | ||
def setSeed(value: Long): this.type = set(seed, value) | ||
|
||
/** @group setParam */ | ||
def setMinCount(value: Int): this.type = set(minCount, value) | ||
|
||
override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = { | ||
transformSchema(dataset.schema, paramMap, logging = true) | ||
val map = extractParamMap(paramMap) | ||
val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v } | ||
val wordVectors = new feature.Word2Vec() | ||
.setLearningRate(map(stepSize)) | ||
.setMinCount(map(minCount)) | ||
.setNumIterations(map(maxIter)) | ||
.setNumPartitions(map(numPartitions)) | ||
.setSeed(map(seed)) | ||
.setVectorSize(map(vectorSize)) | ||
.fit(input) | ||
val model = new Word2VecModel(this, map, wordVectors) | ||
Params.inheritValues(map, this, model) | ||
model | ||
} | ||
|
||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { | ||
validateAndTransformSchema(schema, paramMap) | ||
} | ||
} | ||
|
||
/** | ||
* :: AlphaComponent :: | ||
* Model fitted by [[Word2Vec]]. | ||
*/ | ||
@AlphaComponent | ||
class Word2VecModel private[ml] ( | ||
override val parent: Word2Vec, | ||
override val fittingParamMap: ParamMap, | ||
wordVectors: feature.Word2VecModel) | ||
extends Model[Word2VecModel] with Word2VecBase { | ||
|
||
/** @group setParam */ | ||
def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
||
/** @group setParam */ | ||
def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
/** | ||
* Transform a sentence column to a vector column to represent the whole sentence. The transform | ||
* is performed by averaging all word vectors it contains. | ||
*/ | ||
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { | ||
transformSchema(dataset.schema, paramMap, logging = true) | ||
val map = extractParamMap(paramMap) | ||
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) | ||
val word2Vec = udf { sentence: Seq[String] => | ||
if (sentence.size == 0) { | ||
Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double]) | ||
} else { | ||
val cum = Vectors.zeros(map(vectorSize)) | ||
val model = bWordVectors.value.getVectors | ||
for (word <- sentence) { | ||
if (model.contains(word)) { | ||
axpy(1.0, bWordVectors.value.transform(word), cum) | ||
} else { | ||
// pass words which not belong to model | ||
} | ||
} | ||
scal(1.0 / sentence.size, cum) | ||
cum | ||
} | ||
} | ||
dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol)))) | ||
} | ||
|
||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { | ||
validateAndTransformSchema(schema, paramMap) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.feature | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
import org.apache.spark.mllib.util.TestingUtils._ | ||
import org.apache.spark.sql.{Row, SQLContext} | ||
|
||
class Word2VecSuite extends FunSuite with MLlibTestSparkContext { | ||
|
||
test("Word2Vec") { | ||
val sqlContext = new SQLContext(sc) | ||
import sqlContext.implicits._ | ||
|
||
val sentence = "a b " * 100 + "a c " * 10 | ||
val numOfWords = sentence.split(" ").size | ||
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) | ||
|
||
val codes = Map( | ||
"a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451), | ||
"b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342), | ||
"c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351) | ||
) | ||
|
||
val expected = doc.map { sentence => | ||
Vectors.dense(sentence.map(codes.apply).reduce((word1, word2) => | ||
word1.zip(word2).map { case (v1, v2) => v1 + v2 } | ||
).map(_ / numOfWords)) | ||
} | ||
|
||
val docDF = doc.zip(expected).toDF("text", "expected") | ||
|
||
val model = new Word2Vec() | ||
.setVectorSize(3) | ||
.setInputCol("text") | ||
.setOutputCol("result") | ||
.fit(docDF) | ||
|
||
model.transform(docDF).select("result", "expected").collect().foreach { | ||
case Row(vector1: Vector, vector2: Vector) => | ||
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") | ||
} | ||
} | ||
} | ||
|