In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import col

In [2]:
# Initialize SparkSession
spark = SparkSession.builder.appName('25m_movie_lens').getOrCreate()

In [3]:
# Load both the ratings and movies data into a pandas DataFrame
ratings_df = spark.read.csv('./ml-25m/ratings.csv', header=True, inferSchema=True) .limit(50000)
movies_df = spark.read.csv('./ml-25m/movies.csv', header=True, inferSchema=True).limit(50000)

ratings_df.show()
movies_df.show()

+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
|     1|    296|   5.0|1147880044|
|     1|    306|   3.5|1147868817|
|     1|    307|   5.0|1147868828|
|     1|    665|   5.0|1147878820|
|     1|    899|   3.5|1147868510|
|     1|   1088|   4.0|1147868495|
|     1|   1175|   3.5|1147868826|
|     1|   1217|   3.5|1147878326|
|     1|   1237|   5.0|1147868839|
|     1|   1250|   4.0|1147868414|
|     1|   1260|   3.5|1147877857|
|     1|   1653|   4.0|1147868097|
|     1|   2011|   2.5|1147868079|
|     1|   2012|   2.5|1147868068|
|     1|   2068|   2.5|1147869044|
|     1|   2161|   3.5|1147868609|
|     1|   2351|   4.5|1147877957|
|     1|   2573|   4.0|1147878923|
|     1|   2632|   5.0|1147878248|
|     1|   2692|   5.0|1147869100|
+------+-------+------+----------+
only showing top 20 rows

+-------+--------------------+--------------------+
|movieId|               title|              genres|
+-------+--------------------+

In [4]:
# Left join ratings on movies
join_df = ratings_df.join(movies_df, on='movieId', how='left') \
    .select(col('userId'), col('movieId'), col('rating'), col('title'), col('genres')) \
    .sort('userId')

join_df.show()

+------+-------+------+--------------------+--------------------+
|userId|movieId|rating|               title|              genres|
+------+-------+------+--------------------+--------------------+
|     1|    296|   5.0| Pulp Fiction (1994)|Comedy|Crime|Dram...|
|     1|    306|   3.5|Three Colors: Red...|               Drama|
|     1|    307|   5.0|Three Colors: Blu...|               Drama|
|     1|    665|   5.0|  Underground (1995)|    Comedy|Drama|War|
|     1|    899|   3.5|Singin' in the Ra...|Comedy|Musical|Ro...|
|     1|   1088|   4.0|Dirty Dancing (1987)|Drama|Musical|Rom...|
|     1|   1175|   3.5| Delicatessen (1991)|Comedy|Drama|Romance|
|     1|   1217|   3.5|          Ran (1985)|           Drama|War|
|     1|   1237|   5.0|Seventh Seal, The...|               Drama|
|     1|   1250|   4.0|Bridge on the Riv...| Adventure|Drama|War|
|     1|   1260|   3.5|            M (1931)|Crime|Film-Noir|T...|
|     1|   1653|   4.0|      Gattaca (1997)|Drama|Sci-Fi|Thri...|
|     1|  

In [5]:
(train_df, test_df) = join_df.randomSplit([0.8, 0.2])

In [6]:
# Define ALS model
als = ALS(maxIter=15, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating",
          coldStartStrategy="drop", nonnegative=True)

In [7]:
# Fit data and make predictions
model = als.fit(train_df)
predictions = model.transform(test_df)

predictions.show()

+------+-------+------+--------------------+--------------------+----------+
|userId|movieId|rating|               title|              genres|prediction|
+------+-------+------+--------------------+--------------------+----------+
|     1|    307|   5.0|Three Colors: Blu...|               Drama|  4.213637|
|     1|    665|   5.0|  Underground (1995)|    Comedy|Drama|War|  3.503456|
|     1|   1217|   3.5|          Ran (1985)|           Drama|War| 4.2155585|
|     1|   2012|   2.5|Back to the Futur...|Adventure|Comedy|...|   2.66482|
|     1|   2351|   4.5|Nights of Cabiria...|               Drama| 3.6762354|
|     1|   2843|   4.5|Black Cat, White ...|      Comedy|Romance|  2.862689|
|     1|   3949|   5.0|Requiem for a Dre...|               Drama| 3.5860238|
|     1|   4308|   3.0| Moulin Rouge (2001)|Drama|Musical|Rom...| 2.1690135|
|     1|   5147|   4.0|Wild Strawberries...|               Drama|  2.790667|
|     1|   6954|   3.5|Barbarian Invasio...|Comedy|Crime|Dram...|  2.603771|

In [8]:
# Evaluate model using RMSE
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
rmse = evaluator.evaluate(predictions)

print("Root Mean Squared Error = {:.4f}".format(rmse))

Root Mean Squared Error = 1.1476


In [9]:
# Predict list of movies for a specific user
movie_ids = [1, 11, 111]
user_ids = [10, 10, 10]

user_preds = spark.createDataFrame(zip(movie_ids, user_ids), schema=['movieId', 'userId'])

preds = model.transform(user_preds)

In [10]:
# Stop SparkSession
spark.stop()