In [1]:
import os
import sys
os.environ["PYSPARK_PYTHON"]='/opt/anaconda/envs/bd9/bin/python'
os.environ["SPARK_HOME"]='/usr/hdp/current/spark2-client'
os.environ["PYSPARK_SUBMIT_ARGS"]='--num-executors 5 --executor-memory 4g --executor-cores 1 --driver-memory 2g pyspark-shell'

spark_home = os.environ.get('SPARK_HOME', None)

sys.path.insert(0, os.path.join(spark_home, 'python'))
sys.path.insert(0, os.path.join(spark_home, 'python/lib/py4j-0.10.7-src.zip'))

In [2]:
from pyspark import SparkConf
from pyspark.sql import SparkSession


conf = SparkConf()
conf.set("spark.app.name", "Konstantin Diakvnishvili lab 2 app") 

spark = SparkSession.builder.config(conf=conf).getOrCreate()

In [3]:
spark

In [46]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import HashingTF, IDF, StopWordsRemover
from pyspark.sql.functions import lower, col, udf, pandas_udf
from pyspark.sql.types import FloatType, ArrayType, StringType
import pandas as pd
import re
import json

In [5]:
dataset = spark.read.json("/labs/slaba02/DO_record_per_line.json")

In [6]:
@pandas_udf(ArrayType(StringType()))
def tokenizer_udf(series):
    regex = re.compile(u'[\w\d]{2,}', re.U)
    words = series.str.findall(regex)
    return words

In [7]:
dataset = dataset.withColumn("desc_l", lower(col('desc')))\
                 .withColumn("words", tokenizer_udf('desc_l'))

In [8]:
stop_words = StopWordsRemover.loadDefaultStopWords("english") + \
             StopWordsRemover.loadDefaultStopWords("spanish") + \
             StopWordsRemover.loadDefaultStopWords("russian")
swr = StopWordsRemover(inputCol="words", outputCol="words_filtered", stopWords=stop_words)

In [9]:
hasher = HashingTF(numFeatures=10000, binary=False, inputCol=swr.getOutputCol(), outputCol="tf")
idf = IDF(inputCol=hasher.getOutputCol(), outputCol="tfidf")

In [10]:
preprocessing = Pipeline(stages=[
    swr,
    hasher,
    idf
])

In [11]:
preprocessing_model = preprocessing.fit(dataset)

In [12]:
dataset2 = preprocessing_model.transform(dataset)

In [13]:
dataset2 = dataset2.drop("cat", "desc", "desc_l", "name", "provider", "words", "words_filtered", "tf").cache()
dataset2.show(2)

+---+----+--------------------+
| id|lang|               tfidf|
+---+----+--------------------+
|  4|  en|(10000,[36,63,138...|
|  5|  en|(10000,[32,222,36...|
+---+----+--------------------+
only showing top 2 rows



In [14]:
@udf(FloatType())
def cosin_sim(v, u):
      return float(u.dot(v) / (v.norm(2) * u.norm(2)))

In [63]:
test = ((23126, "en"), (21617, "en"), (16627, "es"), (11556, "es"), (16704, "ru"), (13702, "ru"))
result = {}

In [69]:
for item in test:
    ds2_1 = dataset2.filter(dataset2.id == item[0]).withColumnRenamed("id", "c_id")\
                                               .withColumnRenamed("lang", "c_lang")\
                                               .withColumnRenamed("tfidf", "c_tfidf")\
                                               .cache()
    ds2_a = dataset2.filter(dataset2.lang == item[1]).cache()
    
    dataset3 = ds2_a.join(ds2_1, ds2_a.id != ds2_1.c_id)\
                    .withColumn("cosin_sim", cosin_sim("tfidf", "c_tfidf") )\
                    .dropna(subset="cosin_sim")\
                   . orderBy("cosin_sim", ascending=False).limit(10)
    
    result_list = dataset3.select("id").rdd.flatMap(lambda x: x).collect()
    result.update({item[0]: result_list})

In [70]:
result

{23126: [13665, 14760, 13782, 20638, 24419, 15909, 2724, 25782, 17499, 13348],
 21617: [21609, 21616, 22298, 21608, 21628, 21630, 21081, 21623, 19417, 21508],
 16627: [11431, 5687, 17964, 12660, 12247, 17961, 16694, 5558, 11575, 13551],
 11556: [16488, 468, 19330, 10447, 23357, 21707, 22710, 13461, 10384, 13776],
 16704: [1236, 1247, 1365, 1164, 1273, 20288, 8186, 1233, 8203, 18331],
 13702: [864, 28074, 1041, 21079, 8300, 13057, 8313, 21025, 1033, 1111]}

In [71]:
import json
with open("lab02.json", "w") as json_file:
    json.dump(result, json_file, indent = 4)

In [74]:
spark.stop()

In [73]:
!cat lab02.json

{
    "23126": [
        13665,
        14760,
        13782,
        20638,
        24419,
        15909,
        2724,
        25782,
        17499,
        13348
    ],
    "21617": [
        21609,
        21616,
        22298,
        21608,
        21628,
        21630,
        21081,
        21623,
        19417,
        21508
    ],
    "16627": [
        11431,
        5687,
        17964,
        12660,
        12247,
        17961,
        16694,
        5558,
        11575,
        13551
    ],
    "11556": [
        16488,
        468,
        19330,
        10447,
        23357,
        21707,
        22710,
        13461,
        10384,
        13776
    ],
    "16704": [
        1236,
        1247,
        1365,
        1164,
        1273,
        20288,
        8186,
        1233,
        8203,
        18331
    ],
    "13702": [
        864,
        28074,
        1041,
        21079,
        8300,
   