In [5]:
## Import Libraries
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator

## Set seed
seed = 42

In [6]:
## Create Spark Session
spark = SparkSession.builder.appName('recSystem').getOrCreate()

In [8]:
## Setup Schema
schema = StructType(fields=[StructField('movie_id', IntegerType(), True),
                            StructField('rating', DoubleType(), True),
                            StructField('user_id', IntegerType(), True)])

In [9]:
## Load Data
df = spark.read.csv('gs://spark-training-data/datasets/movielens_ratings.csv', header=True,
                    inferSchema=False, schema=schema)
df.show(5)
df.printSchema() ## Confirm proper schema

+--------+------+-------+
|movie_id|rating|user_id|
+--------+------+-------+
|       2|   3.0|      0|
|       3|   1.0|      0|
|       5|   2.0|      0|
|       9|   4.0|      0|
|      11|   1.0|      0|
+--------+------+-------+
only showing top 5 rows

root
 |-- movie_id: integer (nullable = true)
 |-- rating: double (nullable = true)
 |-- user_id: integer (nullable = true)



In [11]:
df.describe().show()

+-------+------------------+------------------+------------------+
|summary|          movie_id|            rating|           user_id|
+-------+------------------+------------------+------------------+
|  count|              1501|              1501|              1501|
|   mean| 49.40572951365756|1.7741505662891406|14.383744170552964|
| stddev|28.937034065088994| 1.187276166124803| 8.591040424293272|
|    min|                 0|               1.0|                 0|
|    max|                99|               5.0|                29|
+-------+------------------+------------------+------------------+



In [12]:
## Split into training & test
train_data, test_data = df.randomSplit([0.8,0.2], seed=seed)

In [13]:
## Build Alternating Least Squares
als = ALS(maxIter=5, regParam=0.01, userCol='user_id', itemCol='movie_id', ratingCol='rating')

In [14]:
## Fit Model
als_model = als.fit(train_data)

                                                                                

In [15]:
## Make Predictions
predictions = als_model.transform(test_data)
predictions.show()

+--------+------+-------+-------------+
|movie_id|rating|user_id|   prediction|
+--------+------+-------+-------------+
|      85|   3.0|      1|   -0.1890984|
|      85|   3.0|      6|    1.8067293|
|      85|   1.0|     13|    1.2966782|
|      85|   1.0|     15|    1.6449969|
|      53|   1.0|      6|     4.619213|
|      53|   1.0|     12|  -0.11467552|
|      53|   3.0|     20|  -0.11618537|
|      78|   1.0|     19|    1.0745367|
|      78|   1.0|     28|    0.2962752|
|      34|   1.0|     16|    2.9352565|
|      34|   1.0|     28|    1.6606547|
|      81|   5.0|     28|    2.3649263|
|      76|   1.0|      1|     3.218676|
|      76|   1.0|     26|   0.39942393|
|      26|   1.0|      3|    -0.532228|
|      26|   1.0|     19|     1.634851|
|      27|   3.0|     27|    2.6601024|
|      44|   1.0|      6|-0.0074782073|
|      44|   1.0|     28|    1.0622034|
|      12|   1.0|     19|     3.092336|
+--------+------+-------+-------------+
only showing top 20 rows



In [16]:
## Model Evaluation
evaluator = RegressionEvaluator(metricName='rmse', labelCol='rating',
                                predictionCol='prediction')
rmse = evaluator.evaluate(predictions)
print(f'RMSE: {rmse}')

RMSE: 1.6874738409854153


In [22]:
## Sample "deployment" - Get single user's viewing history
single_user = test_data.filter(test_data['user_id'] == 11).select(['user_id','movie_id'])
single_user.show()

+-------+--------+
|user_id|movie_id|
+-------+--------+
|     11|       0|
|     11|      23|
|     11|      30|
|     11|      36|
|     11|      43|
|     11|      45|
|     11|      69|
|     11|      71|
|     11|      75|
|     11|      80|
+-------+--------+



In [24]:
## Predict single_user
reccomendations = als_model.transform(single_user)
reccomendations.show()

+-------+--------+----------+
|user_id|movie_id|prediction|
+-------+--------+----------+
|     11|      43|0.21923077|
|     11|      23| 1.6361773|
|     11|      69| 2.3371408|
|     11|      45|0.17520235|
|     11|      80| 0.8080733|
|     11|      75| -1.262606|
|     11|      71|  2.540303|
|     11|      30| 2.5853858|
|     11|       0| 1.3063326|
|     11|      36| 2.0547783|
+-------+--------+----------+

