In [1]:
import findspark
findspark.init('/home/ubuntu/spark-2.1.1-bin-hadoop2.7')
from pyspark.sql import SparkSession

  if obj.__module__ is "__builtin__":


In [2]:
spark = SparkSession.builder.appName('rec').getOrCreate()

In [3]:
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator

In [4]:
# movie lens dataset
data = spark.read.csv('movielens_ratings.csv',inferSchema=True,header=True)

In [5]:
data.show()

+-------+------+------+
|movieId|rating|userId|
+-------+------+------+
|      2|   3.0|     0|
|      3|   1.0|     0|
|      5|   2.0|     0|
|      9|   4.0|     0|
|     11|   1.0|     0|
|     12|   2.0|     0|
|     15|   1.0|     0|
|     17|   1.0|     0|
|     19|   1.0|     0|
|     21|   1.0|     0|
|     23|   1.0|     0|
|     26|   3.0|     0|
|     27|   1.0|     0|
|     28|   1.0|     0|
|     29|   1.0|     0|
|     30|   1.0|     0|
|     31|   1.0|     0|
|     34|   1.0|     0|
|     37|   1.0|     0|
|     41|   2.0|     0|
+-------+------+------+
only showing top 20 rows



In [6]:
data.describe().show()

+-------+------------------+------------------+------------------+
|summary|           movieId|            rating|            userId|
+-------+------------------+------------------+------------------+
|  count|              1501|              1501|              1501|
|   mean| 49.40572951365756|1.7741505662891406|14.383744170552964|
| stddev|28.937034065088994| 1.187276166124803| 8.591040424293272|
|    min|                 0|               1.0|                 0|
|    max|                99|               5.0|                29|
+-------+------------------+------------------+------------------+



In [7]:
train, test = data.randomSplit([0.8,0.2])

In [8]:
als = ALS(maxIter=5,regParam=0.01,userCol='userId',itemCol='movieId',ratingCol='rating')

In [9]:
model = als.fit(train)

In [10]:
predictions = model.transform(test)

In [11]:
predictions.show()

+-------+------+------+-----------+
|movieId|rating|userId| prediction|
+-------+------+------+-----------+
|     31|   1.0|    26| 0.90474665|
|     31|   1.0|    27| -1.4534044|
|     31|   4.0|    12|   0.605444|
|     31|   1.0|    13|-0.18196931|
|     31|   1.0|     4|     1.4291|
|     31|   3.0|     8| 0.18936466|
|     31|   2.0|    25|  1.6470556|
|     31|   1.0|    29| 0.14376839|
|     31|   3.0|    14|  0.7963653|
|     31|   1.0|     0| 0.05768317|
|     31|   1.0|    18|-0.19969986|
|     85|   1.0|    28|  2.4639535|
|     85|   1.0|    26|   4.464607|
|     85|   1.0|    13| 0.68408024|
|     85|   1.0|     5|  2.2834117|
|     85|   1.0|    15| 0.69987035|
|     85|   5.0|     8|  3.7173345|
|     85|   1.0|     2| -0.9393606|
|     65|   1.0|    28| 0.77409315|
|     65|   2.0|     3|   2.559153|
+-------+------+------+-----------+
only showing top 20 rows



In [12]:
evaluator = RegressionEvaluator(metricName='rmse',labelCol='rating',predictionCol='prediction')

In [13]:
rmse = evaluator.evaluate(predictions)

In [14]:
print('RMSE')
print(rmse)

RMSE
1.956665464516587


In [15]:
# use this on a fresh, single user
single_user = test.filter(test['userId']==11).select(['movieId','userId'])

In [16]:
single_user.show()

+-------+------+
|movieId|userId|
+-------+------+
|      6|    11|
|      9|    11|
|     10|    11|
|     11|    11|
|     12|    11|
|     27|    11|
|     35|    11|
|     36|    11|
|     37|    11|
|     39|    11|
|     40|    11|
|     59|    11|
|     61|    11|
|     66|    11|
|     69|    11|
|     72|    11|
|     78|    11|
|     79|    11|
|     81|    11|
|     94|    11|
+-------+------+
only showing top 20 rows



In [17]:
recommendations = model.transform(single_user)

In [18]:
recommendations.orderBy('prediction',ascending=False).show()

+-------+------+-----------+
|movieId|userId| prediction|
+-------+------+-----------+
|      6|    11|  2.6089587|
|     37|    11|  2.3616257|
|     12|    11|  2.1962087|
|     79|    11|  2.0509467|
|     39|    11|  1.8967505|
|     66|    11|  1.8317456|
|     61|    11|  1.7909635|
|     10|    11|  1.7628655|
|     69|    11|  1.5043404|
|     78|    11|  1.4511297|
|     72|    11|  1.3774076|
|     81|    11|  0.9184571|
|     40|    11| 0.90454495|
|     27|    11|  0.7996415|
|     36|    11| 0.76037085|
|     99|    11| 0.39643043|
|     11|    11| 0.36239907|
|     59|    11|-0.03539577|
|     35|    11|-0.24836501|
|     94|    11| -1.3076303|
+-------+------+-----------+
only showing top 20 rows

