Skip to content

Commit

Permalink
[SPARK-6529] [ML] Add Word2Vec transformer
Browse files Browse the repository at this point in the history
See JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-6529).

There are some notes:

1. I add `learningRate` in sharedParams since it is a common parameter for ML algorithms.
2. We will not support transform of finding synonyms from a `Vector`, which will support in further JIRA issues.
3. Word2Vec is different with other ML models that its training set and transformed set are different. Its training set is an `RDD[Iterable[String]]` which represents documents, but the transformed set we want is an `RDD[String]` that represents unique words. So you have to switch your `inputCol` in these two stages.

Author: Xusen Yin <yinxusen@gmail.com>

Closes apache#5596 from yinxusen/SPARK-6529 and squashes the following commits:

ee2b37a [Xusen Yin] merge with former HEAD
4945462 [Xusen Yin] merge with apache#5626
3bc2cbd [Xusen Yin] change foldLeft to for loop and use blas
5dd4ee7 [Xusen Yin] fix scala style
743e0d5 [Xusen Yin] fix comments and code style
04c48e9 [Xusen Yin] ensure the functionality
a190f2c [Xusen Yin] fix code style and refine the transform function of word2vec
02848fa [Xusen Yin] refine comments
34a55c0 [Xusen Yin] fix errors
109d124 [Xusen Yin] add test suite and pass it
04dde06 [Xusen Yin] add shared params
c594095 [Xusen Yin] add word2vec transformer
23d77fa [Xusen Yin] merge with apache#5626
e8cfaf7 [Xusen Yin] fix conflict with master
66e7bd3 [Xusen Yin] change foldLeft to for loop and use blas
566ec20 [Xusen Yin] fix scala style
b54399f [Xusen Yin] fix comments and code style
1211e86 [Xusen Yin] ensure the functionality
6b97ec8 [Xusen Yin] fix code style and refine the transform function of word2vec
7cde18f [Xusen Yin] rm sharedParams
618abd0 [Xusen Yin] refine comments
e29680a [Xusen Yin] fix errors
fe3afe9 [Xusen Yin] add test suite and pass it
02767fb [Xusen Yin] add shared params
6a514f1 [Xusen Yin] add word2vec transformer
  • Loading branch information
yinxusen authored and mengxr committed Apr 29, 2015
1 parent 15995c8 commit c9d530e
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 1 deletion.
185 changes: 185 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

/**
* Params for [[Word2Vec]] and [[Word2VecModel]].
*/
private[feature] trait Word2VecBase extends Params
with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed {

/**
* The dimension of the code that you want to transform from words.
*/
final val vectorSize = new IntParam(
this, "vectorSize", "the dimension of codes after transforming from words")
setDefault(vectorSize -> 100)

/** @group getParam */
def getVectorSize: Int = getOrDefault(vectorSize)

/**
* Number of partitions for sentences of words.
*/
final val numPartitions = new IntParam(
this, "numPartitions", "number of partitions for sentences of words")
setDefault(numPartitions -> 1)

/** @group getParam */
def getNumPartitions: Int = getOrDefault(numPartitions)

/**
* The minimum number of times a token must appear to be included in the word2vec model's
* vocabulary.
*/
final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " +
"appear to be included in the word2vec model's vocabulary")
setDefault(minCount -> 5)

/** @group getParam */
def getMinCount: Int = getOrDefault(minCount)

setDefault(stepSize -> 0.025)
setDefault(maxIter -> 1)
setDefault(seed -> 42L)

/**
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap)
SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
}
}

/**
* :: AlphaComponent ::
* Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further
* natural language processing or machine learning process.
*/
@AlphaComponent
final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
def setVectorSize(value: Int): this.type = set(vectorSize, value)

/** @group setParam */
def setStepSize(value: Double): this.type = set(stepSize, value)

/** @group setParam */
def setNumPartitions(value: Int): this.type = set(numPartitions, value)

/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)

/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
def setMinCount(value: Int): this.type = set(minCount, value)

override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = extractParamMap(paramMap)
val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v }
val wordVectors = new feature.Word2Vec()
.setLearningRate(map(stepSize))
.setMinCount(map(minCount))
.setNumIterations(map(maxIter))
.setNumPartitions(map(numPartitions))
.setSeed(map(seed))
.setVectorSize(map(vectorSize))
.fit(input)
val model = new Word2VecModel(this, map, wordVectors)
Params.inheritValues(map, this, model)
model
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}

/**
* :: AlphaComponent ::
* Model fitted by [[Word2Vec]].
*/
@AlphaComponent
class Word2VecModel private[ml] (
override val parent: Word2Vec,
override val fittingParamMap: ParamMap,
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Transform a sentence column to a vector column to represent the whole sentence. The transform
* is performed by averaging all word vectors it contains.
*/
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = extractParamMap(paramMap)
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double])
} else {
val cum = Vectors.zeros(map(vectorSize))
val model = bWordVectors.value.getVectors
for (word <- sentence) {
if (model.contains(word)) {
axpy(1.0, bWordVectors.value.transform(word), cum)
} else {
// pass words which not belong to model
}
}
scal(1.0 / sentence.size, cum)
cum
}
}
dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"))
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))

val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,21 @@ trait HasTol extends Params {
/** @group getParam */
final def getTol: Double = getOrDefault(tol)
}

/**
* :: DeveloperApi ::
* Trait for shared param stepSize.
*/
@DeveloperApi
trait HasStepSize extends Params {

/**
* Param for Step size to be used for each iteration of optimization..
* @group param
*/
final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.")

/** @group getParam */
final def getStepSize: Double = getOrDefault(stepSize)
}
// scalastyle:on
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}

class Word2VecSuite extends FunSuite with MLlibTestSparkContext {

test("Word2Vec") {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

val sentence = "a b " * 100 + "a c " * 10
val numOfWords = sentence.split(" ").size
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))

val codes = Map(
"a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451),
"b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342),
"c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351)
)

val expected = doc.map { sentence =>
Vectors.dense(sentence.map(codes.apply).reduce((word1, word2) =>
word1.zip(word2).map { case (v1, v2) => v1 + v2 }
).map(_ / numOfWords))
}

val docDF = doc.zip(expected).toDF("text", "expected")

val model = new Word2Vec()
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
.fit(docDF)

model.transform(docDF).select("result", "expected").collect().foreach {
case Row(vector1: Vector, vector2: Vector) =>
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
}
}
}

0 comments on commit c9d530e

Please sign in to comment.