In [36]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
SparkSession.builder.config(conf=SparkConf())

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder

import pandas as pd
import logging
import json

def create_spark_configuration():
    spark_config = None
    try:
        spark_config = (SparkSession.builder
            .appName("ElasticsearchSparkIntegration")
            # .config("spark.jars.packages", "org.elasticsearch:elasticsearch-spark-20_2.12:7.17.14,"
            #         "org.apache.spark:spark-sql-kafka-0-10_2.12:3.2.4")
            .getOrCreate())
        
        logging.info("Spark connection created successfully!")
    except Exception as e:
        logging.error(f"Couldn't create the spark session due to exception {e}")

    return spark_config

spark = create_spark_configuration()

In [60]:
df = spark.read.csv('data/csv/movie_ratings.csv',inferSchema=True,header=True)
df.describe().show()

(train,test) = df.select('userId','movieId','rating').randomSplit([0.7,0.3],seed=42)

+-------+------------------+------------------+------------------+-----------------+
|summary|            userId|           movieId|            rating|        timestamp|
+-------+------------------+------------------+------------------+-----------------+
|  count|            100000|            100000|            100000|           100000|
|   mean|         462.48475|         425.53013|           3.52986|8.8352885148862E8|
| stddev|266.61442012750905|330.79835632558473|1.1256735991443214|5343856.189502848|
|    min|                 1|                 1|                 1|        874724710|
|    max|               943|              1682|                 5|        893286638|
+-------+------------------+------------------+------------------+-----------------+



In [72]:
als = ALS(maxIter=5,regParam=0.01,userCol="userId",itemCol="movieId",ratingCol="rating",coldStartStrategy="drop",nonnegative=True)

In [73]:
model = als.fit(train)

In [104]:
prediction = model.transform(test)

In [105]:
prediction.show()

+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
|   897|    496|     5|  4.455332|
|   251|    148|     2| 2.9859333|
|   580|    148|     4| 3.7465968|
|   580|    471|     3|   3.25167|
|    65|    471|     4| 3.4136634|
|   883|   1591|     3| 3.7099378|
|   588|    463|     4| 4.3964195|
|   588|    496|     3| 3.9017167|
|   472|    496|     4|  4.330444|
|   321|    496|     4|  4.161474|
|   593|    471|     3| 3.6864114|
|   642|    148|     5| 3.7612953|
|   731|    496|     5| 3.3621073|
|   332|    148|     5| 4.1136007|
|   332|    471|     4| 4.0084143|
|   271|    496|     5| 4.3179636|
|   844|    471|     3|  3.427908|
|   806|    496|     5| 3.8139317|
|   103|    471|     4| 3.5926201|
|   236|    496|     3| 3.9977946|
+------+-------+------+----------+
only showing top 20 rows



In [106]:
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")

In [107]:
rmse = evaluator.evaluate(prediction)
print(rmse)

1.0242864089964292


In [109]:
user_1 = test.filter(test['userId'] > 940 ).select(['movieId','userId'])

In [110]:
user_1.show()

+-------+------+
|movieId|userId|
+-------+------+
|      7|   941|
|    147|   941|
|    181|   941|
|    257|   941|
|    258|   941|
|    273|   941|
|     95|   942|
|    117|   942|
|    210|   942|
|    234|   942|
|    272|   942|
|    300|   942|
|    304|   942|
|    328|   942|
|    347|   942|
|    357|   942|
|    498|   942|
|    511|   942|
|    539|   942|
|    607|   942|
+-------+------+
only showing top 20 rows



In [111]:
rec = model.transform(user_1)

In [112]:
rec.orderBy('prediction',ascending=False).show()

+-------+------+----------+
|movieId|userId|prediction|
+-------+------+----------+
|    219|   943|  6.758924|
|    763|   943|  5.272481|
|     68|   943|  5.150526|
|    541|   943|   5.10025|
|    188|   943| 5.0391307|
|    498|   942| 4.9882317|
|    210|   942| 4.9415627|
|     11|   943|  4.926779|
|     56|   943| 4.8678093|
|    272|   942| 4.8562126|
|    373|   943|  4.786627|
|    201|   943|  4.746359|
|     12|   943| 4.7364955|
|    357|   942| 4.6934023|
|      7|   941|  4.654271|
|    705|   942| 4.6477613|
|    117|   942| 4.6376038|
|    945|   942|  4.546354|
|     95|   942| 4.5394583|
|    187|   943|  4.535033|
+-------+------+----------+
only showing top 20 rows

