In [1]:
# coding: utf-8
"""
pysparkで類似度を図るためのサンプル
"""
import numpy as np
from pyspark import SQLContext, sql
import pyspark
from pyspark.sql import functions, Row
from pyspark.mllib.linalg import DenseVector

# 1. 初期化

In [2]:
sc = pyspark.SparkContext('local[*]')
sqlContext = sql.SQLContext(sc)

# 2. サンプルの作成
* label1~label5はバイナリ、label6~label10は連続値

In [3]:
samples = [['aaa', 'a', 30, 1,2,3,4,5]  + np.random.randn(5).tolist(),
    ['aaa', 'b', 30,2,1,3,4,1] + np.random.randn(5).tolist(),
    ['bbb', 'a', 30,4,5,3,2,4] + np.random.randn(5).tolist(),
    ['bbb', 'b', 30,1,2,4,3,1] + np.random.randn(5).tolist(),
    ['ccc', 'a', 30,4,5,2,1,2] + np.random.randn(5).tolist(),
    ['ccc', 'b', 30,1,2,5,4,1] + np.random.randn(5).tolist(),]
colnames = [
    'mc', 'mtc', 'area_cd',
    'label1', 'label2', 'label3', 'label4', 'label5',
    'label6', 'label7', 'label8', 'label9', 'label10'
]
colnames1 = [col + '_1' for col in colnames]
colnames2 = [col + '_2' for col in colnames]

# 3. Pyspark Dataframeの作成

In [4]:
df1 = sqlContext.createDataFrame(sc.parallelize(samples), colnames1)
df2 = sqlContext.createDataFrame(sc.parallelize(samples), colnames2)

# 4. 類似度を図るために、組み合わせを作成

In [5]:
joined_df = df1.join(df2, df1.area_cd_1 == df2.area_cd_2).filter(functions.concat(df1.mc_1, df1.mtc_1) != functions.concat(df2.mc_2, df2.mtc_2))

* 結果の出力

In [6]:
joined_df.show(5)

+----+-----+---------+--------+--------+--------+--------+--------+------------------+------------------+-------------------+--------------------+------------------+----+-----+---------+--------+--------+--------+--------+--------+--------------------+--------------------+--------------------+-------------------+-------------------+
|mc_1|mtc_1|area_cd_1|label1_1|label2_1|label3_1|label4_1|label5_1|          label6_1|          label7_1|           label8_1|            label9_1|         label10_1|mc_2|mtc_2|area_cd_2|label1_2|label2_2|label3_2|label4_2|label5_2|            label6_2|            label7_2|            label8_2|           label9_2|          label10_2|
+----+-----+---------+--------+--------+--------+--------+--------+------------------+------------------+-------------------+--------------------+------------------+----+-----+---------+--------+--------+--------+--------+--------+--------------------+--------------------+--------------------+-------------------+----------------

# 5.類似度の計算 

In [7]:
def match_sim(row1 ,row2):
    keys = row1.asDict().keys()
    total = len(keys)
    count = 0
    for key in keys:
        if row1[key] == row2[key]:
            count += 1
    return float(count)/total

def cosine_sim(vec1 ,vec2):
    dot = abs(vec1.dot(vec2))
    n1 = vec1.norm(None)
    n2 = vec1.norm(None)
    return float(dot/(n1*n2))

joined_rdd = joined_df.rdd.map(lambda x: (
    Row(mc_1=x.mc_1, mtc_1=x.mtc_1, mc_2=x.mc_2, mtc_2=x.mtc_2),
    Row(label1=x.label1_1, label2=x.label2_1, label3=x.label3_1, label4=x.label4_1, label5=x.label5_1),
    DenseVector([x.label6_1,x.label7_1,x.label8_1,x.label9_1,x.label10_1]),
    Row(label1=x.label1_2, label2=x.label2_2, label3=x.label3_2, label4=x.label4_2, label5=x.label5_2),
    DenseVector([x.label6_2,x.label7_2,x.label8_2,x.label9_2,x.label10_2])
                                         )) \
.map(lambda x: (x[0], match_sim(x[1], x[3]), cosine_sim(x[2], x[4]))) \
.map(lambda x: (x[0].mc_1, x[0].mtc_1, x[0].mc_2, x[0].mtc_2, x[1], x[2]))

# 5. 結果の出力

In [8]:
sqlContext.createDataFrame(joined_rdd, ['tar_mc', 'tar_mtc', 'res_mc', 'res_mtc', 'match_sim', 'cosine_sim']).show()

+------+-------+------+-------+---------+--------------------+
|tar_mc|tar_mtc|res_mc|res_mtc|match_sim|          cosine_sim|
+------+-------+------+-------+---------+--------------------+
|   aaa|      a|   aaa|      b|      0.4|  0.2979433262317515|
|   aaa|      a|   bbb|      a|      0.2|  0.2161103600613806|
|   aaa|      a|   bbb|      b|      0.4|  0.6933162039799152|
|   aaa|      a|   ccc|      a|      0.0| 0.34941331375143353|
|   aaa|      a|   ccc|      b|      0.6|  0.5354750033557132|
|   aaa|      b|   aaa|      a|      0.4| 0.19428899651078324|
|   aaa|      b|   bbb|      a|      0.2| 0.10702152405150611|
|   aaa|      b|   bbb|      b|      0.2|  0.4033681950723296|
|   aaa|      b|   ccc|      a|      0.0| 0.20097172584128625|
|   aaa|      b|   ccc|      b|      0.4|  0.6861144738544892|
|   bbb|      a|   aaa|      a|      0.2|  0.3590385377694502|
|   bbb|      a|   aaa|      b|      0.2| 0.27266040008605663|
|   bbb|      a|   bbb|      b|      0.0|  1.1313716028