In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [3]:
spark = SparkSession.builder.appName("Recommendation").master("local[*]").getOrCreate()

In [4]:
ratings_file = "data/ml-latest-small/ratings.csv"
movies_file = "data/ml-latest-small/movies.csv"

In [5]:
ratings = (
    spark.read.csv(
        path=ratings_file,
        sep=",",
        header=True,
        quote='"',
        schema="userId INT, movieId INT, rating DOUBLE, timestamp INT"
    )
    .select("userId", "movieId", "rating")
)

In [6]:
movies = (
    spark.read.csv(
        path=movies_file,
        sep=",",
        header=True,
        quote='"',
        schema="movieId INT, title STRING, genres STRING"
    )
    .withColumn("release_year", f.regexp_extract(f.col("title"), "\s?\((\d{4})\)", 1))
    .withColumn("title", f.regexp_replace(f.col("title"), "\s?\((\d{4})\)", ""))
    .withColumn("genres", f.split(f.col("genres"), "\|"))
)

In [7]:
movies.show()

+-------+--------------------+--------------------+------------+
|movieId|               title|              genres|release_year|
+-------+--------------------+--------------------+------------+
|      1|           Toy Story|[Adventure, Anima...|        1995|
|      2|             Jumanji|[Adventure, Child...|        1995|
|      3|    Grumpier Old Men|   [Comedy, Romance]|        1995|
|      4|   Waiting to Exhale|[Comedy, Drama, R...|        1995|
|      5|Father of the Bri...|            [Comedy]|        1995|
|      6|                Heat|[Action, Crime, T...|        1995|
|      7|             Sabrina|   [Comedy, Romance]|        1995|
|      8|        Tom and Huck|[Adventure, Child...|        1995|
|      9|        Sudden Death|            [Action]|        1995|
|     10|           GoldenEye|[Action, Adventur...|        1995|
|     11|American Presiden...|[Comedy, Drama, R...|        1995|
|     12|Dracula: Dead and...|    [Comedy, Horror]|        1995|
|     13|               B

In [8]:
als = ALS(
    userCol="userId",
    itemCol="movieId",
    ratingCol="rating",
)

evaluator = RegressionEvaluator(
    metricName="rmse", labelCol="rating", predictionCol="prediction"
)

parameter_grid = (
    ParamGridBuilder()
    .addGrid(als.rank, [1,5,10])
    .addGrid(als.maxIter, [20])
    .addGrid(als.regParam, [0.05, 0.1])
    .build()
)

crossValidator = CrossValidator(
    estimator=als,
    estimatorParamMaps=parameter_grid,
    evaluator=evaluator,
    numFolds=2,
)

(training_data, validation_data) = ratings.randomSplit([8.0, 2.0])

training_data.show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|      1|   4.0|
|     1|      3|   4.0|
|     1|      6|   4.0|
|     1|     47|   5.0|
|     1|     50|   5.0|
|     1|     70|   3.0|
|     1|    101|   5.0|
|     1|    110|   4.0|
|     1|    151|   5.0|
|     1|    163|   5.0|
|     1|    216|   5.0|
|     1|    223|   3.0|
|     1|    231|   5.0|
|     1|    235|   4.0|
|     1|    260|   5.0|
|     1|    296|   3.0|
|     1|    316|   3.0|
|     1|    333|   5.0|
|     1|    362|   5.0|
|     1|    367|   4.0|
+------+-------+------+
only showing top 20 rows



In [9]:
crossval_model = crossValidator.fit(training_data)

In [10]:
model = crossval_model.bestModel
predictions = model.transform(validation_data).na.drop()
print(f"rmse for best model ({model}): {evaluator.evaluate(predictions)}")

rmse for best model (ALSModel: uid=ALS_5b67321727af, rank=1): 0.8865769511675279


In [11]:
predictions.toPandas()

Unnamed: 0,userId,movieId,rating,prediction
0,436,471,3.0,3.559944
1,409,471,3.0,3.797771
2,217,471,2.0,3.002758
3,171,471,3.0,4.406048
4,448,471,4.0,3.302723
...,...,...,...,...
19591,448,84374,2.0,2.998096
19592,298,84374,0.5,2.255881
19593,448,145839,2.5,3.124542
19594,380,147378,3.0,3.610469


In [12]:
USER_ID = [150, 160, 170, 180]

rec_all_users = model.recommendForAllUsers(5)
rec_all_users.show(1, False)

+------+---------------------------------------------------------------------------------------------+
|userId|recommendations                                                                              |
+------+---------------------------------------------------------------------------------------------+
|471   |[[6835, 7.996435], [5181, 7.996435], [5746, 7.996435], [136850, 7.3697414], [5764, 7.196791]]|
+------+---------------------------------------------------------------------------------------------+
only showing top 1 row



In [13]:
def recommender(rec_all_users, movies, userids):
    return_val = (rec_all_users.filter(f.col("userId").isin(userids))
    .withColumn("rec", f.explode("recommendations"))
    .select(
        "userId",
        f.col("rec").movieId.alias("movieId"),
        f.col("rec").rating.alias("rating")
    )
    .join(movies, "movieId")
    .orderBy("rating", ascending=False)
    .select("userId", "movieId", "title", "release_year"))
    return return_val

In [14]:
foo = recommender(rec_all_users, movies, USER_ID)

In [15]:
foo.show()

+------+-------+--------------------+------------+
|userId|movieId|               title|release_year|
+------+-------+--------------------+------------+
|   150|   5746|Galaxy of Terror ...|        1981|
|   150|   5181|           Hangar 18|        1980|
|   150|   6835| Alien Contamination|        1980|
|   170|   6835| Alien Contamination|        1980|
|   170|   5746|Galaxy of Terror ...|        1981|
|   170|   5181|           Hangar 18|        1980|
|   180|   5746|Galaxy of Terror ...|        1981|
|   180|   5181|           Hangar 18|        1980|
|   180|   6835| Alien Contamination|        1980|
|   150| 136850|             Villain|        1971|
|   170| 136850|             Villain|        1971|
|   150|   5764|              Looker|        1981|
|   170|   5764|              Looker|        1981|
|   180| 136850|             Villain|        1971|
|   180|   5764|              Looker|        1981|
|   160|   5746|Galaxy of Terror ...|        1981|
|   160|   5181|           Hang