In [1]:
sc.setLogLevel("ERROR")

In [2]:
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator

In [3]:
# read parquet files
hi_train_100 = spark.read.parquet("gs://st446-final-zqh/train_100.parquet")
hi_test_100 = spark.read.parquet("gs://st446-final-zqh/test_100.parquet")

In [4]:
hi_train_100.show(5)

+--------------------+--------------------+------+-----------------+--------------+-------------+-----------+--------+
|            business|                user|rating|       avg_rating|num_of_reviews|rating_binary|business_id| user_id|
+--------------------+--------------------+------+-----------------+--------------+-------------+-----------+--------+
|0x0:0x9edcb14b0cf...|10000821970394167...|   5.0|4.199999809265137|           278|            1|      980.0| 21195.0|
|0x0:0x9edcb14b0cf...|10001835724049761...|   5.0|4.199999809265137|           278|            1|      980.0|176419.0|
|0x0:0x9edcb14b0cf...|10016550797840624...|   5.0|4.199999809265137|           278|            1|      980.0|179121.0|
|0x0:0x9edcb14b0cf...|10017741322643727...|   5.0|4.199999809265137|           278|            1|      980.0|179337.0|
|0x0:0x9edcb14b0cf...|10020832873317346...|   1.0|4.199999809265137|           278|            0|      980.0|179910.0|
+--------------------+--------------------+-----

In [5]:
hi_train_100.count()

1121988

In [6]:
hi_test_100.count()

124829

In [7]:
def grid_search(train_data, test_data):
    best_rmse = float("inf")
    best_model = None
    best_predictions = None
    best_rank = 0
    best_regParam = 0
    best_maxIter = 0
    for rank in [60]:
        for regParam in [0.1]:
            for maxIter in [20]:
                als = ALS(rank=rank, maxIter=maxIter, regParam=regParam, userCol="user_id", itemCol="business_id", ratingCol="rating_binary", coldStartStrategy="drop")
                model = als.fit(train_data)
                predictions = model.transform(test_data)
                evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating_binary", predictionCol="prediction")
                rmse = evaluator.evaluate(predictions)
                if rmse < best_rmse:
                    best_rmse = rmse
                    best_model = model
                    best_predictions = predictions
                    best_rank = rank
                    best_regParam = regParam
                    best_maxIter = maxIter
                print(f"rank = {rank}, regParam = {regParam}, maxIter = {maxIter}, RMSE = {rmse}")
    return best_model,best_rmse,best_predictions,best_rank,best_regParam,best_maxIter

In [None]:
best_model,best_rmse,best_predictions,best_rank,best_regParam,best_maxIter = grid_search(hi_train_100, hi_test_100)
print("The best rank is: ", best_rank)
print("The best regParam is: ", best_regParam)
print("The best maxIter is: ", best_maxIter)
print("The best rmse is: ", best_rmse)

rank = 60, regParam = 0.1, maxIter = 20, RMSE = 0.4645232982027052
The best rank is:  60
The best regParam is:  0.1
The best maxIter is:  20
The best rmse is:  0.4645232982027052


In [None]:
best_predictions.show(5)

+--------------------+--------------------+------+-----------------+--------------+-------------+-----------+--------+----------+
|            business|                user|rating|       avg_rating|num_of_reviews|rating_binary|business_id| user_id|prediction|
+--------------------+--------------------+------+-----------------+--------------+-------------+-----------+--------+----------+
|0x7c006af9104d777...|11521214871484946...|   5.0|4.900000095367432|           811|            1|      148.0| 65867.0|0.81767297|
|0x7c006af9104d777...|11204325608548318...|   5.0|4.900000095367432|           811|            1|      148.0|150164.0| 0.6109169|
|0x7c006af9104d777...|11822647253417566...|   5.0|4.900000095367432|           811|            1|      148.0|101202.0|0.72297376|
|0x7c006af9104d777...|10587568845938533...|   5.0|4.900000095367432|           811|            1|      148.0|125107.0| 0.6094526|
|0x7c006af9104d777...|11086027296978629...|   5.0|4.900000095367432|           811|       

In [None]:
best_predictions.write.format("parquet").save("gs://st446-final-zqh/best_predictions_100.parquet")

In [None]:
print("done")

done
