diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 458fab48fef5a..afc5ee742ba74 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -27,6 +27,7 @@ import scala.language.existentials import scala.reflect.ClassTag import net.razorvine.pickle._ +import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil @@ -641,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable { def getVectors: JMap[String, JList[Float]] = { model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava } + + def save(sc: SparkContext, path: String) = model.save(sc, path) } /** diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index b5138773fd61b..e659af33f462b 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -36,6 +36,7 @@ from pyspark.mllib.linalg import ( Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', @@ -416,7 +417,7 @@ def fit(self, dataset): return IDFModel(jmodel) -class Word2VecModel(JavaVectorTransformer): +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model """ @@ -455,6 +456,12 @@ def getVectors(self): """ return self.call("getVectors") + @classmethod + def load(cls, sc, path): + jmodel = sc._jvm.org.apache.spark.mllib.feature \ + .Word2VecModel.load(sc._jsc.sc(), path) + return Word2VecModel(jmodel) + @ignore_unicode_prefix class Word2Vec(object): @@ -488,6 +495,17 @@ class Word2Vec(object): >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] [u'b', u'c'] + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = Word2VecModel.load(sc, path) + >>> model.transform("a") == sameModel.transform("a") + True + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def __init__(self): """