In [1]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, collect_list, when, desc, rank, mean
from pyspark.sql.window import Window
from pyspark.ml.recommendation import ALS
from pyspark.mllib.evaluation import RankingMetrics

In [2]:
def main(spark, userID):

#    train = spark.read.parquet('/scratch/sjm643/sp24_bigd/rec_small/train.parquet')
#    val = spark.read.parquet('/scratch/sjm643/sp24_bigd/rec_small/val.parquet')

    train = spark.read.parquet('/scratch/sjm643/sp24_bigd/rec/train.parquet')
    val = spark.read.parquet('/scratch/sjm643/sp24_bigd/rec/val.parquet')
    
    best_MAP = 0
    
    for reg in [.01, .05, .1, .15]:
        for rk in [10, 50, 100, 150]:
    
            als = ALS(rank=rk, maxIter=10, regParam=reg,
                      userCol='userId', itemCol='movieId', ratingCol='rating', coldStartStrategy='drop')

            model = als.fit(train)
            predictions = model.transform(val)
            
            window_spec = Window.partitionBy('userId').orderBy(desc('prediction'))
            ranked_pred = predictions.withColumn('rank', rank().over(window_spec))
            top_100_per_user = ranked_pred.filter(col('rank') <= 100)
            top_100_per_user = top_100_per_user.drop(col('rank'))

            top_100_per_user_rdd = top_100_per_user.rdd.map(
                lambda row: (row['userId'], row['movieId'])).groupByKey().mapValues(list)

            mean_ratings_per_user = val.groupBy('userId').agg(mean("rating").alias("mean_rating"))                                            

            movies_with_mean = val.join(mean_ratings_per_user, 'userId', 'inner')
            movies_above_mean = movies_with_mean.filter(col('rating') > col('mean_rating'))

            movies_above_mean_rdd = movies_above_mean.rdd.map(
                lambda row: (row['userId'], row['movieId'])).groupByKey().mapValues(list)

            preds_and_labels = top_100_per_user_rdd.join(movies_above_mean_rdd).map(lambda row: (row[1][0], row[1][1])).collect()
            preds_and_labels_par = spark.sparkContext.parallelize(preds_and_labels)   

            metrics = RankingMetrics(preds_and_labels_par)

            MAP = metrics.meanAveragePrecision
            print(f'hyperparameters: reg={reg}, rank={rk}, MAP={MAP}')
            if MAP > best_MAP:
                best_MAP = MAP
                best_reg = reg
                best_rank = rk
    
    print(f'best hyperparameters: reg={best_reg}, rank={best_rank}')
    print(f'highest MAP = {best_MAP}')

In [3]:
if __name__ == "__main__":

    spark = SparkSession.builder \
    .appName("Spark Application") \
    .config("spark.executor.memory", "16g") \
    .config("spark.driver.memory", "16g") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.broadcastTimeout", "7200") \
    .config("spark.driver.extraJavaOptions", "-XX:+UseG1GC") \
    .config("spark.executor.extraJavaOptions", "-XX:+UseG1GC") \
    .getOrCreate()
    
    userID = os.environ['USER']

    main(spark, userID)

24/05/11 23:09:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/11 23:09:19 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
24/05/11 23:09:19 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
24/05/11 23:09:20 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
24/05/11 23:09:20 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK
                                                                                

hyperparameters: reg=0.01, rank=10, MAP=0.7147487979969479


                                                                                

hyperparameters: reg=0.01, rank=50, MAP=0.7053154561101421


                                                                                

hyperparameters: reg=0.01, rank=100, MAP=0.7088555511471046


                                                                                

hyperparameters: reg=0.01, rank=150, MAP=0.7129826663978723


                                                                                

hyperparameters: reg=0.05, rank=10, MAP=0.7356838368145181


                                                                                

hyperparameters: reg=0.05, rank=50, MAP=0.7460291996017364


                                                                                

hyperparameters: reg=0.05, rank=100, MAP=0.7500132912595657


                                                                                

hyperparameters: reg=0.05, rank=150, MAP=0.7513355097652926


                                                                                

hyperparameters: reg=0.1, rank=10, MAP=0.7407226551956696


                                                                                

hyperparameters: reg=0.1, rank=50, MAP=0.7458735936417


                                                                                

hyperparameters: reg=0.1, rank=100, MAP=0.7461316559023654


                                                                                ]

hyperparameters: reg=0.1, rank=150, MAP=0.7463411538469942


                                                                                

hyperparameters: reg=0.15, rank=10, MAP=0.7311171033319837


                                                                                ]

hyperparameters: reg=0.15, rank=50, MAP=0.7319180080857891


                                                                                ]

hyperparameters: reg=0.15, rank=100, MAP=0.7320140046446268


                                                                                

hyperparameters: reg=0.15, rank=150, MAP=0.7320569481930816
best hyperparameters: reg=0.05, rank=150
highest MAP = 0.7513355097652926
