In [1]:
import findspark
findspark.init()

In [2]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StringType,
    ArrayType,
    FloatType
)
from pyspark.sql.functions import (
    udf,
    col
)
from pyspark.ml.feature import Word2Vec
# !pip install jieba
import jieba

In [3]:
spark = (SparkSession
    .builder
    .appName("pyspark-word2vec-cosine-similarity")
    .getOrCreate())

In [4]:
df = (spark.
      read.
      csv('/home/eric/Sync/datasets/misc/tech-posts.csv', header=True)
      .select('id', 'title')).cache()

In [5]:
def jieba_seg(x):
    return [w for w in jieba.cut(x) if len(w)>1]

In [6]:
jieba_seg_udf = udf(jieba_seg, ArrayType(StringType()))

In [7]:
df = df.withColumn('words', jieba_seg_udf(df['title']))

In [8]:
df.printSchema()

root
 |-- id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- words: array (nullable = true)
 |    |-- element: string (containsNull = true)



In [9]:
model = Word2Vec(numPartitions=10, inputCol='words', outputCol='vecs').fit(df)

In [10]:
model.getVectors().count()

2703

In [11]:
df_transformed = model.transform(df)

In [12]:
df_cross = df_transformed.select(
    col('id').alias('id1'),
    col('vecs').alias('vecs1')).crossJoin(df_transformed.select(
        col('id').alias('id2'),
        col('vecs').alias('vecs2'))
)

In [13]:
df_cross.show()

+---+--------------------+---+--------------------+
|id1|               vecs1|id2|               vecs2|
+---+--------------------+---+--------------------+
|  9|[0.00491762642050...|  9|[0.00491762642050...|
|  9|[0.00491762642050...| 16|[0.01621099077165...|
|  9|[0.00491762642050...|  8|[-0.0403248976605...|
|  9|[0.00491762642050...|  1|[-0.0259283815510...|
|  9|[0.00491762642050...|  5|[0.02066894519763...|
|  9|[0.00491762642050...|  6|[0.03616197061880...|
|  9|[0.00491762642050...|  7|[0.01177302086725...|
|  9|[0.00491762642050...|  2|[-0.0077832460713...|
|  9|[0.00491762642050...| 72|[0.00162717599887...|
|  9|[0.00491762642050...|  4|[0.00955085083842...|
|  9|[0.00491762642050...| 73|[0.03020236175507...|
|  9|[0.00491762642050...| 10|[-0.0089490925893...|
|  9|[0.00491762642050...| 75|[8.69367737323045...|
|  9|[0.00491762642050...| 11|[-0.0056437781546...|
|  9|[0.00491762642050...| 76|[-0.0014370117569...|
|  9|[0.00491762642050...| 12|[0.01493579955838...|
|  9|[0.0049

In [14]:
from scipy import spatial

@udf(returnType=FloatType())
def sim(x, y):
    return float(1 - spatial.distance.cosine(x, y))

In [15]:
df_cross = df_cross.withColumn('sim', sim(df_cross['vecs1'], df_cross['vecs2']))

In [16]:
test_id = 7445

In [17]:
pdf1 = df_cross.filter(col('id1')==test_id).toPandas()

In [22]:
sim_top10 = pdf1[pdf1.sim<1].sort_values('sim', ascending=False).head(10)
sim_top10

Unnamed: 0,id1,vecs1,id2,vecs2,sim
14297,7445,"[-0.003779336577281356, -0.1523637380450964, -...",6150,"[0.010635668052903686, -0.1240268386900425, -0...",0.986141
12974,7445,"[-0.003779336577281356, -0.1523637380450964, -...",4558,"[-0.015396797796711326, -0.17607890628278255, ...",0.984285
5155,7445,"[-0.003779336577281356, -0.1523637380450964, -...",13264,"[-0.012317438237369062, -0.14086312502622605, ...",0.984285
2203,7445,"[-0.003779336577281356, -0.1523637380450964, -...",2219,"[0.006141198022911945, -0.07564452496202041, -...",0.980139
11891,7445,"[-0.003779336577281356, -0.1523637380450964, -...",11813,"[-0.008300341665744781, -0.14514850452542305, ...",0.979465
13691,7445,"[-0.003779336577281356, -0.1523637380450964, -...",5431,"[0.012424760265275836, -0.12904690578579903, -...",0.978694
14950,7445,"[-0.003779336577281356, -0.1523637380450964, -...",6892,"[-0.00781940243073872, -0.07842672137277466, -...",0.978664
4264,7445,"[-0.003779336577281356, -0.1523637380450964, -...",12563,"[0.0004664606281689235, -0.09541171576295579, ...",0.977603
13695,7445,"[-0.003779336577281356, -0.1523637380450964, -...",5429,"[-0.010000718850642444, -0.09775994010269642, ...",0.977081
1781,7445,"[-0.003779336577281356, -0.1523637380450964, -...",1804,"[-0.011628177358943503, -0.0931693238671869, -...",0.97645


In [19]:
df.filter(col('id')==test_id).toPandas()

Unnamed: 0,id,title,words
0,7445,iOS 性能提升总结,"[iOS, 性能, 提升, 总结]"


In [20]:
df.filter(df.id.isin(sim_top10['id2'].to_list())).toPandas()

Unnamed: 0,id,title,words
0,1804,[译] Android 性能优化课程（一）：渲染性能 (@刘智勇同学),"[Android, 性能, 优化, 课程, 渲染, 性能, 智勇, 同学]"
1,2219,iOS 图形性能进阶与测试 (@方秋枋),"[iOS, 图形, 性能, 进阶, 测试, 方秋枋]"
2,12563,iOS 开发中的 11 种锁以及性能对比,"[iOS, 开发, 11, 种锁, 以及, 性能, 对比]"
3,13264,Android 性能优化来龙去脉总结,"[Android, 性能, 优化, 来龙去脉, 总结]"
4,11813,Android App 性能优化的一些思考,"[Android, App, 性能, 优化, 一些, 思考]"
5,4558,[译] Android 性能优化总结,"[Android, 性能, 优化, 总结]"
6,5431,iOS 开发性能提高,"[iOS, 开发, 性能, 提高]"
7,5429,[译] Android 界面性能调优手册,"[Android, 界面, 性能, 调优, 手册]"
8,6150,微信读书 iOS 性能优化总结,"[微信, 读书, iOS, 性能, 优化, 总结]"
9,6892,[译] iOS 性能优化：Instruments 工具的救命三招,"[iOS, 性能, 优化, Instruments, 工具, 救命, 三招]"


In [21]:
spark.stop()