In [1]:
# Import necessary libraries
from pyspark.sql import SparkSession
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# Create a Spark session
spark = SparkSession.builder \
    .appName("Item-Item Recommender System") \
    .getOrCreate()

# Load the data as a Spark DataFrame
csv_file_path = "data/ratings_small.csv"
data = spark.read.csv(csv_file_path, header=True, inferSchema=True)

# Split the data into a training set and a testing set
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

# Define the ALS algorithm for collaborative filtering
als = ALS(userCol="userId", itemCol="movieId", ratingCol="rating", coldStartStrategy="drop")

# Set up the hyperparameter grid
param_grid = ParamGridBuilder() \
    .addGrid(als.maxIter, [5, 10, 15]) \
    .addGrid(als.regParam, [0.01, 0.1, 0.5]) \
    .addGrid(als.rank, [10, 20, 50]) \
    .build()

# Define the evaluator for RMSE
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")

# Set up cross-validation
cross_validator = CrossValidator(
    estimator=als,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=5
)

# Perform cross-validation and get the best model
cv_model = cross_validator.fit(train_data)
best_model = cv_model.bestModel

# Make predictions on the test set
predictions = best_model.transform(test_data)

# Evaluate the model by calculating the RMSE (Root Mean Squared Error)
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = {:.4f}".format(rmse))

# Generate top 10 movie recommendations for each user
user_recs = best_model.recommendForAllUsers(10)
user_recs.show()

# After fitting the cross-validator and obtaining the best model
print("Best hyperparameters:")
print("  maxIter: {}".format(best_model._java_obj.parent().getMaxIter()))
print("  regParam: {:.2f}".format(best_model._java_obj.parent().getRegParam()))
print("  rank: {}".format(best_model.rank))

# Best hyperparameters:
#   maxIter: 15
#   regParam: 0.10
#   rank: 50

# Stop the Spark session
spark.stop()

23/03/23 22:13:15 WARN Utils: Your hostname, Martin-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.1.53 instead (on interface en0)
23/03/23 22:13:15 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/03/23 22:13:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


                                                                                

23/03/23 22:13:30 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
23/03/23 22:13:30 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS
23/03/23 22:13:30 WARN InstanceBuilder$JavaBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS


[Stage 12:>                                                        (0 + 8) / 10]

23/03/23 22:13:31 WARN InstanceBuilder$NativeLAPACK: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK


                                                                                

Root Mean Squared Error (RMSE) on test data = 0.9111


                                                                                

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|     1|[{1172, 3.4955099...|
|     3|[{27773, 4.371879...|
|     5|[{1948, 4.533952}...|
|     6|[{83318, 4.673204...|
|    12|[{3879, 4.9009957...|
|    13|[{88125, 4.13369}...|
|    15|[{4302, 4.893448}...|
|    16|[{318, 4.9723654}...|
|    19|[{83411, 5.161693...|
|    20|[{51471, 4.917141...|
|    22|[{67504, 4.292516...|
|    26|[{65514, 4.481471...|
|    27|[{318, 4.6969137}...|
|    28|[{1172, 5.132921}...|
|    31|[{83318, 4.783527...|
|    34|[{83318, 4.953843...|
|    37|[{68073, 5.136165...|
|    40|[{67504, 5.185649...|
|    41|[{83359, 4.917856...|
|    43|[{54328, 4.326411...|
+------+--------------------+
only showing top 20 rows

Best hyperparameters:
  maxIter: 15
  regParam: 0.10
  rank: 50
