In [118]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.inspection import permutation_importance
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.neighbors import KNeighborsClassifier

In [119]:
# importing data
data = pd.read_csv('csv_outputs/grouped_cleaned_spotify.csv')
data = data[~data['track_genre'].isin(['misc', 'world'])]

In [120]:
string_columns = ['track_id', 'artists', 'album_name', 'track_name']
categorical_columns = ['key', 'mode', 'time_signature']
response_column = 'track_genre'

Y = data[response_column]
X = data.drop(columns=[response_column, *string_columns])
X = pd.get_dummies(X, columns=[*categorical_columns])

In [121]:
X_train, X_test, y_train, y_test = train_test_split(
    X, Y,
    test_size=0.2,
    random_state=0,
)

In [122]:
dt_classifier = DecisionTreeClassifier(
    max_depth=50,
    random_state=42
)
dt_classifier.fit(X_train, y_train)

y_pred = dt_classifier.predict(X_test)

In [123]:
dt_accuracy = accuracy_score(y_test, y_pred)
dt_class_report = classification_report(y_test, y_pred, zero_division=1)
dt_conf_matrix = confusion_matrix(y_test, y_pred)

print(f'Accuracy: {dt_accuracy}')
print(f'Classification Report: {dt_class_report}')
print(f'Confusion Matrix: {dt_conf_matrix}')

Accuracy: 0.5762903225806452
Classification Report:               precision    recall  f1-score   support

   classical       0.60      0.60      0.60       589
  electronic       0.72      0.74      0.73      4001
        folk       0.58      0.66      0.62      1221
     hip-hop       0.24      0.26      0.25       581
        jazz       0.49      0.49      0.49       818
       metal       0.63      0.63      0.63      1187
         pop       0.45      0.42      0.43      1811
        rock       0.48      0.45      0.47      2192

    accuracy                           0.58     12400
   macro avg       0.52      0.53      0.53     12400
weighted avg       0.57      0.58      0.57     12400

Confusion Matrix: [[ 351   40   38    7   21   15   49   68]
 [  40 2946   92  181  109  130  252  251]
 [  24   80  810   31   30    9  111  126]
 [   5  177   24  151   39   12   93   80]
 [  15  126   46   45  400    9   81   96]
 [  14  116   11   19   10  744   70  203]
 [  47  331  172  106

In [124]:
# Get feature imporances
feature_importances = dt_classifier.feature_importances_
for i, importance in enumerate(feature_importances):
    print(f'{X.columns[i]}: {importance}')

popularity: 0.10723641472190588
duration_ms: 0.08033518832971319
explicit: 0.006726827600486028
danceability: 0.11867398677822003
energy: 0.0600976198522973
loudness: 0.06002875629254791
speechiness: 0.07295971433485196
acousticness: 0.14997048050153683
instrumentalness: 0.08419760526957365
liveness: 0.04790738436269943
valence: 0.0764604039858912
tempo: 0.06902511260195417
key_0: 0.00456147310421049
key_1: 0.0036282700546081486
key_2: 0.0044840873036299566
key_3: 0.0023910510913560186
key_4: 0.003272004883692366
key_5: 0.0045459056769435605
key_6: 0.0037749076682581702
key_7: 0.005204499236742661
key_8: 0.0036658762639377143
key_9: 0.005118755250564711
key_10: 0.0032319153945651853
key_11: 0.0038587108485179625
mode_0: 0.00636324132095331
mode_1: 0.005453388055446032
time_signature_0: 0.0
time_signature_1: 0.0008002945271944898
time_signature_3: 0.00276651766271166
time_signature_4: 0.002434162158207523
time_signature_5: 0.0008254448667823585


In [125]:
# permutation importance
perm_importance = permutation_importance(
    dt_classifier,
    X_test,
    y_test,
    n_repeats=20,
    random_state=42
)

perm_importance_df = pd.DataFrame({
    'Feature': X.columns,
    'Importance Mean': perm_importance.importances_mean,
    'Importance Std': perm_importance.importances_std
})

perm_importance_df

Unnamed: 0,Feature,Importance Mean,Importance Std
0,popularity,0.14129,0.00323
1,duration_ms,0.100565,0.002958
2,explicit,0.006984,0.000731
3,danceability,0.157681,0.003009
4,energy,0.069915,0.002218
5,loudness,0.063504,0.002068
6,speechiness,0.083835,0.002224
7,acousticness,0.192262,0.00274
8,instrumentalness,0.093395,0.00274
9,liveness,0.019577,0.00156


In [126]:
# Random forest
rf_classifier = RandomForestClassifier(
    n_estimators=100,
    max_depth=20,
    random_state=42
)
rf_classifier.fit(X_train, y_train)

In [127]:
# Make predictions
y_pred_rf = rf_classifier.predict(X_test)

In [128]:
# Evaluate
rf_accuracy = accuracy_score(y_test, y_pred_rf)
rf_class_report = classification_report(y_test, y_pred_rf, zero_division=1)
rf_conf_matrix = confusion_matrix(y_test, y_pred_rf)

print(f'Accuracy: {rf_accuracy}')
print(f'Classification Report: {rf_class_report}')
print(f'Confusion Matrix:')
rf_conf_matrix


Accuracy: 0.6670967741935484
Classification Report:               precision    recall  f1-score   support

   classical       0.77      0.69      0.73       589
  electronic       0.71      0.87      0.78      4001
        folk       0.73      0.69      0.71      1221
     hip-hop       0.45      0.13      0.20       581
        jazz       0.73      0.47      0.57       818
       metal       0.74      0.71      0.72      1187
         pop       0.56      0.52      0.54      1811
        rock       0.57      0.59      0.58      2192

    accuracy                           0.67     12400
   macro avg       0.66      0.58      0.60     12400
weighted avg       0.66      0.67      0.65     12400

Confusion Matrix:


array([[ 405,   33,   27,    0,   11,    6,   51,   56],
       [  16, 3497,   21,   24,   40,   73,  149,  181],
       [  15,   89,  838,    3,    7,    1,  140,  128],
       [   0,  302,   11,   74,   13,    6,   90,   85],
       [  12,  151,   26,   11,  383,    6,  109,  120],
       [   6,   99,    3,    0,    4,  841,   23,  211],
       [  13,  410,  103,   30,   43,   60,  943,  209],
       [  59,  342,  126,   22,   25,  144,  183, 1291]])

In [129]:
# permutation importance
perm_importance_rf = permutation_importance(
    rf_classifier,
    X_test,
    y_test,
    n_repeats=20,
    random_state=42,
)

perm_importance_rf_df = pd.DataFrame({
    'Feature': X.columns,
    'Importance Mean': perm_importance.importances_mean,
    'Importance Std': perm_importance.importances_std
})

perm_importance_rf_df

Unnamed: 0,Feature,Importance Mean,Importance Std
0,popularity,0.14129,0.00323
1,duration_ms,0.100565,0.002958
2,explicit,0.006984,0.000731
3,danceability,0.157681,0.003009
4,energy,0.069915,0.002218
5,loudness,0.063504,0.002068
6,speechiness,0.083835,0.002224
7,acousticness,0.192262,0.00274
8,instrumentalness,0.093395,0.00274
9,liveness,0.019577,0.00156


In [130]:
# KNN
knn_classifier = KNeighborsClassifier(n_neighbors=5)
knn_classifier.fit(X_train, y_train)

# Make predictions
y_pred_knn = knn_classifier.predict(X_test)

# Evaluate
knn_accuracy = accuracy_score(y_test, y_pred_knn)
knn_class_report = classification_report(y_test, y_pred_knn, zero_division=1)
knn_conf_matrix = confusion_matrix(y_test, y_pred_knn)

print(f'Accuracy: {knn_accuracy}')
print(f'Classification Report: {knn_class_report}')
print(f'Confusion Matrix:')
knn_conf_matrix

Accuracy: 0.3710483870967742
Classification Report:               precision    recall  f1-score   support

   classical       0.15      0.15      0.15       589
  electronic       0.44      0.67      0.53      4001
        folk       0.34      0.29      0.31      1221
     hip-hop       0.21      0.09      0.12       581
        jazz       0.52      0.34      0.41       818
       metal       0.20      0.14      0.16      1187
         pop       0.28      0.21      0.24      1811
        rock       0.36      0.28      0.31      2192

    accuracy                           0.37     12400
   macro avg       0.31      0.27      0.28     12400
weighted avg       0.35      0.37      0.35     12400

Confusion Matrix:


array([[  89,  252,   49,    9,   17,   42,   47,   84],
       [ 138, 2682,  172,   64,   76,  215,  325,  329],
       [  61,  467,  354,   13,   28,   50,  104,  144],
       [  25,  275,   45,   51,   12,   39,   73,   61],
       [  32,  261,   47,   10,  277,   30,   74,   87],
       [  96,  523,   74,   17,   15,  165,  124,  173],
       [  88,  787,  124,   34,   59,  128,  374,  217],
       [  78,  871,  171,   41,   49,  147,  226,  609]])