In [1]:
import os
import sys
# 如果当前代码文件运行测试需要加入修改路径，避免出现后导包问题
BASE_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.join(BASE_DIR))

PYSPARK_PYTHON = "/miniconda2/envs/reco_sys/bin/python"
# 当存在多个版本时，不指定很可能会导致出错
os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON

from offline import SparkSessionBase
from settings.default import CHANNEL_INFO
from pyspark.ml.feature import Word2Vec


In [2]:
class TrainWord2VecModel(SparkSessionBase):

    SPARK_APP_NAME = "Word2Vec"
    #SPARK_URL = "yarn"

    ENABLE_HIVE_SUPPORT = True

    def __init__(self):
        self.spark = self._create_spark_session()

In [3]:
w2v = TrainWord2VecModel()

In [4]:
def segmentation(partition):
    import os
    import re

    import jieba
    import jieba.analyse
    import jieba.posseg as pseg
    import codecs

    abspath = "/root/words"

    # 结巴加载用户词典
    userDict_path = os.path.join(abspath, "ITKeywords.txt")
    jieba.load_userdict(userDict_path)

    # 停用词文本
    stopwords_path = os.path.join(abspath, "stopwords.txt")

    def get_stopwords_list():
        """返回stopwords列表"""
        stopwords_list = [i.strip()
                          for i in codecs.open(stopwords_path).readlines()]
        return stopwords_list

    # 所有的停用词列表
    stopwords_list = get_stopwords_list()

    # 分词
    def cut_sentence(sentence):
        """对切割之后的词语进行过滤，去除停用词，保留名词，英文和自定义词库中的词，长度大于2的词"""
        # print(sentence,"*"*100)
        # eg:[pair('今天', 't'), pair('有', 'd'), pair('雾', 'n'), pair('霾', 'g')]
        seg_list = pseg.lcut(sentence)
        seg_list = [i for i in seg_list if i.word not in stopwords_list]
        filtered_words_list = []
        for seg in seg_list:
            # print(seg)
            if len(seg.word) <= 1:
                continue
            elif seg.flag == "eng":
                if len(seg.word) <= 2:
                    continue
                else:
                    filtered_words_list.append(seg.word)
            elif seg.flag.startswith("n"):
                filtered_words_list.append(seg.word)
            elif seg.flag in ["x", "eng"]:  # 是自定一个词语或者是英文单词
                filtered_words_list.append(seg.word)
        return filtered_words_list

    for row in partition:
        sentence = re.sub("<.*?>", "", row.sentence)    # 替换掉标签数据
        words = cut_sentence(sentence)
        yield row.article_id, row.channel_id, words

In [7]:
w2v.spark.sql('use article')

DataFrame[]

In [9]:
article = w2v.spark.sql('select * from article_data where channel_id=18 limit 2')

In [10]:
words_df = article.rdd.mapPartitions(segmentation).toDF(['article_id', 'channel_id', 'words'])

In [12]:
new_word2vec = Word2Vec(vectorSize=100, inputCol="words", outputCol="model", minCount=3)

In [13]:
new_model = new_word2vec.fit(words_df)

In [15]:
new_model.getVectors().show()

+-----------+--------------------+
|       word|              vector|
+-----------+--------------------+
|       函数参数|[0.00343297771178...|
|  recipient|[0.01848481036722...|
|   register|[0.01116304472088...|
|        fib|[0.04160280898213...|
|         函数|[0.02662521973252...|
|     encode|[-0.0152690997347...|
|    network|[0.01590051129460...|
|         __|[-0.0548076145350...|
|     xrange|[0.08385339379310...|
|      Proof|[-0.0322913937270...|
|        表达式|[-0.0057290056720...|
|         时间|[-0.0049802274443...|
|Transaction|[-0.0103443199768...|
|    request|[-0.0376213230192...|
|       HTTP|[0.00963569525629...|
|        ram|[-0.0141640864312...|
|  timestamp|[-0.0133891822770...|
|         交易|[0.01113164052367...|
|       join|[0.04276565462350...|
| identifier|[-0.0140685094520...|
+-----------+--------------------+
only showing top 20 rows



In [16]:
new_model.getVectors().rdd.take(5)

[Row(word='函数参数', vector=DenseVector([0.0034, 0.0326, -0.0283, 0.0064, 0.0449, 0.0189, 0.0449, 0.0466, -0.0078, -0.0199, 0.0116, -0.0082, 0.0009, -0.0074, -0.0204, 0.0375, -0.0054, 0.0126, -0.0307, -0.0104, -0.0286, 0.0193, -0.0121, 0.0144, 0.0045, -0.0144, 0.0212, 0.0249, -0.0235, 0.0024, 0.0339, -0.0027, 0.0351, -0.0383, -0.0086, 0.003, -0.0169, 0.0247, 0.0281, -0.0105, -0.0006, 0.017, 0.0226, 0.0251, 0.0434, -0.0254, -0.0021, -0.01, -0.0097, -0.0008, -0.0223, -0.018, -0.0084, 0.0101, -0.0217, 0.0192, 0.0474, -0.0272, 0.0127, -0.0148, 0.0011, -0.0178, -0.0148, -0.0284, -0.0237, -0.0052, -0.0357, 0.0189, 0.0296, 0.0276, -0.0301, -0.0073, -0.0172, -0.0257, 0.0197, -0.0234, -0.0208, -0.0417, -0.0325, -0.0309, 0.0016, 0.0173, 0.0152, 0.0114, 0.0238, -0.0319, -0.0321, -0.0309, 0.0104, 0.0361, 0.0213, 0.0117, 0.0004, 0.0013, 0.0393, -0.0289, 0.052, -0.0416, -0.0049, -0.0189])),
 Row(word='recipient', vector=DenseVector([0.0185, -0.0973, 0.102, -0.0404, -0.1631, -0.0154, -0.1834, -0.1838, 0

In [5]:
from pyspark.ml.feature import Word2VecModel


In [6]:
channel_id = 18
channel = "python"

In [7]:
wv_model = Word2VecModel.load('file:///root/bak/modelsbak/word2vec_model/channel_%d_%s.word2vec' % (channel_id, channel))

In [19]:
vectors = wv_model.getVectors()

In [8]:
wv_model.getVectors().rdd.take(1)

[Row(word='广义', vector=DenseVector([0.2891, -0.1201, 0.2581, 0.0197, 0.0078, 0.0804, -0.1394, 0.0727, -0.0221, 0.284, -0.044, -0.0103, 0.0481, -0.3808, 0.1075, -0.2184, -0.2475, 0.1366, -0.2287, -0.2782, 0.2139, 0.0117, -0.09, -0.0209, -0.0138, -0.0915, 0.0368, -0.1199, 0.0512, 0.2744, -0.0565, 0.3728, 0.3951, -0.005, 0.0649, -0.3247, -0.1027, 0.2947, -0.0416, -0.017, 0.3251, 0.011, 0.1175, 0.2295, -0.0771, -0.253, -0.4119, 0.0251, 0.086, -0.0341, -0.0386, 0.4387, -0.108, -0.0404, -0.2936, -0.0651, -0.1807, -0.0766, -0.0907, -0.2319, -0.0095, -0.0735, -0.4099, 0.2474, 0.1552, -0.3484, 0.0463, -0.256, -0.1195, -0.0489, -0.1076, -0.1659, -0.0391, 0.16, -0.2375, -0.2284, 0.0499, 0.3306, -0.1903, -0.0508, -0.0994, 0.0779, -0.0313, -0.415, 0.5217, -0.2161, 0.046, -0.1052, 0.5156, -0.2391, -0.1148, -0.2006, -0.2139, -0.01, -0.3468, -0.1115, 0.1438, -0.0014, 0.2076, -0.2563]))]

In [12]:
w2v.spark.sql('use article')

DataFrame[]

In [13]:
profile = w2v.spark.sql("select * from article_profile where channel_id=18 limit 10")


In [14]:
profile.show()

+----------+----------+--------------------+--------------------+
|article_id|channel_id|            keywords|              topics|
+----------+----------+--------------------+--------------------+
|     13098|        18|Map(pre -> 0.6040...|[__, object, 属性, ...|
|     13248|        18|Map(有限元 -> 5.2929...|[有限元, 代码分析, 案例, z...|
|     13401|        18|Map(pre -> 0.2100...|[补码, 字符串, 李白, typ...|
|     13723|        18|Map(pre -> 2.1094...|[acc, bstr, 原地, l...|
|     14719|        18|Map(pre -> 0.8814...|[__, ctime, cons,...|
|     14846|        18|Map(__ -> 2.54674...|[files, __, folde...|
|     15173|        18|Map(人人 -> 0.74986...|[cookie, Python爬虫...|
|     15194|        18|Map(dif -> 0.7567...|[display, 课程, lis...|
|     15237|        18|Map(pre -> 0.5349...|[__, send, sel, c...|
|     15322|        18|Map(pre -> 0.5762...|[Pclass, replace,...|
+----------+----------+--------------------+--------------------+



In [16]:
profile.registerTempTable("incremental")


In [17]:
articleKeywordsWeights = w2v.spark.sql(
                "select article_id, channel_id, keyword, weight from incremental LATERAL VIEW explode(keywords) AS keyword, weight")


In [18]:
articleKeywordsWeights.show()

+----------+----------+--------+-------------------+
|article_id|channel_id| keyword|             weight|
+----------+----------+--------+-------------------+
|     13098|        18|    repr| 0.6326590117716192|
|     13098|        18|      __| 2.5401122038114203|
|     13098|        18|      属性|0.23645924932468856|
|     13098|        18|     pre| 0.6040062287555379|
|     13098|        18|    code| 0.9531379029975557|
|     13098|        18|     def| 0.5063435861497416|
|     13098|        18|   color| 1.1337936117177925|
|     13098|        18|      定义| 0.1554380122061322|
|     13098|        18| Student| 0.5033771372284416|
|     13098|        18|getPrice| 0.7404427038950527|
|     13098|        18|      方法|0.08080845613717194|
|     13098|        18|     div| 0.3434819820586186|
|     13098|        18|     str|0.35999033790156054|
|     13098|        18|      pa| 0.6651385256756351|
|     13098|        18|   slots| 0.6992789472129189|
|     13098|        18| cnblogs|0.339265861020

In [20]:
_article_profile = articleKeywordsWeights.join(vectors, vectors.word==articleKeywordsWeights.keyword, "inner")


In [22]:
_article_profile.show(40)

+----------+----------+----------+-------------------+----------+--------------------+
|article_id|channel_id|   keyword|             weight|      word|              vector|
+----------+----------+----------+-------------------+----------+--------------------+
|     13098|        18|      repr| 0.6326590117716192|      repr|[0.21036092936992...|
|     13098|        18|        __| 2.5401122038114203|        __|[0.01188501343131...|
|     13098|        18|        属性|0.23645924932468856|        属性|[-0.0808837264776...|
|     13098|        18|       pre| 0.6040062287555379|       pre|[0.57123136520385...|
|     13098|        18|      code| 0.9531379029975557|      code|[0.36302515864372...|
|     13098|        18|       def| 0.5063435861497416|       def|[0.15328533947467...|
|     13098|        18|     color| 1.1337936117177925|     color|[0.54077166318893...|
|     13098|        18|        定义| 0.1554380122061322|        定义|[-0.0069237630814...|
|     13098|        18|   Student| 0.503377

In [23]:
articleKeywordVectors = _article_profile.rdd.map(lambda row: (row.article_id, row.channel_id, row.keyword, row.weight * row.vector)).toDF(["article_id", "channel_id", "keyword", "weightingVector"])


In [24]:
articleKeywordVectors.show()

+----------+----------+--------+--------------------+
|article_id|channel_id| keyword|     weightingVector|
+----------+----------+--------+--------------------+
|     13098|        18|    repr|[0.13308673769053...|
|     13098|        18|      __|[0.03018926765933...|
|     13098|        18|      属性|[-0.0191257052454...|
|     13098|        18|     pre|[0.34502730264365...|
|     13098|        18|    code|[0.34601303844503...|
|     13098|        18|     def|[0.07761504849378...|
|     13098|        18|   color|[0.61312345712161...|
|     13098|        18|      定义|[-0.0010762159703...|
|     13098|        18| Student|[0.09441257176805...|
|     13098|        18|getPrice|[-0.0847735848446...|
|     13098|        18|      方法|[-0.0048283701284...|
|     13098|        18|     div|[0.04037546778136...|
|     13098|        18|     str|[-0.0134091549740...|
|     13098|        18|      pa|[0.08286002598213...|
|     13098|        18|   slots|[-0.2270685226558...|
|     13098|        18| cnbl

#### 计算文章向量

In [26]:
def avg(row):
    x = 0
    for v in row.vectors:
        x += v
    #  将平均向量作为article的向量
    return row.article_id, row.channel_id, x / len(row.vectors)



In [27]:
articleKeywordVectors.registerTempTable('tempTable')

In [28]:
articleVector = w2v.spark.sql('select article_id, min(channel_id) channel_id, collect_set(weightingVector) vectors from tempTable group by article_id').rdd.map(avg).toDF(['article_id', 'channel_id', 'articleVector'])


In [29]:
articleVector.show()

+----------+----------+--------------------+
|article_id|channel_id|       articleVector|
+----------+----------+--------------------+
|     13098|        18|[0.10339950907039...|
|     13248|        18|[0.84907054580879...|
|     13401|        18|[0.06157120217893...|
|     13723|        18|[0.20708073724961...|
|     14719|        18|[-0.0405607722081...|
|     14846|        18|[0.17945355257543...|
|     15173|        18|[-0.2399774663757...|
|     15194|        18|[0.08605245220126...|
|     15237|        18|[0.02019666206037...|
|     15322|        18|[0.11985676790665...|
+----------+----------+--------------------+



In [30]:
articleVector.rdd.take(1)

[Row(article_id=13098, channel_id=18, articleVector=DenseVector([0.1034, 0.0785, 0.0041, 0.0603, -0.0141, -0.029, 0.0496, 0.1423, -0.0637, 0.01, -0.1751, 0.0785, 0.0409, 0.1243, -0.0123, 0.0188, -0.058, -0.0584, -0.0081, 0.0583, -0.0289, 0.0535, 0.0049, 0.0763, -0.1058, -0.1615, 0.0606, 0.0193, 0.022, -0.1981, -0.0388, 0.0176, 0.2267, 0.0398, -0.1301, 0.0255, -0.1784, -0.0603, 0.0013, -0.0518, 0.1123, -0.0474, -0.1108, 0.0204, 0.0429, 0.0315, 0.0184, 0.2174, -0.0671, 0.0786, -0.0197, -0.0552, 0.1391, 0.0997, 0.098, -0.0035, 0.1513, -0.0479, 0.0127, -0.0037, 0.1215, -0.0365, 0.0724, -0.1241, 0.0419, 0.0564, -0.0454, -0.0452, 0.0281, -0.1254, -0.1335, -0.0174, -0.0413, -0.0709, 0.0937, -0.0294, 0.0517, 0.0954, 0.0809, 0.0077, 0.1265, -0.1395, 0.0913, 0.0066, -0.0008, 0.0983, 0.0209, 0.0487, -0.1136, 0.1168, 0.0356, 0.087, 0.0318, 0.0095, -0.0157, 0.0318, -0.0545, -0.0834, 0.0832, 0.2326]))]

### 计算相似度

In [34]:
from pyspark.ml.linalg import Vectors
# 选取部分数据做测试
article_vector = w2v.spark.sql("select article_id, articlevector from article_vector where channel_id=18 limit 10")


In [35]:
train = article_vector.select(['article_id', 'articlevector'])


In [38]:
def _array_to_vector(row):
    return row.article_id, Vectors.dense(row.articlevector)

In [39]:
train = train.rdd.map(_array_to_vector).toDF(['article_id', 'articleVector'])


In [40]:
train.show()

+----------+--------------------+
|article_id|       articleVector|
+----------+--------------------+
|     13098|[0.10339950907039...|
|     13248|[0.84907054580879...|
|     13401|[0.06157120217893...|
|     13723|[0.20708073724961...|
|     14719|[-0.0405607722081...|
|     14846|[0.17945355257543...|
|     15173|[-0.2399774663757...|
|     15194|[0.08605245220126...|
|     15237|[0.02019666206037...|
|     15322|[0.11985676790665...|
+----------+--------------------+



In [41]:
from pyspark.ml.feature import BucketedRandomProjectionLSH

# 默认4，10，官方推荐使用大小
brp = BucketedRandomProjectionLSH(inputCol='articleVector', outputCol='hashes', numHashTables=4.0, bucketLength=10.0)
model = brp.fit(train)

In [42]:
similar = model.approxSimilarityJoin(train, train, 2.0, distCol='EuclideanDistance')

In [43]:
similar.show()

+--------------------+--------------------+------------------+
|            datasetA|            datasetB| EuclideanDistance|
+--------------------+--------------------+------------------+
|[15237,[0.0201966...|[15237,[0.0201966...|               0.0|
|[15194,[0.0860524...|[15237,[0.0201966...|0.8329097179846192|
|[13098,[0.1033995...|[13098,[0.1033995...|               0.0|
|[15322,[0.1198567...|[15237,[0.0201966...| 0.898816446815378|
|[13401,[0.0615712...|[15173,[-0.239977...|1.4193381175720314|
|[14719,[-0.040560...|[13098,[0.1033995...|1.4624429314093754|
|[15237,[0.0201966...|[13723,[0.2070807...|0.9242292591649734|
|[13401,[0.0615712...|[13723,[0.2070807...|0.7932578595105071|
|[13401,[0.0615712...|[14719,[-0.040560...|1.4501024819771564|
|[15237,[0.0201966...|[15194,[0.0860524...|0.8329097179846192|
|[14719,[-0.040560...|[14846,[0.1794535...|1.5597513614699037|
|[13098,[0.1033995...|[13723,[0.2070807...| 0.765768247253341|
|[14719,[-0.040560...|[13401,[0.0615712...|1.4501024819

In [44]:
similar.sort(['EuclideanDistance']).show()


+--------------------+--------------------+-------------------+
|            datasetA|            datasetB|  EuclideanDistance|
+--------------------+--------------------+-------------------+
|[13723,[0.2070807...|[13723,[0.2070807...|                0.0|
|[15322,[0.1198567...|[15322,[0.1198567...|                0.0|
|[13401,[0.0615712...|[13401,[0.0615712...|                0.0|
|[14719,[-0.040560...|[14719,[-0.040560...|                0.0|
|[15194,[0.0860524...|[15194,[0.0860524...|                0.0|
|[13098,[0.1033995...|[13098,[0.1033995...|                0.0|
|[14846,[0.1794535...|[14846,[0.1794535...|                0.0|
|[13248,[0.8490705...|[13248,[0.8490705...|                0.0|
|[15173,[-0.239977...|[15173,[-0.239977...|                0.0|
|[15237,[0.0201966...|[15237,[0.0201966...|                0.0|
|[15237,[0.0201966...|[13401,[0.0615712...|0.42729625714112773|
|[13401,[0.0615712...|[15237,[0.0201966...|0.42729625714112773|
|[14846,[0.1794535...|[13098,[0.1033995.

In [45]:
similar.rdd.take(1)

[Row(datasetA=Row(article_id=15237, articleVector=DenseVector([0.0202, 0.0611, 0.0567, 0.0019, 0.0277, -0.0515, 0.0181, 0.0424, -0.0214, -0.0072, -0.0413, 0.0591, 0.076, 0.0359, 0.0016, 0.0108, 0.0451, -0.0207, 0.0049, 0.0635, -0.0529, 0.0907, 0.0126, 0.0108, -0.0111, -0.0129, -0.019, -0.005, 0.0292, -0.0707, -0.0297, -0.0005, 0.033, -0.0064, -0.0198, 0.0333, -0.0161, 0.0024, -0.0176, -0.0089, -0.0203, -0.0167, -0.0107, -0.0153, -0.0143, 0.0538, 0.0619, -0.0342, -0.0104, -0.0136, 0.0035, -0.0202, 0.0269, 0.0077, 0.0316, -0.0169, 0.0326, -0.0306, 0.0606, -0.0024, 0.0378, -0.0242, 0.0083, -0.0109, 0.0434, 0.0041, 0.0328, -0.0408, 0.0248, -0.0225, 0.0366, 0.0037, -0.0087, -0.0183, -0.0349, 0.0802, -0.0003, -0.0192, 0.0562, -0.0312, 0.043, -0.0167, -0.0184, 0.0033, -0.0167, 0.038, 0.0632, 0.0453, -0.0499, 0.056, 0.0097, 0.0576, 0.0362, -0.0154, 0.0354, 0.0665, -0.0622, -0.0416, -0.0229, 0.0795]), hashes=[DenseVector([-1.0]), DenseVector([-1.0]), DenseVector([-1.0]), DenseVector([0.0])]), d