In [2]:
import os
from pyspark.sql import SparkSession

os.environ['SPARK_HOME'] = '/usr/local/spark'

# Start Spark
spark = SparkSession.builder.appName("Million Song Mining").master("local[*]").getOrCreate()

# Check Spark version
print(spark.version)

3.5.2


In [3]:
path_to_dataset = "./../data/raw/kaggle_visible_evaluation_triplets.txt"

df = spark.read.csv(path_to_dataset, sep='\t', inferSchema=True, header=False)

df = df.withColumnRenamed("_c0", "user_id").withColumnRenamed("_c1", "song_id").withColumnRenamed("_c2", "play_count")

df.show(5)

[Stage 1:>                                                          (0 + 8) / 8]

+--------------------+------------------+----------+
|             user_id|           song_id|play_count|
+--------------------+------------------+----------+
|fd50c4007b68a3737...|SOBONKR12A58A7A7E0|         1|
|fd50c4007b68a3737...|SOEGIYH12A6D4FC0E3|         1|
|fd50c4007b68a3737...|SOFLJQZ12A6D4FADA6|         1|
|fd50c4007b68a3737...|SOHTKMO12AB01843B0|         1|
|fd50c4007b68a3737...|SODQZCY12A6D4F9D11|         1|
+--------------------+------------------+----------+
only showing top 5 rows



                                                                                

In [4]:
df.describe().show()



+-------+--------------------+------------------+------------------+
|summary|             user_id|           song_id|        play_count|
+-------+--------------------+------------------+------------------+
|  count|             1450933|           1450933|           1450933|
|   mean|                NULL|              NULL|3.1871492343202616|
| stddev|                NULL|              NULL| 7.051663619572663|
|    min|00007a02388c208ea...|SOAAAFI12A6D4F9C66|                 1|
|    max|ffff07d7d9bb187aa...|SOZZZWN12AF72A1E29|               923|
+-------+--------------------+------------------+------------------+



                                                                                

In [6]:
# Get small subset of users
subset_users = df.select('user_id').distinct().limit(1000)
df_subset = df.join(subset_users, on='user_id', how='inner')

subset_path = "./../data/subsets/small_subset.csv"
df_subset.write.csv(subset_path, header=True)


                                                                                

In [7]:
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

# Split into train, test sets (80, 20)
train_data, test_data = df_subset.randomSplit([0.8, 0.2])

# Initialize StringIndexers for user_id and song_id since they are both string type and need to be in int type later on
user_indexer = StringIndexer(inputCol="user_id", outputCol="user_index")
song_indexer = StringIndexer(inputCol="song_id", outputCol="song_index")

pipeline = Pipeline(stages=[user_indexer, song_indexer])
train_data = pipeline.fit(train_data).transform(train_data)
test_data = pipeline.fit(test_data).transform(test_data)
train_data.show(5)

als = ALS(userCol="user_index", itemCol="song_index", ratingCol="play_count", coldStartStrategy="drop")

model = als.fit(train_data)
predictions = model.transform(test_data)
predictions.show(5)

evaluator = RegressionEvaluator(metricName="rmse", labelCol="play_count", predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print(f"Root-mean-square error: {rmse}")

                                                                                

+--------------------+------------------+----------+----------+----------+
|             user_id|           song_id|play_count|user_index|song_index|
+--------------------+------------------+----------+----------+----------+
|0020258bdaf943abc...|SOHVNHU12A58A7C802|         5|     741.0|    3391.0|
|0020258bdaf943abc...|SOLOZDT12A58A78374|         3|     741.0|    4352.0|
|0020258bdaf943abc...|SOMMYNX12A58A7D9A6|         1|     741.0|     347.0|
|0020258bdaf943abc...|SOQUARI12A67ADA92C|         2|     741.0|    1015.0|
|0020258bdaf943abc...|SOTYJZR12A6D4F6A3E|         1|     741.0|    1119.0|
+--------------------+------------------+----------+----------+----------+
only showing top 5 rows



24/08/29 20:54:53 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/08/29 20:54:53 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
24/08/29 20:54:53 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK
                                                                                

+--------------------+------------------+----------+----------+----------+----------+
|             user_id|           song_id|play_count|user_index|song_index|prediction|
+--------------------+------------------+----------+----------+----------+----------+
|0020258bdaf943abc...|SOMOKND12A8C137D89|         2|     646.0|    1297.0|-1.0173073|
|005b1fab38cdeb9d5...|SOLUWYQ12A6D4FD0F8|         1|     647.0|    1237.0| -1.438634|
|0083526744e75be75...|SOHEMBR12B0B8075DF|         3|     447.0|     821.0|0.54257476|
|0083526744e75be75...|SOKRFOK12A67020A02|         5|     447.0|    1136.0|-0.1741548|
|0107197bd5e2521cf...|SOGZJZU12AF72A0D45|         1|     270.0|     809.0|-1.2573574|
+--------------------+------------------+----------+----------+----------+----------+
only showing top 5 rows



                                                                                

Root-mean-square error: 8.788313909700209
