In [64]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql import functions as F
from pyspark.sql.window import Window

In [None]:
spark = SparkSession.builder \
    .appName("MovieDataFrame") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

In [66]:
df = spark \
    .read \
    .option('header', 'true') \
    .csv('movie/ratings.csv', inferSchema=True)

df = df.select(['userId', 'movieId', 'rating'])

In [67]:
df = df.limit(124289)

In [68]:
normalize = True
if normalize:
        avg = df.groupby('movieId').mean('rating').alias('avg_rating')
        df = df.join(avg, on='movieId') \
                .withColumn('rating', col('rating') - col('avg(rating)')) \
                .select('userId', 'movieId', 'rating')

In [69]:
MIN_RATINGS = 10
movie_counts = df.groupBy("movieId").agg(F.count("rating").alias("count"))
df_filtered = df.join(
    F.broadcast(movie_counts), "movieId"
).filter(F.col("count") >= MIN_RATINGS)

In [70]:
df_filtered = df_filtered.select("userId", "movieId", "rating")

In [71]:
movie_norms = df_filtered.groupBy("movieId").agg(
    F.sum(F.col("rating") * F.col("rating")).alias("sq_sum")
).withColumn("norm", F.sqrt(F.col("sq_sum")))

In [72]:
pairs = df_filtered.alias("a").join(
    df_filtered.alias("b"),
    (F.col("a.userId") == F.col("b.userId")) & 
    (F.col("a.movieId") < F.col("b.movieId"))
).select(
    F.col("a.movieId").alias("movie1"),
    F.col("b.movieId").alias("movie2"),
    (F.col("a.rating") * F.col("b.rating")).alias("product")
)

In [73]:
dot_products = pairs.groupBy("movie1", "movie2").agg(
    F.sum("product").alias("dot_product"),
    F.count("product").alias("common_users")
).filter(F.col("common_users") >= 5)

In [74]:
similarities = dot_products.join(
    movie_norms.alias("n1"),
    F.col("movie1") == F.col("n1.movieId")
).join(
    movie_norms.alias("n2"),
    F.col("movie2") == F.col("n2.movieId")
).select(
    "movie1",
    "movie2",
    (F.col("dot_product") / (F.col("n1.norm") * F.col("n2.norm"))).alias("similarity"),
    "common_users"
)

In [75]:
k = 20
window_spec = Window \
    .partitionBy("userId", "movieId") \
    .orderBy(F.desc("similarity"))

In [76]:
user_rated_movies = df.select(
    F.col("userId").alias("check_userId"),
    F.col("movieId").alias("check_movieId"),
    F.col("rating").alias("check_rating")
).distinct()

predictions = (
    df.alias("rated")
    .join(
        similarities.alias("sim"),
        F.col("rated.movieId") == F.col("sim.movie1")
    )
    .join(
        user_rated_movies.alias("check"),
        (F.col("rated.userId") == F.col("check.check_userId")) &
        (F.col("sim.movie2") == F.col("check.check_movieId"))
    )
    .select(
        F.col("rated.userId").alias("userId"),
        F.col("rated.movieId").alias("movieId"),
        F.col("rated.rating").alias("rating1"),
        F.col("sim.movie2").alias("movie2"),
        F.col("sim.similarity").alias("similarity"),
        F.col("check.check_rating").alias("rating2")
    )
    .withColumn("rank",F.row_number().over(window_spec))
    .filter(F.col("rank") <= 20)
    .drop("rank")
)

In [77]:
predictions = predictions.groupBy("userId", "movieId").agg(
    F.first("rating1").alias("actual_rating"),
    F.sum(F.col("similarity") * F.col("rating2")).alias("numerator"),
    F.sum("similarity").alias("denominator")
)

In [78]:
predictions = predictions.withColumn(
    "predicted_rating",
    F.when(F.col("denominator") != 0, 
          F.col("numerator") / F.col("denominator"))
    .otherwise(None)
).select(
    "userId",
    "movieId",
    "actual_rating",
    "predicted_rating",
    "numerator",
    "denominator"
)

In [None]:
predictions.show() # with normalization

+------+-------+--------------------+--------------------+--------------------+--------------------+
|userId|movieId|       actual_rating|    predicted_rating|           numerator|         denominator|
+------+-------+--------------------+--------------------+--------------------+--------------------+
|     1|    110| -2.9682539682539684| 0.06525509445368251|0.055278673486096115|  0.8471165960127823|
|     1|    147|   1.074074074074074|  0.6963676266263935| 0.09589854742944068|  0.1377125296505074|
|     1|    858|  0.5391459074733094|  0.6620630863667751|  0.9359363469756963|  1.4136664107220722|
|     1|   1221|  0.7722513089005236|   0.561825507154906|  0.5586766366468569|  0.9943952873838088|
|     1|   1246|  1.1691729323308269|  0.6696351763210588|  0.9220014855626385|  1.3768713445253387|
|     1|   1968| 0.17279411764705888|  0.5769007613879342|   0.645035871978795|  1.1181054267062123|
|     1|   2762|                 0.5|  0.7948736709821944|  0.6818162154655183|  0.85776676

In [None]:
predictions.show() # without normalization

+------+-------+-------------+------------------+------------------+-------------------+
|userId|movieId|actual_rating|  predicted_rating|         numerator|        denominator|
+------+-------+-------------+------------------+------------------+-------------------+
|     1|    110|          1.0| 4.419925124802842|26.964949593217078|  6.100770676385586|
|     1|    147|          4.5| 4.663748348382494|3.5654541950625065| 0.7645039845040301|
|     1|    858|          5.0| 4.413742066493125| 29.54962338025821|  6.694913960782588|
|     1|   1221|          5.0| 4.287753968038814|23.049364672009936|  5.375626690295512|
|     1|   1246|          5.0|4.2883038330971806|22.225976007628546|  5.182929398819224|
|     1|   1968|          4.0| 4.360929025959088|19.916061866425174|  4.566930979126652|
|     1|   2762|          4.5| 4.293049071076338| 24.46407064645143|  5.698530401451452|
|     1|   2918|          5.0| 4.278330366188078| 17.59239899193086|  4.111977684323944|
|     1|   2959|     

In [60]:
valid_predictions = predictions.filter(F.col("predicted_rating").isNotNull())

In [61]:
metrics = valid_predictions.agg(
    F.count("*").alias("count"),
    F.avg(F.abs(F.col("actual_rating") - F.col("predicted_rating"))).alias("mae"),
    F.sqrt(F.avg(F.pow(F.col("actual_rating") - F.col("predicted_rating"), 2))).alias("rmse")
)

In [None]:
metrics.show() # with normalization

+------+------------------+-----------------+
| count|               mae|             rmse|
+------+------------------+-----------------+
|101707|0.7203858483575515|19.53000856692622|
+------+------------------+-----------------+



In [None]:
metrics.show() # without normalization

+------+------------------+------------------+
| count|               mae|              rmse|
+------+------------------+------------------+
|103837|0.6746410067500092|0.8909160830954588|
+------+------------------+------------------+

