In [2]:
import pyspark
from pyspark.sql.types import *
import pyspark.ml.recommendation
from pyspark.ml.recommendation import *
import pyspark.ml.evaluation
from pyspark.ml.evaluation import *
import numpy as np
from pyspark.ml.evaluation import RegressionEvaluator

In [3]:
sc, spark

(<SparkContext master=local[4] appName=PySparkShell>,
 <pyspark.sql.session.SparkSession at 0x110a38dd8>)

In [4]:
ratings_schema = StructType( [
    StructField('userId', IntegerType(), True),
    StructField('movieId', IntegerType(), True),
    StructField('rating', FloatType(), True),
    StructField('timestamp', IntegerType(), True)
        ])

raw_ratings_df = spark.read.csv("./data/movies/ratings.csv",
                               schema=ratings_schema)

raw_ratings_df.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: float (nullable = true)
 |-- timestamp: integer (nullable = true)



In [5]:
movies_schema = StructType( [
    StructField('movieId', IntegerType(), True),
    StructField('title', StringType(), True),
    StructField('genre', StringType(), True)
        ])

raw_movies_df = spark.read.csv("./data/movies/movies.csv",
                               schema=movies_schema)

raw_movies_df.printSchema()

root
 |-- movieId: integer (nullable = true)
 |-- title: string (nullable = true)
 |-- genre: string (nullable = true)



In [6]:
raw_movies_df.registerTempTable('movies')
raw_ratings_df.registerTempTable('ratings')

# MVP For ALS on Rating Data Only

In [7]:
train_base = spark.sql("""
SELECT userId, 
movieId, timestamp,
CASE WHEN rating is null then 0.0 ELSE rating END as rating
FROM ratings
WHERE userId is not null
AND timestamp < 1296192000
""")

train_base.show()

+------+-------+----------+------+
|userId|movieId| timestamp|rating|
+------+-------+----------+------+
|     1|     31|1260759144|   2.5|
|     1|   1029|1260759179|   3.0|
|     1|   1061|1260759182|   3.0|
|     1|   1129|1260759185|   2.0|
|     1|   1172|1260759205|   4.0|
|     1|   1263|1260759151|   2.0|
|     1|   1287|1260759187|   2.0|
|     1|   1293|1260759148|   2.0|
|     1|   1339|1260759125|   3.5|
|     1|   1343|1260759131|   2.0|
|     1|   1371|1260759135|   2.5|
|     1|   1405|1260759203|   1.0|
|     1|   1953|1260759191|   4.0|
|     1|   2105|1260759139|   4.0|
|     1|   2150|1260759194|   3.0|
|     1|   2193|1260759198|   2.0|
|     1|   2294|1260759108|   2.0|
|     1|   2455|1260759113|   2.5|
|     1|   2968|1260759200|   1.0|
|     1|   3671|1260759117|   3.0|
+------+-------+----------+------+
only showing top 20 rows



In [8]:
test_base = spark.sql("""
SELECT userId, 
movieId, timestamp,
CASE WHEN rating is null then 0.0 ELSE rating END as rating
FROM ratings
WHERE userId is not null
AND timestamp >= 1296192000
""")

test_base.show()

+------+-------+----------+------+
|userId|movieId| timestamp|rating|
+------+-------+----------+------+
|     3|     60|1298861675|   3.0|
|     3|    110|1298922049|   4.0|
|     3|    247|1298861637|   3.5|
|     3|    267|1298861761|   3.0|
|     3|    296|1298862418|   4.5|
|     3|    318|1298862121|   5.0|
|     3|    355|1298861589|   2.5|
|     3|    356|1298862167|   5.0|
|     3|    377|1298923242|   2.5|
|     3|    527|1298862528|   3.0|
|     3|    588|1298922100|   3.0|
|     3|    592|1298923247|   3.0|
|     3|    593|1298921840|   3.0|
|     3|    595|1298923260|   2.0|
|     3|    736|1298932787|   3.5|
|     3|    778|1298863157|   4.0|
|     3|    866|1298861687|   3.0|
|     3|   1197|1298932770|   5.0|
|     3|   1210|1298921795|   3.0|
|     3|   1235|1298861628|   4.0|
+------+-------+----------+------+
only showing top 20 rows



In [9]:
als = ALS(rank=10, maxIter=5, seed=0, regParam=0.1, implicitPrefs=False,
          userCol="userId", itemCol="movieId", ratingCol="rating", nonnegative=True)
model = als.fit(train_base)

In [10]:
predictions = model.transform(test_base).persist()
predictions = predictions.na.drop()

In [11]:
predictions.registerTempTable("predictions")

In [14]:
predictions_base = spark.sql("""
SELECT * FROM predictions
WHERE NOT ISNAN(prediction)
ORDER BY prediction DESC
""").show()

+------+-------+----------+------+----------+
|userId|movieId| timestamp|rating|prediction|
+------+-------+----------+------+----------+
|    78|  79132|1327062958|   4.5| 5.2240825|
|   480|  37731|1339455993|   1.5| 5.1346283|
|   480|  79132|1339283594|   5.0| 5.0135756|
|   480|   1680|1339455518|   3.0|   4.87072|
|   380|  78499|1304471106|   4.0|  4.745358|
|   480|  80463|1339455165|   4.5|  4.725688|
|   426|   8132|1310375708|   4.0| 4.7219677|
|   480|   1207|1339455039|   4.5| 4.7203236|
|   480|  79702|1339285851|   3.0|  4.695591|
|   480|  68954|1339285893|   2.5| 4.6603575|
|   426|  53956|1320778912|   2.0|  4.609357|
|    73|   7063|1411451412|   4.0| 4.6051664|
|   480|    858|1339456454|   4.0|  4.601046|
|   501|   1704|1307129555|   5.0|  4.591006|
|   480|   1240|1339285265|   4.5|  4.581721|
|   380|  69951|1330910373|   3.0| 4.5737853|
|   480|  56782|1339455242|   3.5| 4.5698547|
|   501|   1136|1309492528|   3.5| 4.5658817|
|   480|  73017|1339284321|   5.0|

In [15]:
evaluator = RegressionEvaluator(metricName='rmse', labelCol='rating',
                                predictionCol='prediction')
rmse = evaluator.evaluate(predictions)
print('Root-mean-square error = ' + str(rmse))

Root-mean-square error = 0.8590750999735148


# Using Dummy Variables for Genre

In [16]:
train_genre = spark.sql("""
SELECT DISTINCT userId, 
ratings.movieId,
CASE WHEN rating is null then 0.0 ELSE rating END as rating,
CASE WHEN genre like '%Drama%' then 1 else 0 END as drama,
CASE WHEN genre like '%Comedy%' then 1 else 0 END as comedy,
CASE WHEN genre like '%Romance%' then 1 else 0 END as romance,
CASE WHEN genre like '%Action%' then 1 else 0 END as action,
CASE WHEN genre like '%Crime%' then 1 else 0 END as crime,
CASE WHEN genre like '%Mystery%' then 1 else 0 END as mystery,
CASE WHEN genre like '%War%' then 1 else 0 END as war,
CASE WHEN genre like '%West%' then 1 else 0 END as west,
CASE WHEN genre like '%Horror%' then 1 else 0 END as horror,
CASE WHEN genre like '%Thriller%' then 1 else 0 END as thriller,
CASE WHEN genre like '%Adventure%' then 1 else 0 END as adventure,
CASE WHEN genre like '%Documentary%' then 1 else 0 END as documentary,
CASE WHEN genre like '%Child%' then 1 else 0 END as childrens,
CASE WHEN genre like '%Animation%' then 1 else 0 END as animation,
CASE WHEN genre like '%Sci%' then 1 else 0 END as sci_fi,
CASE WHEN genre like '%Musical%' then 1 else 0 END as musical,
CASE WHEN genre like '%Fantasy%' then 1 else 0 END as fantasy,
CASE WHEN genre like '%Film-Noir%' then 1 else 0 END as film_noir,
CASE WHEN genre like '%IMAX%' then 1 else 0 END as imax

FROM ratings
LEFT JOIN movies on ratings.movieID = movies.movieID
WHERE userId is not null
AND timestamp < 1296192000
""")

In [17]:
test_genre = spark.sql("""
SELECT DISTINCT userId, 
ratings.movieId,
CASE WHEN rating is null then 0.0 ELSE rating END as rating,
CASE WHEN genre like '%Drama%' then 1 else 0 END as drama,
CASE WHEN genre like '%Comedy%' then 1 else 0 END as comedy,
CASE WHEN genre like '%Romance%' then 1 else 0 END as romance,
CASE WHEN genre like '%Action%' then 1 else 0 END as action,
CASE WHEN genre like '%Crime%' then 1 else 0 END as crime,
CASE WHEN genre like '%Mystery%' then 1 else 0 END as mystery,
CASE WHEN genre like '%War%' then 1 else 0 END as war,
CASE WHEN genre like '%West%' then 1 else 0 END as west,
CASE WHEN genre like '%Horror%' then 1 else 0 END as horror,
CASE WHEN genre like '%Thriller%' then 1 else 0 END as thriller,
CASE WHEN genre like '%Adventure%' then 1 else 0 END as adventure,
CASE WHEN genre like '%Documentary%' then 1 else 0 END as documentary,
CASE WHEN genre like '%Child%' then 1 else 0 END as childrens,
CASE WHEN genre like '%Animation%' then 1 else 0 END as animation,
CASE WHEN genre like '%Sci%' then 1 else 0 END as sci_fi,
CASE WHEN genre like '%Musical%' then 1 else 0 END as musical,
CASE WHEN genre like '%Fantasy%' then 1 else 0 END as fantasy,
CASE WHEN genre like '%Film-Noir%' then 1 else 0 END as film_noir,
CASE WHEN genre like '%IMAX%' then 1 else 0 END as imax

FROM ratings
LEFT JOIN movies on ratings.movieID = movies.movieID
WHERE userId is not null
AND timestamp >= 1296192000
""")

In [18]:
als_genre = ALS(rank=10, maxIter=5, seed=0, regParam=0.1, implicitPrefs=False,
          userCol="userId", itemCol="movieId", ratingCol="rating", nonnegative=True)
model = als_genre.fit(train_genre)

In [19]:
predictions_genre = model.transform(test_genre).persist()
predictions_genre = predictions.na.drop()

In [20]:
evaluator = RegressionEvaluator(metricName='rmse', labelCol='rating',
                                predictionCol='prediction')
rmse = evaluator.evaluate(predictions_genre)
print('Root-mean-square error = ' + str(rmse))

Root-mean-square error = 0.8590750999735149


# Testing Hyperparameters on the OG ALS Model

In [54]:
als_playground = ALS(rank=10, maxIter=5, seed=0, regParam=.2, implicitPrefs=False,
          userCol="userId", itemCol="movieId", ratingCol="rating", nonnegative=True)
model = als_playground.fit(train_base)

In [55]:
predictions_playground = model.transform(test_base).persist()
predictions_playground = predictions_playground.na.drop()

In [56]:
evaluator = RegressionEvaluator(metricName='rmse', labelCol='rating',
                                predictionCol='prediction')
rmse = evaluator.evaluate(predictions_playground)
print('Root-mean-square error = ' + str(rmse))

Root-mean-square error = 0.8538004128931168


In [57]:
userRecs = model.recommendForAllUsers(10)

In [59]:
userRecs.show()

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|   471|[[4302,5.173933],...|
|   463|[[4302,5.0394797]...|
|   496|[[5765,5.811225],...|
|   148|[[4731,5.581875],...|
|   540|[[5765,5.725072],...|
|   392|[[8208,4.56997], ...|
|   243|[[5071,5.039308],...|
|   623|[[5071,5.5812473]...|
|    31|[[5071,5.5674973]...|
|   516|[[5765,5.1803555]...|
|   580|[[4731,4.8906956]...|
|   451|[[8208,5.080391],...|
|    85|[[5765,4.967734],...|
|   137|[[5765,5.197393],...|
|    65|[[5071,5.633008],...|
|    53|[[31116,4.228965]...|
|   255|[[4731,5.6881833]...|
|   588|[[4302,5.2902207]...|
|   472|[[5071,5.499962],...|
|   322|[[4731,5.2431684]...|
+------+--------------------+
only showing top 20 rows

