In [1]:
sc.setLogLevel("ERROR")

In [2]:
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import StringIndexer

In [3]:
# read parquet files
ca_train_10 = spark.read.parquet("gs://st446-final-zqh/CA/train_10.parquet")
ca_test_10 = spark.read.parquet("gs://st446-final-zqh/CA/test_10.parquet")

In [4]:
ca_train_10.show(5)

+--------------------+--------------------+------+----------+--------------+-----------+-------+-----------------+
|            business|                user|rating|avg_rating|num_of_reviews|business_id|user_id|__index_level_0__|
+--------------------+--------------------+------+----------+--------------+-----------+-------+-----------------+
|0x809ad0db45f1a8c...|11545947020589421...|   5.0| 3.6842105|            19|     141123|2856849|          1551225|
|0x80857cd919d7414...|11613779044853393...|   4.0|       4.5|             2|      31946|2982347|          2322900|
|0x80950dccbcf816d...|10790460684728616...|   5.0|  4.793103|            29|     128648|1462057|          2329232|
|0x80dcaa437800fe9...|11075637764241779...|   4.0| 4.1864405|            59|     324558|1987661|          5210014|
|0x80dcb13252a4122...|10472812636391648...|   5.0| 4.6666665|            12|     328658| 874058|          1026907|
+--------------------+--------------------+------+----------+--------------+----

In [5]:
ca_train_10.count()

6316254

In [6]:
ca_train_10.count()

6316254

In [10]:
from pyspark.sql.functions import when, col
ca_train_10 = ca_train_10.withColumn("rating_binary", when(col("rating") > 4, 1).otherwise(0))
ca_test_10 = ca_test_10.withColumn("rating_binary", when(col("rating") > 4, 1).otherwise(0))

In [11]:
ca_train_10.show(5)

+--------------------+--------------------+------+----------+--------------+-----------+-------+-----------------+-------------+
|            business|                user|rating|avg_rating|num_of_reviews|business_id|user_id|__index_level_0__|rating_binary|
+--------------------+--------------------+------+----------+--------------+-----------+-------+-----------------+-------------+
|0x809ad0db45f1a8c...|11545947020589421...|   5.0| 3.6842105|            19|     141123|2856849|          1551225|            1|
|0x80857cd919d7414...|11613779044853393...|   4.0|       4.5|             2|      31946|2982347|          2322900|            0|
|0x80950dccbcf816d...|10790460684728616...|   5.0|  4.793103|            29|     128648|1462057|          2329232|            1|
|0x80dcaa437800fe9...|11075637764241779...|   4.0| 4.1864405|            59|     324558|1987661|          5210014|            0|
|0x80dcb13252a4122...|10472812636391648...|   5.0| 4.6666665|            12|     328658| 874058| 

In [12]:
def grid_search(train_data, test_data):
    best_rmse = float("inf")
    best_model = None
    best_predictions = None
    best_rank = 0
    best_regParam = 0
    best_maxIter = 0
    for rank in [60]:
        for regParam in [0.1]:
            for maxIter in [20]:
                als = ALS(rank=rank, maxIter=maxIter, regParam=regParam, userCol="user_id", itemCol="business_id", ratingCol="rating_binary", coldStartStrategy="drop")
                model = als.fit(train_data)
                predictions = model.transform(test_data)
                evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating_binary", predictionCol="prediction")
                rmse = evaluator.evaluate(predictions)
                if rmse < best_rmse:
                    best_rmse = rmse
                    best_model = model
                    best_predictions = predictions
                    best_rank = rank
                    best_regParam = regParam
                    best_maxIter = maxIter
                print(f"rank = {rank}, regParam = {regParam}, maxIter = {maxIter}, RMSE = {rmse}")
    return best_model,best_rmse,best_predictions,best_rank,best_regParam,best_maxIter


In [None]:
best_model,best_rmse,best_predictions,best_rank,best_regParam,best_maxIter = grid_search(ca_train_10, ca_test_10)
print("The best rank is: ", best_rank)
print("The best regParam is: ", best_regParam)
print("The best maxIter is: ", best_maxIter)
print("The best rmse is: ", best_rmse)

rank = 60, regParam = 0.1, maxIter = 20, RMSE = 0.5370825174187231
The best rank is:  60
The best regParam is:  0.1
The best maxIter is:  20
The best rmse is:  0.5370825174187231


In [16]:
best_predictions.show(5)

+--------------------+--------------------+------+----------+--------------+-----------+-------+-----------------+-------------+------------+
|            business|                user|rating|avg_rating|num_of_reviews|business_id|user_id|__index_level_0__|rating_binary|  prediction|
+--------------------+--------------------+------+----------+--------------+-----------+-------+-----------------+-------------+------------+
|0x4cb41c74dc86650...|11005594223594790...|   5.0| 4.1641793|            67|        496|1858836|          4052680|            1|         0.0|
|0x4cb41c74dc86650...|11201278429100995...|   4.0| 4.1641793|            67|        496|2220093|          4052657|            0|         0.0|
|0x4cb41c74dc86650...|10200679145240549...|   4.0| 4.1641793|            67|        496| 370422|          4052699|            0|         0.0|
|0x4cb41c74dc86650...|10109045912522730...|   5.0| 4.1641793|            67|        496| 201320|          4052698|            1|   0.6440561|
|0x54c

In [17]:
# define a function to evaluate the model
def evaluate_model(predictions):
    # calculate TP, FP, TN, FN
    TP = predictions.filter((col('prediction_binary') == 1) & (col('rating_binary') == 1)).count()
    FP = predictions.filter((col('prediction_binary') == 1) & (col('rating_binary') == 0)).count()
    TN = predictions.filter((col('prediction_binary') == 0) & (col('rating_binary') == 0)).count()
    FN = predictions.filter((col('prediction_binary') == 0) & (col('rating_binary') == 1)).count()

    # calculate Precision, Recall, F1-Score
    precision = TP / (TP + FP) if (TP + FP) else 0
    recall = TP / (TP + FN) if (TP + FN) else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) else 0

    # calculate accuracy
    accuracy = (TP + TN) / (TP + TN + FP + FN)

    # calculate Specificity
    specificity = TN / (TN + FP) if (TN + FP) else 0

    return accuracy, precision, recall, f1_score, specificity

In [18]:
best_predictions = best_predictions.withColumn('prediction_binary', when(col('prediction') > 0, 1).otherwise(0))
accuracy, precision, recall, f1_score, specificity = evaluate_model(best_predictions)
print("The accuracy of the ALS model is:", accuracy)
print("The precision of the ALS model is:", precision)
print("The recall of the ALS model is:", recall)
print("The f1_score of the ALS model is:", f1_score)
print("The specificity of the ALS model is:", specificity)

The accuracy of the ALS model is: 0.6325648004175649
The precision of the ALS model is: 0.6408726300113754
The recall of the ALS model is: 0.8802199772412103
The f1_score of the ALS model is: 0.7417153815795965
The specificity of the ALS model is: 0.2620486544483129


In [19]:
# normalize the 'prediction' into 0-1

max_prediction = best_predictions.agg({"prediction": "max"}).collect()[0][0]
min_prediction = best_predictions.agg({"prediction": "min"}).collect()[0][0]

best_predictions = best_predictions.withColumn("probability", (best_predictions["prediction"] - min_prediction) / (max_prediction - min_prediction))

best_predictions.show(5)

+--------------------+--------------------+------+----------+--------------+-----------+-------+-----------------+-------------+------------+-----------------+-------------------+
|            business|                user|rating|avg_rating|num_of_reviews|business_id|user_id|__index_level_0__|rating_binary|  prediction|prediction_binary|        probability|
+--------------------+--------------------+------+----------+--------------+-----------+-------+-----------------+-------------+------------+-----------------+-------------------+
|0x4cb41c74dc86650...|11005594223594790...|   5.0| 4.1641793|            67|        496|1858836|          4052680|            1|         0.0|                0|0.29390214307067797|
|0x4cb41c74dc86650...|11201278429100995...|   4.0| 4.1641793|            67|        496|2220093|          4052657|            0|         0.0|                0|0.29390214307067797|
|0x4cb41c74dc86650...|10200679145240549...|   4.0| 4.1641793|            67|        496| 370422|    

In [None]:
best_predictions.write.format("parquet").save("gs://st446-final-zqh/best_predictions_ca_10.parquet")

In [None]:
print("done")

done
