In [12]:
import os
import sys
# 如果当前代码文件运行测试需要加入修改路径，避免出现后导包问题
BASE_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.join(BASE_DIR))

PYSPARK_PYTHON = "/root/miniconda3/envs/test/bin/python"
# 当存在多个版本时，不指定很可能会导致出错
os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON

from offline import SparkSessionBase
import pyhdfs
import time


class UpdateRecall(SparkSessionBase):
    """离线相关处理程序
    """
    SPARK_APP_NAME = "updateRecall"
    ENABLE_HIVE_SUPPORT = True

    SPARK_EXECUTOR_MEMORY = "4g"

    def __init__(self):

        self.spark = self._create_spark_session()

ur = UpdateRecall()

#### 读取用户对文章的行为表

In [17]:
ur.spark.sql("use profile")
user_article_basic = ur.spark.sql("select user_id,article_id,clicked from user_article_basic")

In [18]:
# 将clicked的类型转为int
def convert_boolean_int(row):
    return row.user_id, row.article_id, int(row.clicked)
        
user_article_basic = user_article_basic.rdd.map(convert_boolean_int).toDF(['user_id', 'article_id', 'clicked'])

user_article_basic.show()

+-------------------+----------+-------+
|            user_id|article_id|clicked|
+-------------------+----------+-------+
|1105045287866466304|     14225|      0|
|1106476833370537984|     14208|      0|
|                  1|     44386|      1|
|                  1|     44696|      0|
|                 10|     43907|      0|
|1105093883106164736|    140357|      0|
|1106473203766657024|     16005|      0|
|                 33|     13570|      0|
|                  1|     17632|      0|
|1106473203766657024|     17665|      0|
|                 10|     44368|      0|
|1105093883106164736|     15750|      0|
|                  2|     44371|      1|
|1105105185656537088|     44180|      0|
|1106396183141548032|     43885|      0|
|                  4|     15196|      0|
|                  4|     18701|      0|
|1105045287866466304|     14668|      0|
|1105045287866466304|     14805|      0|
|1106473203766657024|     44664|      0|
+-------------------+----------+-------+
only showing top

#### 用户ID与文章ID处理

In [20]:
# 用户和文章ID超过ALS最大整数值，需要使用StringIndexer进行转换
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

user_indexer = StringIndexer(inputCol='user_id',outputCol='als_user_id')
article_indexer = StringIndexer(inputCol='article_id',outputCol='als_article_id')
pip = Pipeline(stages=[user_indexer,article_indexer])
pip_model = pip.fit(user_article_basic)
als_user_article = pip_model.transform(user_article_basic)
als_user_article.show()

+-------------------+----------+-------+-----------+--------------+
|            user_id|article_id|clicked|als_user_id|als_article_id|
+-------------------+----------+-------+-----------+--------------+
|1105045287866466304|     14225|      0|        2.0|          29.0|
|1106476833370537984|     14208|      0|        0.0|           8.0|
|                  1|     44386|      1|        9.0|          13.0|
|                  1|     44696|      0|        9.0|          58.0|
|                 10|     43907|      0|        1.0|           6.0|
|1105093883106164736|    140357|      0|        3.0|          23.0|
|1106473203766657024|     16005|      0|        5.0|          39.0|
|                 33|     13570|      0|       11.0|          31.0|
|                  1|     17632|      0|        9.0|          90.0|
|1106473203766657024|     17665|      0|        5.0|          36.0|
|                 10|     44368|      0|        1.0|           2.0|
|1105093883106164736|     15750|      0|        

#### 模型训练与推荐

In [21]:
from pyspark.ml.recommendation import ALS

In [23]:
als = ALS(userCol='als_user_id',itemCol='als_article_id',ratingCol='clicked')
als_model = als.fit(als_user_article)

In [76]:
recall_res = als_model.recommendForAllUsers(100)
recall_res.show()

+-----------+--------------------+
|als_user_id|     recommendations|
+-----------+--------------------+
|         12|[[0, 0.0], [10, 0...|
|          1|[[112, 0.15105331...|
|         13|[[0, 0.0], [10, 0...|
|          6|[[71, 8.8845376E-...|
|          3|[[64, 0.24442917]...|
|          5|[[189, 7.170044E-...|
|          9|[[112, 0.5766499]...|
|          4|[[134, 0.46485025...|
|          8|[[189, 0.74185854...|
|          7|[[189, 0.28287688...|
|         10|[[189, 1.0544213]...|
|         11|[[64, 0.30284104]...|
|          2|[[189, 3.600872E-...|
|          0|[[134, 0.6459866]...|
+-----------+--------------------+



#### 建立真实用户id与索引的映射，真实文章id与索引的映射

In [72]:
user_real_index = als_user_article.groupBy(['user_id']).max('als_user_id').withColumnRenamed('max(als_user_id)','als_user_id')
user_real_index.show(5)

+-------------------+-----------+
|            user_id|als_user_id|
+-------------------+-----------+
|1106473203766657024|        5.0|
|1103195673450250240|        6.0|
|1105045287866466304|        2.0|
|1105105185656537088|        4.0|
|                 33|       11.0|
+-------------------+-----------+
only showing top 5 rows



In [73]:
article_real_index = als_user_article.groupBy(['article_id']).max('als_article_id').withColumnRenamed('max(als_article_id)','als_article_id')
article_real_index.show(5)

+----------+--------------+
|article_id|als_article_id|
+----------+--------------+
|     13401|         217.0|
|     14805|           1.0|
|     44013|         190.0|
|     16158|          53.0|
|     17454|         164.0|
+----------+--------------+
only showing top 5 rows



In [77]:
# 用户id
recall_res = recall_res.join(user_real_index, on=['als_user_id'], how='left').select(['als_user_id', 'recommendations', 'user_id'])
recall_res.show(5)

+-----------+--------------------+-------------------+
|als_user_id|     recommendations|            user_id|
+-----------+--------------------+-------------------+
|          8|[[189, 0.74185854...|                  4|
|          0|[[134, 0.6459866]...|1106476833370537984|
|          7|[[189, 0.28287688...|1106396183141548032|
|          1|[[112, 0.15105331...|                 10|
|          4|[[134, 0.46485025...|1105105185656537088|
+-----------+--------------------+-------------------+
only showing top 5 rows



In [78]:
# 文章id
import pyspark.sql.functions as F
recall_res = recall_res.withColumn('als_article_id',F.explode('recommendations')).drop('recommendations').select(["user_id","als_article_id"])
recall_res.show(5)

+-------+-----------------+
|user_id|   als_article_id|
+-------+-----------------+
|      4|[189, 0.74185854]|
|      4| [65, 0.62069404]|
|      4| [76, 0.62069404]|
|      4| [56, 0.48931167]|
|      4| [50, 0.42007074]|
+-------+-----------------+
only showing top 5 rows



In [79]:
def get_article_index(row):
    return row.user_id,row.als_article_id[0]

recall_res = recall_res.rdd.map(get_article_index).toDF(["user_id","als_article_id"])
recall_res.show(5)

+-------+--------------+
|user_id|als_article_id|
+-------+--------------+
|      4|           189|
|      4|            65|
|      4|            76|
|      4|            56|
|      4|            50|
+-------+--------------+
only showing top 5 rows



In [80]:
recall_res = recall_res.join(article_real_index,on=['als_article_id'],how='left').select('user_id','article_id')
recall_res.show(20)

+-------------------+----------+
|            user_id|article_id|
+-------------------+----------+
|                  4|     44412|
|1106476833370537984|     44412|
|1106396183141548032|     44412|
|                 10|     44412|
|1105105185656537088|     44412|
|                  2|     44412|
|                  1|     44412|
|                  4|     44761|
|1106476833370537984|     44761|
|1106396183141548032|     44761|
|                 10|     44761|
|1105105185656537088|     44761|
|                  2|     44761|
|                 23|     44761|
|                  1|     44761|
|                 38|     44761|
|                 33|     13090|
|1105093883106164736|     13090|
|1105045287866466304|     13090|
|1103195673450250240|     13090|
+-------------------+----------+
only showing top 20 rows



In [81]:
def change_article_id(row):
    if row.article_id == 14225:
        article_id = 1
    elif row.article_id == 14208:
        article_id = 2
    elif row.article_id == 44386:
        article_id = 3
    elif row.article_id == 44696:
        article_id = 4
    elif row.article_id == 43907:
        article_id = 5
    elif row.article_id == 140357:
        article_id = 6
    elif row.article_id == 16005:
        article_id = 7 
    elif row.article_id == 13570:
        article_id = 8
    elif row.article_id == 17632:
        article_id = 9
    else:
        article_id = 10
    return row.user_id,article_id

recall_change = recall_res.rdd.map(change_article_id).toDF(["user_id","article_id"])
recall_change.show()

+-------------------+----------+
|            user_id|article_id|
+-------------------+----------+
|                  4|        10|
|1106476833370537984|        10|
|1106396183141548032|        10|
|                 10|        10|
|1105105185656537088|        10|
|                  2|        10|
|                  1|        10|
|                  4|        10|
|1106476833370537984|        10|
|1106396183141548032|        10|
|                 10|        10|
|1105105185656537088|        10|
|                  2|        10|
|                 23|        10|
|                  1|        10|
|                 38|        10|
|                 33|        10|
|1105093883106164736|        10|
|1105045287866466304|        10|
|1103195673450250240|        10|
+-------------------+----------+
only showing top 20 rows



#### 获取每个文章的频道，按频道分组

In [38]:
ur.spark.sql("use article")
article_data = ur.spark.sql("select article_id,channel_id from article_data limit 10")
article_data.show()

+----------+----------+
|article_id|channel_id|
+----------+----------+
|         1|        17|
|         2|        17|
|         3|        17|
|         4|        17|
|         5|        17|
|         6|        17|
|         7|        17|
|         8|        17|
|         9|        17|
|        10|        17|
+----------+----------+



In [90]:
recall_channel = recall_change.join(article_data,on=['article_id'],how='left')
recall_channel.show()

+----------+-------------------+----------+
|article_id|            user_id|channel_id|
+----------+-------------------+----------+
|        10|                  4|        17|
|        10|1106476833370537984|        17|
|        10|1106396183141548032|        17|
|        10|                 10|        17|
|        10|1105105185656537088|        17|
|        10|                  2|        17|
|        10|                  1|        17|
|        10|                  4|        17|
|        10|1106476833370537984|        17|
|        10|1106396183141548032|        17|
|        10|                 10|        17|
|        10|1105105185656537088|        17|
|        10|                  2|        17|
|        10|                 23|        17|
|        10|                  1|        17|
|        10|                 38|        17|
|        10|                 33|        17|
|        10|1105093883106164736|        17|
|        10|1105045287866466304|        17|
|        10|1103195673450250240|

In [92]:
als_recall = recall_channel.groupBy(["user_id","channel_id"]).agg(F.collect_set("article_id")).withColumnRenamed("collect_set(article_id)","article_list")
als_recall.show()

+-------------------+----------+-------------------+
|            user_id|channel_id|       article_list|
+-------------------+----------+-------------------+
|                  1|        17|   [9, 2, 3, 10, 8]|
|1103195673450250240|        17|   [9, 1, 2, 10, 8]|
|                  4|        17|[9, 1, 6, 3, 10, 8]|
|                  2|        17|      [1, 3, 10, 8]|
|1106476833370537984|        17|   [1, 6, 3, 10, 8]|
|1106396183141548032|        17|   [1, 6, 3, 10, 8]|
|                 10|        17|   [1, 2, 3, 10, 8]|
|1105093883106164736|        17|   [9, 2, 6, 3, 10]|
|                 33|        17|      [9, 2, 10, 8]|
|                 23|        17|   [9, 6, 3, 10, 8]|
|1105105185656537088|        17|   [1, 6, 3, 10, 8]|
|1105045287866466304|        17|[9, 1, 2, 6, 10, 8]|
|1106473203766657024|        17|[9, 1, 2, 6, 10, 8]|
|                 38|        17|   [9, 6, 3, 10, 8]|
+-------------------+----------+-------------------+



#### als召回结果存储

In [94]:
def save_offline_recall_hbase(partition):
    import happybase
    pool = happybase.ConnectionPool(size=10,host='hadoop1')
    
    for row in partition:
        with pool.connection() as conn:
            # 获取历史看过该频道的文章
            history_table = conn.table('history_recall1')
            data = history_table.cells('reco:his:{}'.format(row.user_id).encode(),
                                      'channel:{}'.format(row.channel_id).encode())
            history = []
            if len(data) > 1:
                for i in data:
                    history.extend(i)
                    
            # 过滤掉历史召回数据
            reco_res = list(set(row.article_list) - set(history))
            
            if reco_res:

                table = conn.table('cb_recall1')
                
                # 将过滤后的数据存放在推荐频道
                table.put('recall:user:{}'.format(row.user_id).encode(),
                        {'als:{}'.format(row.channel_id).encode():str(reco_res).encode()})
                
                # 同时将结果放入历史召回表
                history_table.put('reco:his:{}'.format(row.user_id).encode(),
                                  {'channel:{}'.format(row.channel_id):str(reco_res).encode()})
            conn.close()
            
als_recall.rdd.foreachPartition(save_offline_recall_hbase)

### 内容召回

In [100]:
# 获取与用户点击的文章的相似的文章
ur.spark.sql("use profile")
user_article_basic = ur.spark.sql("select * from user_article_basic")
user_article_basic = user_article_basic.filter("clicked=True")
user_article_basic.show()

+-------------------+-------------------+----------+----------+------+-------+---------+--------+---------+
|            user_id|        action_time|article_id|channel_id|shared|clicked|collected|exposure|read_time|
+-------------------+-------------------+----------+----------+------+-------+---------+--------+---------+
|                  1|2019-03-07 16:57:34|     44386|        18| false|   true|    false|    true|    17850|
|                  2|2019-03-05 10:19:54|     44371|        18| false|   true|    false|    true|      938|
|                  2|2019-03-07 10:06:20|     18103|        18| false|   true|    false|    true|      648|
|                  2|2019-03-15 14:51:12|     43894|        18| false|   true|    false|    true|      928|
|                  2|2019-03-07 10:05:29|     18836|        18| false|   true|    false|    true|      835|
|                  2|2019-03-07 10:06:57|     14961|        18| false|   true|    false|    true|     5248|
|                  2|2019-03

In [103]:
def change_user_article(row):
    if row.article_id == 44386:
        article_id = 1
    elif row.article_id == 44371:
        article_id = 2
    elif row.article_id == 18103:
        article_id = 3
    elif row.article_id == 43894:
        article_id = 4
    elif row.article_id == 18836:
        article_id = 5
    elif row.article_id == 14961:
        article_id = 6
    elif row.article_id == 18609:
        article_id = 7 
    elif row.article_id == 18353:
        article_id = 8
    elif row.article_id == 16062:
        article_id = 9
    else:
        article_id = 10
    return row.user_id,row.channel_id,article_id,row.clicked

user_article_index = user_article_basic.rdd.map(change_user_article).toDF(["user_id","channel_id","article_id","clicked"])
user_article_index.show()

+-------------------+----------+----------+-------+
|            user_id|channel_id|article_id|clicked|
+-------------------+----------+----------+-------+
|                  1|        18|         1|   true|
|                  2|        18|         2|   true|
|                  2|        18|         3|   true|
|                  2|        18|         4|   true|
|                  2|        18|         5|   true|
|                  2|        18|         6|   true|
|                  2|        18|        10|   true|
|                 33|        18|        10|   true|
|1105093883106164736|        18|        10|   true|
|                  2|        18|        10|   true|
|                  4|        18|         7|   true|
|                  2|        18|        10|   true|
|                  2|        18|        10|   true|
|                  4|        18|        10|   true|
|1106476833370537984|        18|         9|   true|
|                  2|        18|        10|   true|
|           

In [106]:
def save_content_filter_history_recall(partition):
    import happybase
    pool = happybase.ConnectionPool(size=10,host="hadoop1")
    
    # 获取相似文章
    with pool.connection() as conn:
        similar_table = conn.table("article_similar")
        
        for row in partition:
            similar_article = similar_table.row(str(row.article_id).encode(),
                                               columns=[b'similar'])
            #  相似文章排序，只拿出前几篇
            _str = sorted(similar_article.items(),key=lambda item:item[1],reverse=True)
            if _str:
                reco_article = [int(i[0].split(b':')[1]) for i in _str][:10]
                
                # 获取历史文章
                history_table = conn.table("history_recall1")
                data = history_table.cells('reco:his:{}'.format(row.user_id).encode(),
                                          'channel:{}'.format(row.channel_id).encode())
                
                history = []
                if len(data) > 1:
                    for i in data:
                        history.extend(i)
                        
                # 过滤历史文章
                reco_res = list(set(reco_article) - set(history))
                
                # 将结果放入内容召回表及历史召回表
                content_table = conn.table("cb_recall1")
                content_table.put("recall_user:{}".format(row.user_id).encode(),
                                 {'content:{}'.format(row.channel_id).encode():str(reco_res)})
                
                # 放入历史召回表
                history_table.put("reco:his:{}".format(row.user_id).encode(),
                                 {'channel:{}'.format(row.channel_id).encode():str(reco_res)})
        conn.close()
        
user_article_index.foreachPartition(save_content_filter_history_recall)       