In [1]:
# 用sparkml的als进行训练
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.sql import Row
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql.types import *

spark = SparkSession(sc)
spark

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
1,application_1603600007911_0006,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<pyspark.sql.session.SparkSession object at 0x7fde2c1a2b50>

In [4]:
# 读取pandas的dataframe
als_data = spark.read.csv("./sort_als_data.csv", header=True)
als_data.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+-----------+----------+
|als_anchor_id|als_item_id|if_choosed|
+-------------+-----------+----------+
|            0|          0|         1|
|            0|     306844|         1|
|            0|     306845|         1|
|            0|     306846|         1|
|            0|     306847|         1|
|            0|     306848|         1|
|            0|     306849|         1|
|            0|       3034|         1|
|            0|     306850|         1|
|            0|     306851|         1|
|            0|     306852|         1|
|            0|     306853|         1|
|            0|     306854|         1|
|            0|     306855|         1|
|            0|     306856|         1|
|            0|     306857|         1|
|            0|     306858|         1|
|            0|     306859|         1|
|            0|     306860|         1|
|            0|     306861|         1|
+-------------+-----------+----------+
only showing top 20 rows

In [5]:
# 转化dataframe中的数值类型
als_data_num = als_data.withColumn("als_anchor_id", als_data["als_anchor_id"].cast(IntegerType()))
als_data_num = als_data_num.withColumn("als_item_id", als_data["als_item_id"].cast(IntegerType()))
als_data_num = als_data_num.withColumn("if_choosed", als_data["if_choosed"].cast(IntegerType()))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
# 切分数据集
(train, test) = als_data_num.randomSplit([0.8, 0.2])

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
# 确认数据类型
train.dtypes

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[('als_anchor_id', 'int'), ('als_item_id', 'int'), ('if_choosed', 'int')]

In [8]:
# 构造als，设定基本参数
als = ALS(maxIter=5, regParam=0.01, userCol="als_anchor_id", itemCol="als_item_id", ratingCol="if_choosed",
          coldStartStrategy="drop")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

- 训练过程

In [9]:
# 模型训练
als_model = als.fit(train)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

- 推理过程

In [10]:
# 推理测试集
predictions = als_model.transform(test)
predictions.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+-----------+----------+-------------+
|als_anchor_id|als_item_id|if_choosed|   prediction|
+-------------+-----------+----------+-------------+
|         1005|         12|         0|  0.002290467|
|          633|         12|         0| 0.0026946266|
|          876|         12|         0| 7.7563425E-5|
|          874|         12|         0|-0.0023558827|
|          406|         12|         0|-0.0012933537|
|           76|         12|         0| 0.0025947276|
|          811|         12|         0| 0.0020250394|
|           12|         12|         0|  0.001838097|
|          984|         12|         0|  0.002460972|
|          727|         12|         0|  9.310976E-4|
|         1061|         12|         0| 4.4720626E-4|
|          444|         12|         0|-0.0016402653|
|          774|         12|         0|-0.0019777166|
|          912|         12|         0| 4.1634834E-5|
|          992|         12|         0| 0.0012556601|
|          442|         12|         0|  7.3783

In [11]:
# 评估结果
evaluator = RegressionEvaluator(metricName="rmse", labelCol="if_choosed",
                                predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Root-mean-square error = 0.4848365898428372

In [12]:
# 召回结果
anchorRecs = als_model.recommendForAllUsers(100)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [13]:
anchorRecs.collect()[:5]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[Row(als_anchor_id=540, recommendations=[Row(als_item_id=2401, rating=3.116373062133789), Row(als_item_id=59210, rating=2.567249059677124), Row(als_item_id=84123, rating=2.449601650238037), Row(als_item_id=92272, rating=2.3006582260131836), Row(als_item_id=68190, rating=2.2953109741210938), Row(als_item_id=92544, rating=2.2640693187713623), Row(als_item_id=10381, rating=2.2574245929718018), Row(als_item_id=166506, rating=2.230597972869873), Row(als_item_id=59156, rating=2.2146072387695312), Row(als_item_id=150684, rating=2.187387228012085), Row(als_item_id=113709, rating=2.1488614082336426), Row(als_item_id=85566, rating=2.13942289352417), Row(als_item_id=175524, rating=2.138134002685547), Row(als_item_id=40458, rating=2.122217893600464), Row(als_item_id=7360, rating=2.102689027786255), Row(als_item_id=2344, rating=2.0953972339630127), Row(als_item_id=81818, rating=2.0919759273529053), Row(als_item_id=55750, rating=2.088259696960449), Row(als_item_id=66827, rating=2.085789203643799), R