In [1]:
import os
import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.join(BASE_DIR))
from offline import SparkSessionBase
from pyspark.ml.feature import Word2Vec

In [2]:
class TrainWord2VecModel(SparkSessionBase):
    
    SPARK_APP_NAME = "Word2Vec"
    ENABLE_HIVE_SUPPORT = True
    SPARK_EXECUTOR_MEMORY = "4g"

    
    def __init__(self):
        self.spark = self._create_spark_session()
        
w2v = TrainWord2VecModel()

In [3]:
def segmentation(partition):
    import os
    import re
    import jieba
    import jieba.analyse
    import jieba.posseg as pseg
    import codecs
    
    abspath = "/Users/hycao/text"
    
    userDict_path = os.path.join(abspath, "ITKeywords.txt")
    jieba.load_userdict(userDict_path)
    stopwords_path = os.path.join(abspath, "stopwords.txt")
    
    def get_stopwords_list():
        stopwords_list = [i.strip() for i in codecs.open(stopwords_path).readlines()]
        return stopwords_list
    
    stopwords_list = get_stopwords_list()
    
    def cut_sentence(sentence):
        seg_list = pseg.lcut(sentence)
        seg_list = [i for i in seg_list if i.flag not in stopwords_list]
        filtered_words_list = []
        for seg in seg_list:
            if len(seg.word) <= 1:
                continue
            elif seg.flag == "eng":
                if len(seg.word) <= 2:
                    continue
                    
                else:
                    filtered_words_list.append(seg.word)
            elif seg.flag.startswith("n"):
                filtered_words_list.append(seg.word)
            elif seg.flag in ["X", "eng"]:
                filtered_words_list.append(seg.word)
        return filtered_words_list
    
    for row in partition:
        sentence = re.sub("<.*?>", "", row.sentence)
        words = cut_sentence(sentence)
        yield row.article_id, row.channel_id, words
        

In [5]:
w2v.spark.sql("use fytang")
article_data = w2v.spark.sql("select * from article_data where channel_id = 18")
words_df = article_data.rdd.mapPartitions(segmentation).toDF(['article_id', 'channel_id', 'words'])


In [9]:
wv = Word2Vec(vectorSize=100, inputCol='words', outputCol='model')
wv_model = wv.fit(words_df)
wv_model.save("hdfs://localhost:9000/fytang/models/test.word2vec")

In [10]:
from pyspark.ml.feature import Word2VecModel
word2vec = Word2VecModel.load("hdfs://localhost:9000/fytang/models/test.word2vec")
vectors = word2vec.getVectors()

In [11]:
vectors.show()

+----------+--------------------+
|      word|              vector|
+----------+--------------------+
|        广义|[0.14468304812908...|
|        伙伴|[0.00758749945089...|
|        箭头|[0.17873702943325...|
|      COCO|[0.11621176451444...|
|        拜拜|[0.00776594784110...|
|  quotient|[0.04783434420824...|
|        货币|[0.07925040274858...|
|        人物|[0.21718211472034...|
|       wsy|[0.00973777379840...|
|fromParams|[0.05492971837520...|
|ershoufang|[-0.0702918618917...|
|        热身|[0.10387352854013...|
|    breaks|[0.01752104423940...|
|      marr|[-0.0695709511637...|
|       可靠性|[-0.2670255005359...|
|      测试代码|[0.20592857897281...|
|       pys|[0.00484006293118...|
|       dns|[-0.0224121324717...|
|   frmongo|[0.00727079855278...|
|       ROW|[-0.1352243125438...|
+----------+--------------------+
only showing top 20 rows



In [13]:
article_profile = w2v.spark.sql("select * from tfidf_keywords_values")
article_profile.show()

+----------+----------+-------+--------+
|article_id|channel_id|keyword|   tfidf|
+----------+----------+-------+--------+
|     98319|        17|    var| 20.6079|
|     98323|        17|    var|  7.4938|
|     98326|        17|    var|104.9128|
|     98344|        17|    var|  5.6203|
|     98359|        17|    var| 69.3174|
|     98360|        17|    var|  9.3672|
|     98392|        17|    var| 14.9875|
|     98393|        17|    var|155.4958|
|     98406|        17|    var| 11.2407|
|     98419|        17|    var| 59.9502|
|     98442|        17|    var| 18.7344|
|     98445|        17|    var| 37.4689|
|     98512|        17|    var| 29.9751|
|     98544|        17|    var|  5.6203|
|     98545|        17|    var| 22.4813|
|     98548|        17|    var| 71.1909|
|     98599|        17|    var| 11.2407|
|     98609|        17|    var| 18.7344|
|     98642|        17|    var|  67.444|
|     98648|        15|    var| 20.6079|
+----------+----------+-------+--------+
only showing top

In [48]:
article_profile.registerTempTable("incremental")
# keyword_weight = w2v.spark.sql("select article_id,channel_id, keyword, weight from incremental LATERAL VIEW explode(keywords) AS keyword, weight")

In [49]:
_article_profile = article_profile.join(vectors,vectors.word == article_profile.keyword,"inner")

In [50]:
_article_profile.show()

+----------+----------+-------+--------+----+--------------------+
|article_id|channel_id|keyword|   tfidf|word|              vector|
+----------+----------+-------+--------+----+--------------------+
|     98319|        17|    var| 20.6079| var|[0.07167968153953...|
|     98323|        17|    var|  7.4938| var|[0.07167968153953...|
|     98326|        17|    var|104.9128| var|[0.07167968153953...|
|     98344|        17|    var|  5.6203| var|[0.07167968153953...|
|     98359|        17|    var| 69.3174| var|[0.07167968153953...|
|     98360|        17|    var|  9.3672| var|[0.07167968153953...|
|     98392|        17|    var| 14.9875| var|[0.07167968153953...|
|     98393|        17|    var|155.4958| var|[0.07167968153953...|
|     98406|        17|    var| 11.2407| var|[0.07167968153953...|
|     98419|        17|    var| 59.9502| var|[0.07167968153953...|
|     98442|        17|    var| 18.7344| var|[0.07167968153953...|
|     98445|        17|    var| 37.4689| var|[0.07167968153953

In [51]:
_article_profile = _article_profile.select(['article_id', 'channel_id', 'vector'])
_article_profile.show()

+----------+----------+--------------------+
|article_id|channel_id|              vector|
+----------+----------+--------------------+
|     98319|        17|[0.07167968153953...|
|     98323|        17|[0.07167968153953...|
|     98326|        17|[0.07167968153953...|
|     98344|        17|[0.07167968153953...|
|     98359|        17|[0.07167968153953...|
|     98360|        17|[0.07167968153953...|
|     98392|        17|[0.07167968153953...|
|     98393|        17|[0.07167968153953...|
|     98406|        17|[0.07167968153953...|
|     98419|        17|[0.07167968153953...|
|     98442|        17|[0.07167968153953...|
|     98445|        17|[0.07167968153953...|
|     98512|        17|[0.07167968153953...|
|     98544|        17|[0.07167968153953...|
|     98545|        17|[0.07167968153953...|
|     98548|        17|[0.07167968153953...|
|     98599|        17|[0.07167968153953...|
|     98609|        17|[0.07167968153953...|
|     98642|        17|[0.07167968153953...|
|     9864

In [52]:
_article_profile.registerTempTable("temptable")
articlevector = w2v.spark.sql("select article_id, max(channel_id) channel_id, collect_set(vector) articlevecter from temptable group by article_id")

In [53]:
articlevector.show()

+----------+----------+--------------------+
|article_id|channel_id|       articlevecter|
+----------+----------+--------------------+
|       148|        17|[[-0.465461105108...|
|       463|         1|[[-0.395194649696...|
|       471|        17|[[-0.040295206010...|
|       496|        11|[[-0.090842999517...|
|       833|         1|[[-0.235992595553...|
|      1088|         1|[[-0.013495572842...|
|      1238|        11|[[0.4386835694313...|
|      1342|         6|[[-0.036753281950...|
|      1580|         6|[[-0.066306084394...|
|      1591|        17|[[0.2542469501495...|
|      1645|         6|[[-0.390882134437...|
|      1829|        11|[[0.1422660946846...|
|      1959|        11|[[0.0840188488364...|
|      2122|        17|[[-0.013495572842...|
|      2142|        11|[[0.0012916581472...|
|      2366|        13|[[-0.047671638429...|
|      2659|        17|[[-0.040295206010...|
|      2866|        17|[[0.1863891184329...|
|      3175|         6|[[-0.327978760004...|
|      374

In [54]:
def avg_vector(row):
    x = 0
    for v in row.articlevecter:
        x += v
    return row.article_id, row.channel_id, x/len(row.articlevecter)
articlevector = articlevector.rdd.map(avg_vector).toDF(['article_id', 'channel_id', 'articlevector'])