In [1]:
# read data using pyspark
import pandas as pd
import plotly.express as px
from pyspark.sql import SparkSession
spark= SparkSession.builder.appName("BD").getOrCreate()

# read data
dev = spark.read.csv("../data/dev.csv", header=True, inferSchema=True)
test = spark.read.csv("../data/test.csv", header=True, inferSchema=True)
label = ["music_genre"]
categorical_features = ['key', 'mode']
numerical_features = ['popularity','acousticness', 'danceability', 'duration_ms', 'energy', 'instrumentalness', 'liveness', 'loudness', 'speechiness', 'tempo', 'valence']

In [9]:
# naive approach to clasification (baseline)
from pyspark.ml.feature import OneHotEncoder, VectorAssembler, StringIndexer
from pyspark.ml import Pipeline

from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import RandomForestClassifier

# preprocessing
indexer = StringIndexer(inputCols=categorical_features + label, outputCols=[col + "_index" for col in categorical_features + label])
ohe = OneHotEncoder(inputCols=[col + "_index" for col in categorical_features], outputCols=[col + "_ohe" for col in categorical_features])
assembler = VectorAssembler(inputCols=[col + "_ohe" for col in categorical_features] + numerical_features, outputCol="features")

# cross validation
estimator = RandomForestClassifier(labelCol="music_genre_index", featuresCol="features")
estimator_params = ParamGridBuilder().addGrid(RandomForestClassifier.numTrees, [10, 20, 30]).build()
evaluator = MulticlassClassificationEvaluator(labelCol="music_genre_index", predictionCol="prediction", metricName="f1")
cross = CrossValidator(estimator=estimator, estimatorParamMaps=estimator_params, evaluator=evaluator, numFolds=10)

# build and train the pipeline
pipeline = Pipeline(stages=[indexer, ohe, assembler, cross])
model = pipeline.fit(dev)

In [10]:

# evaluate the model
predictions = model.transform(test)
f1 = evaluator.evaluate(predictions)
print("F1 score: ", f1)


F1 score:  0.4731085611692445
