In [28]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, avg, max, min, mean, count, udf, first, explode
# SparkSession.sparkContext.name
spark = SparkSession.builder.appName("Collaborative filtering").getOrCreate()

In [2]:
movies = spark.read.options(
    inferSchema=True,
    header=True
).csv("../data/movies.csv")
ratings = spark.read.options(
    inferSchema=True,
    header=True
).csv("../data/ratings.csv")
movies.printSchema()
ratings.printSchema()

                                                                                

root
 |-- movieId: integer (nullable = true)
 |-- title: string (nullable = true)
 |-- genres: string (nullable = true)

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)
 |-- timestamp: integer (nullable = true)



In [3]:
ratings.show()

+------+-------+------+---------+
|userId|movieId|rating|timestamp|
+------+-------+------+---------+
|     1|      1|   4.0|964982703|
|     1|      3|   4.0|964981247|
|     1|      6|   4.0|964982224|
|     1|     47|   5.0|964983815|
|     1|     50|   5.0|964982931|
|     1|     70|   3.0|964982400|
|     1|    101|   5.0|964980868|
|     1|    110|   4.0|964982176|
|     1|    151|   5.0|964984041|
|     1|    157|   5.0|964984100|
|     1|    163|   5.0|964983650|
|     1|    216|   5.0|964981208|
|     1|    223|   3.0|964980985|
|     1|    231|   5.0|964981179|
|     1|    235|   4.0|964980908|
|     1|    260|   5.0|964981680|
|     1|    296|   3.0|964982967|
|     1|    316|   3.0|964982310|
|     1|    333|   5.0|964981179|
|     1|    349|   4.0|964982563|
+------+-------+------+---------+
only showing top 20 rows



In [4]:
movies.show()

+-------+--------------------+--------------------+
|movieId|               title|              genres|
+-------+--------------------+--------------------+
|      1|    Toy Story (1995)|Adventure|Animati...|
|      2|      Jumanji (1995)|Adventure|Childre...|
|      3|Grumpier Old Men ...|      Comedy|Romance|
|      4|Waiting to Exhale...|Comedy|Drama|Romance|
|      5|Father of the Bri...|              Comedy|
|      6|         Heat (1995)|Action|Crime|Thri...|
|      7|      Sabrina (1995)|      Comedy|Romance|
|      8| Tom and Huck (1995)|  Adventure|Children|
|      9| Sudden Death (1995)|              Action|
|     10|    GoldenEye (1995)|Action|Adventure|...|
|     11|American Presiden...|Comedy|Drama|Romance|
|     12|Dracula: Dead and...|       Comedy|Horror|
|     13|        Balto (1995)|Adventure|Animati...|
|     14|        Nixon (1995)|               Drama|
|     15|Cutthroat Island ...|Action|Adventure|...|
|     16|       Casino (1995)|         Crime|Drama|
|     17|Sen

In [5]:
df = ratings.join(movies, on="movieId", how='left')

# df.groupby("userId").pivot('movieId').mean("rating").show()

In [6]:
train, test = df.randomSplit([0.8, 0.2])

In [7]:
train.count()

                                                                                

80539

In [9]:
import pandas as pd 

pd_df = train.toPandas()

                                                                                

In [8]:
# df.groupby("userId").pivot("movieId").agg({"rating": "first"}).show()

In [13]:
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator

als = ALS(
    userCol="userId", 
    itemCol="movieId", 
    ratingCol="rating", 
    nonnegative=True, 
    implicitPrefs=False, 
    coldStartStrategy="drop"
)

param_grid = ParamGridBuilder() \
                .addGrid(als.rank, [10, 50, 100, 150]) \
                .addGrid (als.regParam, [.01, .05, .1, .15]) \
                .build()

evaluator = RegressionEvaluator(
    metricName="rmse",
    labelCol="rating",
    predictionCol="prediction"
)

cv = CrossValidator(estimator=als, estimatorParamMaps=param_grid, evaluator=evaluator, numFolds=5)

model = cv.fit(train)

23/12/03 15:51:00 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                

In [17]:
model = model.bestModel

preds = model.transform(test)

RMSE = evaluator.evaluate(preds)

print(RMSE)

                                                                                

0.8656056522621457


In [31]:
train_preds = model.recommendForAllUsers(5)
train_preds = train_preds.withColumn("movieid_rating", explode("recommendations"))
train_preds.select("userId", col("movieid_rating.movieId"), col("movieid_rating.rating")).show()



+------+-------+---------+
|userId|movieId|   rating|
+------+-------+---------+
|     1|   3379| 5.753362|
|     1|  33649|5.6195025|
|     1| 171495|5.5365224|
|     1| 184245|5.5041466|
|     1| 179135|5.5041466|
|     2| 131724| 4.817793|
|     2| 184245| 4.740913|
|     2| 179135| 4.740913|
|     2| 134796| 4.740913|
|     2|  86237| 4.740913|
|     3|   6835| 4.837939|
|     3|   5746| 4.837939|
|     3|  70946|4.7974906|
|     3|   5181|4.7396092|
|     3|   7991|4.6180015|
|     4|   3851| 4.842227|
|     4|   4765| 4.768887|
|     4|   1046| 4.731052|
|     4|   1733|4.6588025|
|     4|   2204| 4.629056|
+------+-------+---------+
only showing top 20 rows



                                                                                

23/12/03 17:47:22 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 1091006 ms exceeds timeout 120000 ms
23/12/03 17:47:22 WARN SparkContext: Killing executors is not supported by current scheduler.
23/12/03 17:47:23 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$