In [1]:
# Imports
import os # filenames
import pandas as pd  # dataframes

In [2]:
# Spark Imports
import findspark
findspark.init("/Users/elliot/spark")  # get spark here
from pyspark.sql import SparkSession  # session to run spark
from pyspark.sql.functions import udf, when  # user defined function, when
from pyspark.sql.types import *  # work with various types in the rdd
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator

In [3]:
# Clean Data Files
data_dir = "data"
data_files = os.listdir(data_dir)

# Open the artist alias file and convert it to a csv
# 587 lines of artist aliases
# All of them only have two items per line
# Artist_ID Artist_Alias_ID : Int Int
with open(os.path.join(data_dir,"artist_alias_small.txt"), 'r') as artist_alias_txt, open(os.path.join(data_dir,"alias_data.csv"), 'w') as alias_data_csv:
    alias_data_csv.write("artist_alias_id,artist_id\n")
    
    for line in artist_alias_txt:
        artist_alias_id, artist_id = line.split()
        alias_data_csv.write("{},{}\n".format(artist_alias_id, artist_id))
        
# Open the user artist data file and convert it to a csv   
# 49481 lines of user data
# All of them only have three items per line
# User_ID Artist_ID Song_Plays : Int Int Int
with open(os.path.join(data_dir,"user_artist_data_small.txt"), 'r') as user_artist_data_txt, open(os.path.join(data_dir,"user_data.csv"), 'w') as user_data_csv:
    user_data_csv.write("user_id,dirty_artist_id,artist_plays\n")
    
    for line in user_artist_data_txt:
        user_id, dirty_artist_id, artist_plays = line.split()
        user_data_csv.write("{},{},{}\n".format(user_id, dirty_artist_id, artist_plays))
        
        
# Open the artist data file and convert it to a csv 
# 30537 lines of artist data
# Any line could have an arbitrary number of spaced items
# Artist_ID Artist_Name : Int String
with open(os.path.join(data_dir,"artist_data_small.txt"), 'r') as artist_data_txt, open(os.path.join(data_dir,"artist_data.csv"), 'w') as artist_data_csv:
    artist_data_csv.write("artist_id,artist_name\n")
    
    for line in artist_data_txt:
        artist_data_line = line.split()
        artist_id = artist_data_line[0]
        
        artist_name_list = artist_data_line[1:]
        artist_name = ""
        for word in artist_name_list:
            artist_name += word + " "
            
        artist_name = artist_name[:-1]
        
        artist_data_csv.write("{},{}\n".format(artist_id, artist_name))

In [4]:
# Spark

# start spark session
spark = SparkSession.builder.appName("MusicRecommender").getOrCreate()

In [5]:
# Read data into the spark session
user_data = spark.read.csv(os.path.join(data_dir, "user_data.csv"), header=True) 
user_data.show()
user_data.describe().show()
user_data.select('user_id').distinct().describe().show()


artist_data = spark.read.csv(os.path.join(data_dir, "artist_data.csv"), header=True) 
artist_data.show()
artist_data.describe().show()

alias_data = spark.read.csv(os.path.join(data_dir, "alias_data.csv"), header=True)
alias_data.show()
alias_data.describe().show()

+-------+---------------+------------+
|user_id|dirty_artist_id|artist_plays|
+-------+---------------+------------+
|1059637|        1000010|         238|
|1059637|        1000049|           1|
|1059637|        1000056|           1|
|1059637|        1000062|          11|
|1059637|        1000094|           1|
|1059637|        1000112|         423|
|1059637|        1000113|           5|
|1059637|        1000114|           2|
|1059637|        1000123|           2|
|1059637|        1000130|       19129|
|1059637|        1000139|           4|
|1059637|        1000241|         188|
|1059637|        1000263|         180|
|1059637|        1000289|           2|
|1059637|        1000305|           1|
|1059637|        1000320|          21|
|1059637|        1000340|           1|
|1059637|        1000427|          20|
|1059637|        1000428|          12|
|1059637|        1000433|          10|
+-------+---------------+------------+
only showing top 20 rows

+-------+------------------+----------

In [6]:
# Replace the aliases with the real ids
user_data = user_data.join(alias_data, user_data.dirty_artist_id == alias_data.artist_alias_id, "left_outer")

user_data = user_data.withColumn("clean_artist_id", 
                                       when(user_data.artist_id.isNotNull(), user_data.artist_id)
                                       .otherwise(user_data.dirty_artist_id))

user_data = user_data['user_id', 'clean_artist_id', 'artist_plays']
user_data.show()

+-------+---------------+------------+
|user_id|clean_artist_id|artist_plays|
+-------+---------------+------------+
|1059637|        1000010|         238|
|1059637|        1000049|           1|
|1059637|        1000056|           1|
|1059637|        1000062|          11|
|1059637|        1000094|           1|
|1059637|        1000112|         423|
|1059637|        1000113|           5|
|1059637|        1000114|           2|
|1059637|        1000123|           2|
|1059637|        1000130|       19129|
|1059637|        1000139|           4|
|1059637|        1000241|         188|
|1059637|        1000263|         180|
|1059637|        1000289|           2|
|1059637|        1000305|           1|
|1059637|        1000320|          21|
|1059637|        1000340|           1|
|1059637|        1000427|          20|
|1059637|        1000428|          12|
|1059637|        1000433|          10|
+-------+---------------+------------+
only showing top 20 rows



In [7]:
# Add a column of the real names matching the artist ids
user_data = user_data.join(artist_data, user_data.clean_artist_id == artist_data.artist_id, "left_outer")

user_data = user_data['user_id', 'clean_artist_id', 'artist_name', 'artist_plays']
user_data.show()

+-------+---------------+--------------------+------------+
|user_id|clean_artist_id|         artist_name|artist_plays|
+-------+---------------+--------------------+------------+
|1059637|        1000010|           Aerosmith|         238|
|1059637|        1000049|     Edna's Goldfish|           1|
|1059637|        1000056|The Mighty Mighty...|           1|
|1059637|        1000062|        Foo Fighters|          11|
|1059637|        1000094|  The Bouncing Souls|           1|
|1059637|        1000112|       Alkaline Trio|         423|
|1059637|        1000113|         The Beatles|           5|
|1059637|        1000114|           Pennywise|           2|
|1059637|        1000123|             Incubus|           2|
|1059637|        1000130|         Bright Eyes|       19129|
|1059637|        1000139|                Muse|           4|
|1059637|        1000241|          Jason Mraz|         188|
|1059637|        1000263|     Jimmy Eat World|         180|
|1059637|        1000289|           Meat

In [8]:
# Clip unreasonably large song play values above
# Clip at 4800, which is 10 days exactly of listening to that song

user_data = user_data.withColumn("clipped_artist_plays", 
                                 when(user_data.artist_plays > 4800, 4800)
                                 .otherwise(user_data.artist_plays))

user_data = user_data['user_id', 'clean_artist_id', 'artist_name', 'clipped_artist_plays']
user_data.show()

+-------+---------------+--------------------+--------------------+
|user_id|clean_artist_id|         artist_name|clipped_artist_plays|
+-------+---------------+--------------------+--------------------+
|1059637|        1000010|           Aerosmith|                 238|
|1059637|        1000049|     Edna's Goldfish|                   1|
|1059637|        1000056|The Mighty Mighty...|                   1|
|1059637|        1000062|        Foo Fighters|                  11|
|1059637|        1000094|  The Bouncing Souls|                   1|
|1059637|        1000112|       Alkaline Trio|                 423|
|1059637|        1000113|         The Beatles|                   5|
|1059637|        1000114|           Pennywise|                   2|
|1059637|        1000123|             Incubus|                   2|
|1059637|        1000130|         Bright Eyes|                4800|
|1059637|        1000139|                Muse|                   4|
|1059637|        1000241|          Jason Mraz|  

In [9]:
# Train the als model 
# Rank = 10
# Iterations = 10
# Lambda = 0.01
# Alpha = 1.0

user_data=user_data.withColumn("user_id", user_data["user_id"].cast(IntegerType()))
user_data=user_data.withColumn("clean_artist_id", user_data["clean_artist_id"].cast(IntegerType()))
user_data=user_data.withColumn("clipped_artist_plays", user_data["clipped_artist_plays"].cast(IntegerType()))

als = ALS(rank=10, maxIter=10, regParam=0.01, alpha=1.0, userCol="user_id", itemCol="clean_artist_id", ratingCol="clipped_artist_plays")
model = als.fit(user_data)

In [10]:
# Spot check the als model
predictions = model.transform(user_data.where(user_data.user_id == 2064012))
predictions.show()

predictions = model.transform(user_data.where(user_data.user_id == 1000647))
predictions.show()

predictions = model.transform(user_data.where(user_data.user_id == 2023686))
predictions.show()

+-------+---------------+--------------------+--------------------+----------+
|user_id|clean_artist_id|         artist_name|clipped_artist_plays|prediction|
+-------+---------------+--------------------+--------------------+----------+
|2064012|        1058104|        Gwen Stefani|                  37| 20.766594|
|2064012|        1053277|         State Radio|                1093| 1138.0634|
|2064012|           4531| "Weird Al" Yankovic|                2847|  2941.663|
|2064012|        1281902|           Trenthian|                  75|  78.17719|
|2064012|            976|             Nirvana|                4800| 4572.5654|
|2064012|        1007993|         Denis Leary|                1655| 1709.3997|
|2064012|        1001909|            Interpol|                 988|  928.7781|
|2064012|           1307|   The White Stripes|                 120| 194.09766|
|2064012|           4267|           Green Day|                 215| 200.68958|
|2064012|        1000183|           Disturbed|      

In [None]:
als = ALS(userCol="user_id", itemCol="clean_artist_id", ratingCol="clipped_artist_plays", coldStartStrategy="drop")

paramGrid = ParamGridBuilder()\
            .addGrid(als.rank, [10, 20, 30, 40, 50, 60, 70])\
            .addGrid(als.maxIter, [10])\
            .addGrid(als.regParam, [1, 0.1, 0.01, 0.001, 0.0001, 0.00001])\
            .addGrid(als.alpha, [0.1, 0.4, 1.0, 4.0, 10.0, 40.0, 100.0])\
            .build()
            
# [10, 20, 30, 40, 50, 60, 70] [10, 50]
# [10]
# [1, 0.1, 0.01, 0.001, 0.0001, 0.00001] [1, 0.01, 0.0001]
# [0.1, 0.4, 1.0, 4.0, 10.0, 40.0, 100.0] [1.0, 40.0]
    
mse = RegressionEvaluator(metricName="mse", labelCol="clipped_artist_plays", predictionCol="prediction", )

cross_validator = CrossValidator(estimator=als, estimatorParamMaps=paramGrid, evaluator=mse, numFolds=10)

models = cross_validator.fit(user_data)
print(models.avgMetrics)
print(models.extractParamMap())

In [None]:
spark.stop()