In [1]:
import pyspark.ml.base
import pyspark.sql.functions as F
from pyspark.sql.functions import col
from pyspark.sql.functions import countDistinct
from pyspark.sql.types import LongType

In [2]:
import metaspore as ms

author='sunkai'
app_name='Amazon Collaborative Filtering Demo/%s'%author
local=False
worker_count = 4
server_count = 4
batch_size=256
worker_memory='5G'
server_memory='5G'
coordinator_memory='5G'

spark_confs={
    "spark.network.timeout":"500",
    "spark.submit.pyFiles":"python.zip",
    "spark.sql.shuffle.partitions":"1000" # default 200
}

spark = ms.spark.get_session(local=local,
                             app_name=app_name,
                             batch_size=batch_size,
                             worker_count=worker_count,
                             server_count=server_count,
                             worker_memory=worker_memory,
                             server_memory=server_memory,
                             coordinator_memory=coordinator_memory,
                             spark_confs=spark_confs)
sc = spark.sparkContext
print(sc.version)
print(sc.applicationId)
print(sc.uiWebUrl)

22/02/24 04:06:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


3.1.2
spark-application-1645675613486
http://jupyter.my.nginx.test/hub/user-redirect/proxy/4040/jobs/


In [3]:
train_path='s3://dmetasoul-bucket/demo/movielens/match/num_negs_100/train.parquet'
test_path='s3://dmetasoul-bucket/demo/movielens/match/num_negs_100/test.parquet'
item_path='s3://dmetasoul-bucket/demo/movielens/match/num_negs_100/item.parquet'

train_dataset = spark.read.parquet(train_path)
train_dataset=train_dataset[train_dataset['label']==1]
test_dataset = spark.read.parquet(test_path)

                                                                                

In [4]:
train_dataset.show(10)

[Stage 2:>                                                          (0 + 1) / 1]

+-----+-------+------+---+----------+-----+--------+--------------------+--------------------+----------+--------------------+
|label|user_id|gender|age|occupation|  zip|movie_id|    recent_movie_ids|               genre|last_movie|          last_genre|
+-----+-------+------+---+----------+-----+--------+--------------------+--------------------+----------+--------------------+
|    1|   3163|     M| 18|        15|95616|     356|32531078150033...|  ComedyRomanceWar|       101|              Comedy|
|    1|   4507|     M| 18|         4|61820|     110|29492028370345...|    ActionDramaWar|      1954|        ActionDrama|
|    1|    190|     M| 25|        17|55125|    2762|28049122651225...|            Thriller|      1089|      CrimeThriller|
|    1|   3626|     M| 25|        17|75075|      93|35442306371517...|      ComedyRomance|      2195|              Comedy|
|    1|   2015|     M| 18|         4|01003|    1917|2409951479786...|ActionAdventure...|      1037|Action

                                                                                

In [5]:
import sys
sys.path.append('../../') 
from python.item_cf_retrieval import ItemCFModel, ItemCFEstimator

In [6]:
## Using Swing implementation in MindAlpha
estimator=ItemCFEstimator(user_id_column_name='user_id',
                        item_id_column_name='movie_id',
                        behavior_column_name='label',
                        behavior_filter_value='1')

model = estimator.fit(train_dataset)

In [7]:
model.df.show(20)

                                                                                

+----+--------------------+
| key|               value|
+----+--------------------+
|1090|[{1222, 0.3960944...|
|1572|[{659, 0.39790767...|
|2088|[{2471, 0.3545055...|
|2294|[{1907, 0.3906826...|
| 296|[{608, 0.55562270...|
|3414|[{937, 0.19069249...|
|1394|[{1079, 0.4741696...|
|3826|[{3793, 0.2857708...|
| 451|[{3044, 0.2110436...|
|1280|[{3260, 0.2460103...|
|1870|[{2029, 0.2246403...|
|2700|[{223, 0.43360152...|
|2917|[{1084, 0.2770283...|
|3526|[{2918, 0.3925131...|
|3836|[{3846, 0.3019790...|
|   7|[{339, 0.39856439...|
| 307|[{308, 0.57122097...|
|3491|[{3577, 0.1859107...|
|3895|[{3896, 0.2840043...|
| 475|[{300, 0.29518867...|
+----+--------------------+
only showing top 20 rows



In [8]:
test_df = test_dataset.select('user_id', 'last_movie', 'movie_id')\
            .groupBy('user_id', 'last_movie')\
            .agg(F.collect_set('movie_id').alias('label_items'))
test_df = test_df.withColumnRenamed('last_movie', 'movie_id')
prediction_df = model.transform(test_df)
prediction_df = prediction_df.withColumnRenamed('value', 'rec_info')
prediction_df.show(10)



+-------+--------+-----------+----+--------------------+
|user_id|movie_id|label_items| key|            rec_info|
+-------+--------+-----------+----+--------------------+
|     57|    1090|     [1258]|1090|[{1222, 0.3960944...|
|    770|    1090|     [3359]|1090|[{1222, 0.3960944...|
|   1895|    1090|     [1193]|1090|[{1222, 0.3960944...|
|   5163|    1090|     [1222]|1090|[{1222, 0.3960944...|
|   4526|    1090|     [1300]|1090|[{1222, 0.3960944...|
|   4525|    1090|     [2890]|1090|[{1222, 0.3960944...|
|   3784|    1090|     [1358]|1090|[{1222, 0.3960944...|
|    769|    1090|     [1304]|1090|[{1222, 0.3960944...|
|    601|    1090|     [2106]|1090|[{1222, 0.3960944...|
|   4650|    1090|     [2359]|1090|[{1222, 0.3960944...|
+-------+--------+-----------+----+--------------------+
only showing top 10 rows



                                                                                

In [9]:
prediction_label_rdd = prediction_df.rdd.map(lambda x:(\
                                    [xx._1 for xx in x.rec_info] if x.rec_info is not None else [], \
                                     x.label_items))

In [11]:
prediction_label_rdd.take(10)

                                                                                

[(['1222',
   '1208',
   '1242',
   '1299',
   '1263',
   '1272',
   '608',
   '1259',
   '1196',
   '1960',
   '3448',
   '2000',
   '2194',
   '1200',
   '1231',
   '527',
   '2529',
   '110',
   '1250',
   '1293'],
  ['1304']),
 (['1222',
   '1208',
   '1242',
   '1299',
   '1263',
   '1272',
   '608',
   '1259',
   '1196',
   '1960',
   '3448',
   '2000',
   '2194',
   '1200',
   '1231',
   '527',
   '2529',
   '110',
   '1250',
   '1293'],
  ['2359']),
 (['1222',
   '1208',
   '1242',
   '1299',
   '1263',
   '1272',
   '608',
   '1259',
   '1196',
   '1960',
   '3448',
   '2000',
   '2194',
   '1200',
   '1231',
   '527',
   '2529',
   '110',
   '1250',
   '1293'],
  ['3359']),
 (['1222',
   '1208',
   '1242',
   '1299',
   '1263',
   '1272',
   '608',
   '1259',
   '1196',
   '1960',
   '3448',
   '2000',
   '2194',
   '1200',
   '1231',
   '527',
   '2529',
   '110',
   '1250',
   '1293'],
  ['1193']),
 (['1222',
   '1208',
   '1242',
   '1299',
   '1263',
   '1272',
   '608',




In [13]:
from pyspark.mllib.evaluation import RankingMetrics
metrics = RankingMetrics(prediction_label_rdd)



In [14]:
print("Debug -- Precision@20: ", metrics.precisionAt(20))
print("Debug -- Recall@20: ", metrics.recallAt(20))
print("Debug -- MAP@20: ", metrics.meanAveragePrecisionAt(20))

                                                                                

Debug -- Precision@20:  0.008584437086092715


                                                                                

Debug -- Recall@20:  0.17168874172185417




Debug -- MAP@20:  0.043670630377466144




In [7]:
dataset = train_dataset

dataset = dataset.withColumnRenamed('movie_id', 'item_id')
user_bhv_count = (dataset
              .groupBy(F.col('user_id'))
              .agg(F.count(F.col('item_id')).alias('item_count'))
              .withColumn("item_count",F.col("item_count").cast(LongType()))
              .filter(F.col("user_id").isNotNull() & F.col('item_count').isNotNull()))

user_bhv_count.show(20)

                                                                                

+-------+----------+
|user_id|item_count|
+-------+----------+
|   2088|       440|
|   2294|        61|
|   3414|       296|
|   5325|        47|
|    296|        96|
|   1572|        44|
|   1512|        33|
|   1090|        72|
|   1394|        26|
|   3826|        27|
|    451|       208|
|   2700|       293|
|   3526|       961|
|   3836|       509|
|   4850|        36|
|   1870|       157|
|   2917|        69|
|   1280|        49|
|      7|        30|
|   4988|        25|
+-------+----------+
only showing top 20 rows



In [8]:
import pyspark.sql.functions as functions

t1 = dataset.withColumnRenamed('item_id', 'item_id_i')
t2 = dataset.withColumnRenamed('item_id', 'item_id_j')

crossing = t1.alias('t1').join(t2.alias('t2'), on=(F.col('t1.user_id')==F.col('t2.user_id')), how='leftouter')\
                         .filter(F.col('t1.user_id').isNotNull() & F.col('t2.user_id').isNotNull() \
                                                                & (F.col('t1.item_id_i')!=F.col('t2.item_id_j'))) \
                         .groupby('t1.user_id', 't1.item_id_i', 't2.item_id_j') \
                         .agg(functions.count(functions.lit(1)).alias('crossing_count')) 

crossing_count = crossing.alias('t1').join(user_bhv_count.alias('t2'), on=(F.col('t1.user_id')==F.col('t2.user_id'))) \
                                     .filter(F.col('t2.item_count')>0) \
                                     .select('t1.user_id', 't1.item_id_i', 't1.item_id_j', 't2.item_count') 

crossing_weight = crossing_count.groupby('user_id', 'item_id_i', 'item_id_j') \
                                .agg(functions.sum(functions.lit(1)/functions.sqrt(F.col('item_count'))).alias('weight'))

In [9]:
crossing.show(20)
crossing_count.show(20)
crossing_weight.show(20)

                                                                                

+-------+---------+---------+--------------+
|user_id|item_id_i|item_id_j|crossing_count|
+-------+---------+---------+--------------+
|   1090|     2324|     3623|             1|
|   1090|     2324|     1393|             1|
|   1090|     2324|      589|             1|
|   1090|     2324|     3307|             1|
|   1090|     2324|      953|             1|
|   1090|     2324|       47|             1|
|   1090|     2324|      593|             1|
|   1090|     2324|      912|             1|
|   1090|     2324|      480|             1|
|   1090|     2324|     2428|             1|
|   1090|     2324|     1093|             1|
|   1090|     2324|     1894|             1|
|   1090|     2324|     3095|             1|
|   1090|     2324|      318|             1|
|   1090|     2324|     2688|             1|
|   1090|     2324|     1221|             1|
|   1090|     2324|     3198|             1|
|   1090|     2324|      908|             1|
|   1090|     2324|     3219|             1|
|   1090| 

                                                                                

+-------+---------+---------+----------+
|user_id|item_id_i|item_id_j|item_count|
+-------+---------+---------+----------+
|   1090|     1213|     1221|        72|
|   1090|     1213|     3198|        72|
|   1090|     1213|      908|        72|
|   1090|     1213|     3219|        72|
|   1090|     1213|     1704|        72|
|   1090|     1213|     1580|        72|
|   1090|     1213|     2396|        72|
|   1090|     1213|      100|        72|
|   1090|     1213|      913|        72|
|   1090|     1213|     1573|        72|
|   1090|     1213|      924|        72|
|   1090|     1213|     1208|        72|
|   1090|     1213|       50|        72|
|   1090|     1213|     1961|        72|
|   1090|     1213|     2028|        72|
|   1090|     1213|      296|        72|
|   1090|     1213|      923|        72|
|   1090|     1213|      110|        72|
|   1090|     1213|     2710|        72|
|   1090|     1213|     3623|        72|
+-------+---------+---------+----------+
only showing top



+-------+---------+---------+-------------------+
|user_id|item_id_i|item_id_j|             weight|
+-------+---------+---------+-------------------+
|   1090|     3623|     1213|0.11785113019775793|
|   1090|     3623|     1287|0.11785113019775793|
|   1090|     3623|      954|0.11785113019775793|
|   1090|     3623|       36|0.11785113019775793|
|   1090|     3623|     1569|0.11785113019775793|
|   1090|     3623|      805|0.11785113019775793|
|   1090|     3623|      648|0.11785113019775793|
|   1090|     3623|     1246|0.11785113019775793|
|   1090|     3623|     1221|0.11785113019775793|
|   1090|     3623|     3198|0.11785113019775793|
|   1090|     3623|      908|0.11785113019775793|
|   1090|     3623|     3219|0.11785113019775793|
|   1090|     3623|     3386|0.11785113019775793|
|   1090|     3623|     3176|0.11785113019775793|
|   1090|     3623|     1293|0.11785113019775793|
|   1090|     3623|     2763|0.11785113019775793|
|   1090|     3623|     2683|0.11785113019775793|


                                                                                

In [11]:
item_l2_norm = dataset.alias('t1').join(user_bhv_count.alias('t2'), on=(F.col('t1.user_id')==F.col('t2.user_id'))) \
                                  .filter(F.col('t2.item_count')>0) \
                                  .select('t1.user_id', 't1.item_id', 't2.item_count') \
                                  .groupby('t1.item_id') \
                                  .agg(functions.sum(functions.lit(1)/functions.sqrt(F.col('item_count'))).alias('weight'))

In [12]:
item_l2_norm.show(20)

                                                                                

+-------+------------------+
|item_id|            weight|
+-------+------------------+
|   1090| 81.99269299502775|
|    296|174.25336247246761|
|   2294| 42.01543536395503|
|   2088|26.035601014825353|
|   1572|1.2492930379630705|
|   3414| 3.223254147310252|
|   1394|102.81365520142187|
|   3826|27.942823340986447|
|    451|6.4775412310601865|
|   3836|26.890522813416045|
|   2917|31.864419762116942|
|   3526|49.971812645136666|
|   2700| 96.69238961824358|
|   1280| 21.12303163242884|
|   1870| 4.840285370258332|
|      7| 27.93225014000092|
|   3491|3.5809008237929256|
|    475|18.721305453867377|
|   3895| 6.235870463745235|
|    307|15.275806722100004|
+-------+------------------+
only showing top 20 rows



In [21]:
t2 = item_l2_norm.withColumnRenamed('weight', 'normal_weight_i')
t3 = item_l2_norm.withColumnRenamed('weight', 'normal_weight_j')

inner_product = crossing_weight.groupby('item_id_i', 'item_id_j') \
                               .agg(functions.sum(F.col('weight') * F.col('weight')).alias('weight_sum')) \

cossine_similarity = inner_product.alias('t1')\
                                  .join(t2.alias('t2'), on=(F.col('t1.item_id_i')==F.col('t2.item_id'))) \
                                  .join(t3.alias('t3'), on=(F.col('t1.item_id_j')==F.col('t3.item_id'))) \
                                  .withColumn('weight', F.col('t1.weight_sum')/(F.col('t2.normal_weight_i') * F.col('t3.normal_weight_j')))

In [22]:
inner_product.show(20)
cossine_similarity.show(20)

                                                                                

+---------+---------+--------------------+
|item_id_i|item_id_j|          weight_sum|
+---------+---------+--------------------+
|      231|     1242|  0.5405926168909746|
|      541|     2518| 0.38116727450646914|
|      247|     2255| 0.03216623241149229|
|     1960|     3683|  0.5995237500413857|
|      720|     2739| 0.13933215787579392|
|     1307|     3235| 0.10449778192682883|
|     2255|     3591| 0.12084086576764883|
|     1678|     2241|0.040134625712340666|
|     1678|       62|  0.5827342637255613|
|     2659|     1265| 0.08211371411527571|
|     2659|     2263| 0.01599965420524179|
|     1617|     3210|  1.7503995061990834|
|     2890|     1974| 0.21214047071668113|
|     1179|     1147| 0.46145562339740176|
|     1179|     2735| 0.29074619182029676|
|      671|      306| 0.08555086538523371|
|     2021|     1297| 0.40495465331033476|
|     1220|     2797|    3.04807671509008|
|     1961|     1694|  0.6891071215710826|
|     2908|     1242|  0.6802398620994448|
+---------+



+---------+---------+--------------------+-------+------------------+-------+-----------------+--------------------+
|item_id_i|item_id_j|          weight_sum|item_id|   normal_weight_i|item_id|  normal_weight_j|              weight|
+---------+---------+--------------------+-------+------------------+-------+-----------------+--------------------+
|     1008|     1090| 0.13778961966513997|   1008| 5.795555147675379|   1090|81.99269299502774|2.899654860719855E-4|
|     1686|     1090|  0.0907326705073908|   1686|5.4281240409019835|   1090|81.99269299502774|2.038631729673413E-4|
|     3089|     1090| 0.24975705575484006|   3089|17.978697599688754|   1090|81.99269299502774|1.694276892231603...|
|      383|     1090| 0.43932295301338553|    383| 16.74195140303235|   1090|81.99269299502774|3.200388243887443...|
|      700|     1090| 0.11229013089042485|    700|   8.7376439771071|   1090|81.99269299502774|1.567372029400510...|
|     3074|     1090| 0.36914324915815355|   3074|13.75797251578

                                                                                

In [29]:
from pyspark.sql import Window
max_recommend_count=10

## https://stackoverflow.com/questions/64274160/pyspark-collect-list-but-limit-to-max-n-results
result = cossine_similarity \
            .withColumn("rn", F.row_number().over(
                Window.partitionBy('item_id_i').orderBy(F.desc('weight'))
                )) \
            .filter(f"rn <= 10")  \
            .groupBy('item_id_i') \
            .agg(F.collect_list(F.col('item_id_j')).alias('rec_info'))

In [30]:
result.show(20)

                                                                                

+---------+--------------------+
|item_id_i|            rec_info|
+---------+--------------------+
|     1090|[402, 3172, 2777,...|
|     1572|[1364, 666, 1360,...|
|     2088|[572, 1142, 703, ...|
|     2294|[655, 530, 1316, ...|
|      296|[2343, 624, 2438,...|
|     3414|[3687, 1070, 3315...|
|     1394|[2631, 2251, 1574...|
|     3826|[3322, 790, 3800,...|
|      451|[2619, 821, 712, ...|
|     1280|[3530, 545, 167, ...|
|     1870|[2909, 641, 396, ...|
|     2700|[1886, 3381, 977,...|
|     2917|[2308, 310, 1724,...|
|     3526|[3607, 2242, 3881...|
|     3836|[868, 3762, 3803,...|
|        7|[2244, 3904, 3601...|
|      307|[2705, 679, 1165,...|
|     3491|[3493, 3119, 1470...|
|     3895|[192, 3887, 3880,...|
|      475|[2619, 679, 139, ...|
+---------+--------------------+
only showing top 20 rows



In [None]:
sc.stop()