In [None]:
from pyspark.sql import (
    functions as f,
    SparkSession,
    types as t
)

# https://www.kaggle.com/code/varian97/item-based-collaborative-filtering/notebook

spark = SparkSession.builder.appName("item_base_cf").getOrCreate()
# anime_df = spark.read.csv(
#     "file:///home/jovyan/work/sample/anime.csv", header=True, inferSchema=True)
# # anime_df.show()

rating_df = spark.read.csv(
    "file:///home/jovyan/work/sample/rating.csv", header=True, inferSchema=True)
# rating_df.show()

ratings = rating_df.select("user_id", "anime_id", "rating")

# grouping the ratings for the same user and distinct anime_id
anime_pairs = ratings.alias("rating_a")\
                .join(
                    ratings.alias("rating_b"),
                    (f.col("rating_a.user_id") == f.col("rating_b.user_id")) & \
                        (f.col("rating_a.anime_id") < f.col("rating_b.anime_id"))) \
                .select(
                    f.col("rating_a.anime_id").alias("anime_a"),
                    f.col("rating_b.anime_id").alias("anime_b"),
                    f.col("rating_a.rating").alias("rating_a"),
                    f.col("rating_b.rating").alias("rating_b"))
# anime_pairs.show()

def get_similarity(spark, anime_pairs):
    pair_score = anime_pairs \
                    .withColumn("rating_aa", f.col("rating_a") * f.col("rating_a")) \
                    .withColumn("rating_bb", f.col("rating_b") * f.col("rating_b")) \
                    .withColumn("rating_ab", f.col("rating_a") * f.col("rating_b")) 
    
    calc_similarity = pair_score \
      .groupBy("anime_a", "anime_b") \
      .agg(
        f.sum(f.col("rating_ab")).alias("numerator"),
        (f.sqrt(f.sum(f.col("rating_aa"))) * f.sqrt(f.sum(f.col("rating_bb")))).alias("denominator"),
        f.count(f.col("rating_ab")).alias("pair_count")
      )
    
    data = calc_similarity \
      .withColumn("score", \
        f.when(f.col("denominator") != 0, f.col("numerator") / f.col("denominator")) \
          .otherwise(0) \
      ).select("anime_a", "anime_b", "score", "pair_count")

    return data

anime_pair_similarity = get_similarity(spark, anime_pairs)
anime_pair_similarity.show()


