In [44]:
from pyspark.ml.recommendation import ALS
from pyspark.sql import Row, SparkSession

spark = SparkSession.builder \
    .master('local') \
    .appName('recommendation') \
    .getOrCreate()

In [63]:
data_path = '/Users/hyunseokjung/data/spark_guide'

ratings = spark.read.text(f"{data_path}/sample_movielens_ratings.txt") \
    .rdd.toDF() \
    .selectExpr("split(value, '::') as col") \
    .selectExpr(
        "cast(col[0] as int) as userId",
        "cast(col[1] as int) as movieId",
        "cast(col[2] as float) as rating",
        "cast(col[3] as long) as timestamp"
    )

In [46]:
train,test = ratings.randomSplit([0.8, 0.2])

In [47]:
als = ALS() \
    .setMaxIter(5) \
    .setRegParam(0.01) \
    .setUserCol("userId") \
    .setItemCol("movieId") \
    .setRatingCol("rating")

In [48]:
print(als.explainParams())

alpha: alpha for implicit preference (default: 1.0)
blockSize: block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data. (default: 4096)
checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. (default: 10)
coldStartStrategy: strategy for dealing with unknown or new users/items at prediction time. This may be useful in cross-validation or production scenarios, for handling user/item ids the model has not seen in the training data. Supported values: 'nan', 'drop'. (default: nan)
finalStorageLevel: StorageLevel for ALS model factors. (default: MEMORY_AND_DISK)
implicitPrefs: whether to use implicit preference (default: False)
intermediateStorageLevel: StorageLevel for interme

In [64]:
alsModel = als.fit(train)
predictions = alsModel.transform(test)

In [66]:
predictions.show(3)

+------+-------+------+----------+----------+
|userId|movieId|rating| timestamp|prediction|
+------+-------+------+----------+----------+
|     0|      3|   1.0|1424380312| 0.6339553|
|     0|      5|   2.0|1424380312| 1.4654416|
|     0|     21|   1.0|1424380312| 1.6357095|
+------+-------+------+----------+----------+
only showing top 3 rows



In [74]:
alsModel.recommendForAllUsers(3) \
    .selectExpr("userId", "explode(recommendations)") \
    .show(6)
alsModel.recommendForAllItems(3) \
    .selectExpr("movieId", "explode(recommendations)") \
    .show(6)

                                                                                

+------+---------------+
|userId|            col|
+------+---------------+
|    20|{22, 4.6268606}|
|    20| {75, 4.272371}|
|    20|{54, 4.1329565}|
|    10| {38, 4.204379}|
|    10|{46, 3.8890886}|
|    10| {74, 3.536286}|
+------+---------------+
only showing top 6 rows





+-------+---------------+
|movieId|            col|
+-------+---------------+
|     20|{26, 4.9450316}|
|     20| {17, 4.641791}|
|     20| {2, 4.1928797}|
|     40|  {6, 4.387648}|
|     40| {2, 3.8655288}|
|     40| {7, 3.4234028}|
+-------+---------------+
only showing top 6 rows



                                                                                

In [77]:
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator() \
    .setMetricName("rmse") \
    .setLabelCol("rating") \
    .setPredictionCol("prediction")

rmse = evaluator.evaluate(predictions)
print(f"Root-mean-square error = {rmse}")

Root-mean-square error = 2.019469736413673


In [78]:
from pyspark.mllib.evaluation import RegressionMetrics

regComparison = predictions.select("rating", "prediction") \
    .rdd.map(lambda x: (x(0), x(1)))
metrics = RegressionMetrics(regComparison)



In [86]:
metrics.rootMeanSquaredError

22/12/02 13:54:15 ERROR Executor: Exception in task 0.0 in stage 715.0 (TID 1842)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/worker.py", line 686, in main
    process()
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/worker.py", line 678, in process
    serializer.dump_stream(out_iter, outfile)
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/serializers.py", line 273, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/util.py", line 81, in wrapper
    return f(*args, **kwargs)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/session.py", line 910, in prepare
    verify_func(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1722, in verify
   

Py4JJavaError: An error occurred while calling o1333.rootMeanSquaredError.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 715.0 failed 1 times, most recent failure: Lost task 0.0 in stage 715.0 (TID 1842) (ingu627 executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/worker.py", line 686, in main
    process()
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/worker.py", line 678, in process
    serializer.dump_stream(out_iter, outfile)
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/serializers.py", line 273, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/util.py", line 81, in wrapper
    return f(*args, **kwargs)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/session.py", line 910, in prepare
    verify_func(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1722, in verify
    verify_value(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1700, in verify_struct
    verifier(v)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1722, in verify
    verify_value(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1716, in verify_default
    verify_acceptable_types(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1592, in verify_acceptable_types
    raise TypeError(
TypeError: field prediction: DoubleType() can not accept object Row(1.0=0) in type <class 'pyspark.sql.types.Row'>

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:559)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:765)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:747)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:512)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$4(RDD.scala:1236)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$6(RDD.scala:1237)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1589)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2323)
	at org.apache.spark.rdd.RDD.$anonfun$fold$1(RDD.scala:1174)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.fold(RDD.scala:1168)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$2(RDD.scala:1267)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1228)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$1(RDD.scala:1214)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1214)
	at org.apache.spark.mllib.stat.Statistics$.colStats(Statistics.scala:58)
	at org.apache.spark.mllib.evaluation.RegressionMetrics.summary$lzycompute(RegressionMetrics.scala:70)
	at org.apache.spark.mllib.evaluation.RegressionMetrics.summary(RegressionMetrics.scala:62)
	at org.apache.spark.mllib.evaluation.RegressionMetrics.SSerr$lzycompute(RegressionMetrics.scala:74)
	at org.apache.spark.mllib.evaluation.RegressionMetrics.SSerr(RegressionMetrics.scala:74)
	at org.apache.spark.mllib.evaluation.RegressionMetrics.meanSquaredError(RegressionMetrics.scala:106)
	at org.apache.spark.mllib.evaluation.RegressionMetrics.rootMeanSquaredError(RegressionMetrics.scala:115)
	at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:104)
	at java.base/java.lang.reflect.Method.invoke(Method.java:578)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:1589)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/worker.py", line 686, in main
    process()
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/worker.py", line 678, in process
    serializer.dump_stream(out_iter, outfile)
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/serializers.py", line 273, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/Users/hyunseokjung/spark-3.3.0/python/lib/pyspark.zip/pyspark/util.py", line 81, in wrapper
    return f(*args, **kwargs)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/session.py", line 910, in prepare
    verify_func(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1722, in verify
    verify_value(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1700, in verify_struct
    verifier(v)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1722, in verify
    verify_value(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1716, in verify_default
    verify_acceptable_types(obj)
  File "/Users/hyunseokjung/opt/anaconda3/envs/pyspark/lib/python3.9/site-packages/pyspark/sql/types.py", line 1592, in verify_acceptable_types
    raise TypeError(
TypeError: field prediction: DoubleType() can not accept object Row(1.0=0) in type <class 'pyspark.sql.types.Row'>

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:559)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:765)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:747)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:512)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$4(RDD.scala:1236)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$6(RDD.scala:1237)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	... 1 more


In [79]:
from pyspark.mllib.evaluation import RankingMetrics, RegressionMetrics
from pyspark.sql.functions import col, expr

perUserActual = predictions \
    .where("rating > 2.5") \
    .groupBy("userId") \
    .agg(expr("collect_set(movieId) as movies"))

In [55]:
perUserPredictions = predictions \
    .orderBy(col("userId"), expr("prediction DESC")) \
    .groupBy("userId") \
    .agg(expr("collect_list(movieId) as movies"))

In [56]:
perUserActualvPred = perUserActual.join(perUserPredictions, ["userId"]).rdd \
    .map(lambda row: (row[1], row[2][:15]))
ranks = RankingMetrics(perUserActualvPred)



In [57]:
ranks.meanAveragePrecision
ranks.precisionAt(5)

0.5076923076923078