In [1]:
import os
import sys
# 如果当前代码文件运行测试需要加入修改路径，避免出现后导包问题
BASE_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.join(BASE_DIR))

PYSPARK_PYTHON = "/miniconda2/envs/reco_sys/bin/python"
# 当存在多个版本时，不指定很可能会导致出错
os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON

from offline import SparkSessionBase

class UpdateRecall(SparkSessionBase):

    SPARK_APP_NAME = "updateRecall"
    ENABLE_HIVE_SUPPORT = True

    def __init__(self):
        self.spark = self._create_spark_session()

ur = UpdateRecall()

In [2]:
ur.spark.sql("use profile")

DataFrame[]

In [3]:
user_article_click = ur.spark.sql("select * from user_article_basic").\
            select(['user_id', 'article_id', 'clicked'])

In [4]:
user_article_click.show()

+-------------------+----------+-------+
|            user_id|article_id|clicked|
+-------------------+----------+-------+
|1105045287866466304|     14225|  false|
|1106476833370537984|     14208|  false|
|1111189494544990208|     19322|  false|
|1111524501104885760|     44161|  false|
|1112727762809913344|     18172|   true|
|                  1|     44386|   true|
|                  1|     44696|  false|
|                 10|     43907|  false|
|1106473203766657024|     16005|  false|
|1108264901190615040|     15196|  false|
|                 23|     44739|   true|
|                 33|     13570|  false|
|                  1|     17632|  false|
|1106473203766657024|     17665|  false|
|1111189494544990208|     44368|  false|
|                 10|     44368|  false|
|1105093883106164736|     15750|  false|
|1106396183141548032|     19476|  false|
|1111524501104885760|     19233|  false|
|                  2|     44371|   true|
+-------------------+----------+-------+
only showing top

In [5]:
# 更换类型
def change_types(row):
    return row.user_id, row.article_id, int(row.clicked)

user_article_click = user_article_click.rdd.map(change_types).toDF(['user_id', 'article_id', 'clicked'])


In [6]:
user_article_click.show()

+-------------------+----------+-------+
|            user_id|article_id|clicked|
+-------------------+----------+-------+
|1105045287866466304|     14225|      0|
|1106476833370537984|     14208|      0|
|1111189494544990208|     19322|      0|
|1111524501104885760|     44161|      0|
|1112727762809913344|     18172|      1|
|                  1|     44386|      1|
|                  1|     44696|      0|
|                 10|     43907|      0|
|1106473203766657024|     16005|      0|
|1108264901190615040|     15196|      0|
|                 23|     44739|      1|
|                 33|     13570|      0|
|                  1|     17632|      0|
|1106473203766657024|     17665|      0|
|1111189494544990208|     44368|      0|
|                 10|     44368|      0|
|1105093883106164736|     15750|      0|
|1106396183141548032|     19476|      0|
|1111524501104885760|     19233|      0|
|                  2|     44371|      1|
+-------------------+----------+-------+
only showing top

In [7]:
user_article_click.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- article_id: long (nullable = true)
 |-- clicked: long (nullable = true)



In [8]:
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
# 用户和文章ID超过ALS最大整数值，需要使用StringIndexer进行转换
user_id_indexer = StringIndexer(inputCol='user_id', outputCol='als_user_id')
article_id_indexer = StringIndexer(inputCol='article_id', outputCol='als_article_id')
pip = Pipeline(stages=[user_id_indexer, article_id_indexer])
pip_fit = pip.fit(user_article_click)
als_user_article_click = pip_fit.transform(user_article_click)


In [9]:
als_user_article_click.show()

+-------------------+----------+-------+-----------+--------------+
|            user_id|article_id|clicked|als_user_id|als_article_id|
+-------------------+----------+-------+-----------+--------------+
|1105045287866466304|     14225|      0|        4.0|          15.0|
|1106476833370537984|     14208|      0|        2.0|           2.0|
|1111189494544990208|     19322|      0|        1.0|         133.0|
|1111524501104885760|     44161|      0|        9.0|          37.0|
|1112727762809913344|     18172|      1|       12.0|          54.0|
|                  1|     44386|      1|       10.0|          11.0|
|                  1|     44696|      0|       10.0|          97.0|
|                 10|     43907|      0|        3.0|           1.0|
|1106473203766657024|     16005|      0|        5.0|          32.0|
|1108264901190615040|     15196|      0|        6.0|           7.0|
|                 23|     44739|      1|       17.0|           4.0|
|                 33|     13570|      0|       1

In [10]:
from pyspark.ml.recommendation import ALS
# 模型训练和推荐默认每个用户固定文章个数
als = ALS(userCol='als_user_id', itemCol='als_article_id', ratingCol='clicked', checkpointInterval=1)
model = als.fit(als_user_article_click)
recall_res = model.recommendForAllUsers(100)

In [11]:
recall_res.rdd.take(5)

[Row(als_user_id=12, recommendations=[Row(als_article_id=206, rating=0.24584342539310455), Row(als_article_id=217, rating=0.24584342539310455), Row(als_article_id=89, rating=0.20355388522148132), Row(als_article_id=50, rating=0.1987779140472412), Row(als_article_id=75, rating=0.16880828142166138), Row(als_article_id=49, rating=0.16649414598941803), Row(als_article_id=55, rating=0.15852127969264984), Row(als_article_id=229, rating=0.15697592496871948), Row(als_article_id=209, rating=0.15697592496871948), Row(als_article_id=261, rating=0.15697592496871948), Row(als_article_id=153, rating=0.15697592496871948), Row(als_article_id=233, rating=0.15697592496871948), Row(als_article_id=223, rating=0.15697592496871948), Row(als_article_id=160, rating=0.15697592496871948), Row(als_article_id=174, rating=0.15697592496871948), Row(als_article_id=194, rating=0.15697592496871948), Row(als_article_id=204, rating=0.15697592496871948), Row(als_article_id=202, rating=0.15697592496871948), Row(als_articl

In [12]:
# 保存原来的下表映射关系
refection_user = als_user_article_click.groupBy(['user_id']).max('als_user_id').withColumnRenamed(
'max(als_user_id)', 'als_user_id')
refection_article = als_user_article_click.groupBy(['article_id']).max('als_article_id').withColumnRenamed(
'max(als_article_id)', 'als_article_id')

In [13]:
refection_user.show()

+-------------------+-----------+
|            user_id|als_user_id|
+-------------------+-----------+
|1106473203766657024|        5.0|
|1103195673450250240|        7.0|
|1105045287866466304|        4.0|
|1111524501104885760|        9.0|
|1105105185656537088|        8.0|
|1113316420155867136|       18.0|
|                 33|       13.0|
|                  1|       10.0|
|1113244157343694848|       15.0|
|                 10|        3.0|
|1113053603926376448|       20.0|
|1112727762809913344|       12.0|
|                  2|       11.0|
|                  4|       14.0|
|1106476833370537984|        2.0|
|1106396183141548032|        0.0|
|                 38|       16.0|
|                 23|       17.0|
|1108264901190615040|        6.0|
|1111189494544990208|        1.0|
+-------------------+-----------+
only showing top 20 rows



In [14]:
refection_article.show()

+-------------------+--------------+
|         article_id|als_article_id|
+-------------------+--------------+
|              13401|         146.0|
|              14805|           0.0|
|1112593324574769152|         119.0|
|              44013|         130.0|
|              16158|         118.0|
|              17454|         216.0|
|             134736|         215.0|
|              14839|         258.0|
|              14883|         138.0|
|              14972|          78.0|
|             134730|         170.0|
|              14442|         203.0|
|              17802|         190.0|
|              19259|          85.0|
|              17605|          89.0|
|              17748|         123.0|
|              18038|          45.0|
|              17866|         228.0|
|              13357|         173.0|
|              44693|         109.0|
+-------------------+--------------+
only showing top 20 rows



In [15]:
recall_res = recall_res.join(refection_user, on=['als_user_id'], how='left').select(
['als_user_id', 'recommendations', 'user_id'])

In [16]:
recall_res.show()

+-----------+--------------------+-------------------+
|als_user_id|     recommendations|            user_id|
+-----------+--------------------+-------------------+
|          8|[[263,0.36433065]...|1105105185656537088|
|          0|[[251,0.64322466]...|1106396183141548032|
|          7|[[120,0.06348746]...|1103195673450250240|
|         18|[[0,0.0], [10,0.0...|1113316420155867136|
|          1|[[93,0.42809626],...|1111189494544990208|
|          4|[[263,0.01827759]...|1105045287866466304|
|         11|[[206,0.9502387],...|                  2|
|         14|[[206,0.6877553],...|                  4|
|          3|[[261,0.041544966...|                 10|
|         19|[[25,0.24205461],...|1105093883106164736|
|          2|[[263,0.63030016]...|1106476833370537984|
|         17|[[120,0.3623774],...|                 23|
|         10|[[206,0.7823634],...|                  1|
|         13|[[86,0.2458632], ...|                 33|
|          6|[[263,0.16696511]...|1108264901190615040|
|         

In [17]:
import pyspark.sql.functions as F
recall_res = recall_res.withColumn('als_article_id', F.explode('recommendations')).drop('recommendations')


In [18]:
recall_res.show()

+-----------+-------------------+----------------+
|als_user_id|            user_id|  als_article_id|
+-----------+-------------------+----------------+
|          8|1105105185656537088|[263,0.36433065]|
|          8|1105105185656537088|[105,0.25173277]|
|          8|1105105185656537088|[115,0.25173277]|
|          8|1105105185656537088|[251,0.17713536]|
|          8|1105105185656537088|[235,0.17713536]|
|          8|1105105185656537088|[222,0.17713536]|
|          8|1105105185656537088|[176,0.17713536]|
|          8|1105105185656537088|[236,0.17713536]|
|          8|1105105185656537088| [15,0.15716684]|
|          8|1105105185656537088|[100,0.13071595]|
|          8|1105105185656537088| [36,0.11427968]|
|          8|1105105185656537088| [12,0.10651863]|
|          8|1105105185656537088|  [5,0.06923568]|
|          8|1105105185656537088| [6,0.041910112]|
|          8|1105105185656537088|[20,0.039821483]|
|          8|1105105185656537088| [50,0.03866965]|
|          8|110510518565653708

In [20]:
def _article_id(row):
    return row.als_user_id, row.user_id, row.als_article_id[0]

In [21]:
als_recall = recall_res.rdd.map(_article_id).toDF(['als_user_id', 'user_id', 'als_article_id'])


In [22]:
als_recall.show()

+-----------+-------------------+--------------+
|als_user_id|            user_id|als_article_id|
+-----------+-------------------+--------------+
|          8|1105105185656537088|           263|
|          8|1105105185656537088|           105|
|          8|1105105185656537088|           115|
|          8|1105105185656537088|           251|
|          8|1105105185656537088|           235|
|          8|1105105185656537088|           222|
|          8|1105105185656537088|           176|
|          8|1105105185656537088|           236|
|          8|1105105185656537088|            15|
|          8|1105105185656537088|           100|
|          8|1105105185656537088|            36|
|          8|1105105185656537088|            12|
|          8|1105105185656537088|             5|
|          8|1105105185656537088|             6|
|          8|1105105185656537088|            20|
|          8|1105105185656537088|            50|
|          8|1105105185656537088|            40|
|          8|1105105

In [23]:
als_recall = als_recall.join(refection_article, on=['als_article_id'], how='left').select(
  ['user_id', 'article_id'])

In [24]:
als_recall.show()

+-------------------+----------+
|            user_id|article_id|
+-------------------+----------+
|1113316420155867136|    134730|
|                 33|    134730|
|1113053603926376448|    134730|
|1111524501104885760|    134730|
|                 38|    134730|
|1105105185656537088|     18127|
|1106396183141548032|     18127|
|1103195673450250240|     18127|
|1113316420155867136|     18127|
|1111189494544990208|     18127|
|1105045287866466304|     18127|
|                  2|     18127|
|                  4|     18127|
|                 10|     18127|
|1106476833370537984|     18127|
|                 23|     18127|
|                  1|     18127|
|1108264901190615040|     18127|
|1113053603926376448|     18127|
|1106473203766657024|     18127|
+-------------------+----------+
only showing top 20 rows



In [25]:
ur.spark.sql("use toutiao")
news_article_basic = ur.spark.sql("select article_id, channel_id from news_article_basic")

als_recall = als_recall.join(news_article_basic, on=['article_id'], how='left')

In [26]:
als_recall.show()

+----------+-------------------+----------+
|article_id|            user_id|channel_id|
+----------+-------------------+----------+
|    134730|1113316420155867136|        18|
|    134730|                 33|        18|
|    134730|1113053603926376448|        18|
|    134730|1111524501104885760|        18|
|    134730|                 38|        18|
|     18127|1105105185656537088|        18|
|     18127|1106396183141548032|        18|
|     18127|1103195673450250240|        18|
|     18127|1113316420155867136|        18|
|     18127|1111189494544990208|        18|
|     18127|1105045287866466304|        18|
|     18127|                  2|        18|
|     18127|                  4|        18|
|     18127|                 10|        18|
|     18127|1106476833370537984|        18|
|     18127|                 23|        18|
|     18127|                  1|        18|
|     18127|1108264901190615040|        18|
|     18127|1113053603926376448|        18|
|     18127|1106473203766657024|

In [27]:
als_recall = als_recall.groupBy(['user_id', 'channel_id']).agg(F.collect_list('article_id')).withColumnRenamed(
  'collect_list(article_id)', 'article_list')

In [28]:
als_recall.show()

+-------------------+----------+--------------------+
|            user_id|channel_id|        article_list|
+-------------------+----------+--------------------+
|1113244157343694848|         7|    [141437, 141469]|
|1108264901190615040|         7|    [141437, 141469]|
|1106396183141548032|         7|    [141437, 141469]|
|1103195673450250240|         5|            [141440]|
|1108264901190615040|        18|[18127, 16421, 18...|
|1106473203766657024|        18|[18127, 16421, 18...|
|                  4|         7|    [141437, 141469]|
|                  2|         5|            [141440]|
|                 23|         7|    [141437, 141469]|
|1106396183141548032|        13|            [141431]|
|1113053603926376448|        13|            [141431]|
|1103195673450250240|      null|[1112608068731928...|
|1113053603926376448|         7|            [141437]|
|1111524501104885760|      null|[1112592065390182...|
|                 10|         7|    [141437, 141469]|
|1106476833370537984|       

### 离线基于内容相似召回

In [30]:
ur.spark.sql("use profile")
user_article_basic = ur.spark.sql("select * from user_article_basic")
user_article_basic = user_article_basic.filter('clicked=True')

In [31]:
user_article_basic.show()

+-------------------+-------------------+----------+----------+------+-------+---------+--------+---------+
|            user_id|        action_time|article_id|channel_id|shared|clicked|collected|exposure|read_time|
+-------------------+-------------------+----------+----------+------+-------+---------+--------+---------+
|1112727762809913344|2019-04-03 12:51:57|     18172|        18| false|   true|     true|    true|    19413|
|                  1|2019-03-07 16:57:34|     44386|        18| false|   true|    false|    true|    17850|
|                 23|2019-04-03 08:10:23|     44739|        18| false|   true|    false|    true|    14216|
|                  2|2019-03-05 10:19:54|     44371|        18| false|   true|    false|    true|      938|
|                  2|2019-03-07 10:06:20|     18103|        18| false|   true|    false|    true|      648|
|1111189494544990208|2019-03-28 16:56:55|     44737|        18| false|   true|    false|    true|     4138|
|                  2|2019-03