### ML genre classification on FMA 

* use the 'medium' dataset to work with more training examplars
* note medium also only contains tracks where all tagged genres roll up to the same root genre

In [20]:
%matplotlib inline

import pandas as pd
import numpy as np
import scipy as sp
import IPython.display as ipd

import matplotlib.pyplot as plt

from sklearn.svm import SVC
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import (train_test_split, GridSearchCV, RandomizedSearchCV)
from sklearn.metrics import (classification_report, confusion_matrix, ConfusionMatrixDisplay, f1_score)
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import (RandomForestClassifier, AdaBoostClassifier)
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier

from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import HalvingRandomSearchCV, HalvingGridSearchCV

import utils

RANDOM_STATE = 53

In [21]:
(features, tracks) = utils.load_features()
features.shape, tracks.shape

((106574, 518), (106574, 52))

In [22]:
small = tracks[('set', 'subset')] == 'small'
medium = tracks[('set', 'subset')].isin(['small','medium'])
X = features[medium]
y = tracks[medium][('track','genre_top')]

print(X.shape, y.shape)


(25000, 518) (25000,)


#### examine genre_top breakdown

* dataset is very unbalanced

In [23]:
y.value_counts()

Rock                   7103
Electronic             6314
Experimental           2251
Hip-Hop                2201
Folk                   1519
Instrumental           1350
Pop                    1186
International          1018
Classical               619
Old-Time / Historic     510
Jazz                    384
Country                 178
Soul-RnB                154
Spoken                  118
Blues                    74
Easy Listening           21
Name: (track, genre_top), dtype: int64

#### arbitrarily eliminate bottom 3 genres due to lack of exemplars


In [24]:
prune = ~y.isin(['Spoken','Blues','Easy Listening'])
y = y[prune]
X = X[prune]

#### build training/test sets

In [25]:
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size=0.2,
                                                    random_state=RANDOM_STATE,
                                                    shuffle=True,
                                                    stratify=y)
scaler = MinMaxScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

#### Naively try basic ML classifiers

* all features
* default settings

NOTE: pretty slow!!!

In [26]:
classifiers = {'SVC': SVC(kernel='linear', random_state=RANDOM_STATE),
               'SVC-RBF': SVC(kernel='rbf', random_state=RANDOM_STATE),
               'LR' : LogisticRegression(random_state=RANDOM_STATE),
               'KNN' :KNeighborsClassifier()
              }

for (name, cl) in classifiers.items():
    cl.fit(X_train_scaled, y_train)
    y_pred = cl.predict(X_test_scaled)
    print(classification_report(y_test, y_pred))

                     precision    recall  f1-score   support

          Classical       0.75      0.83      0.79       124
            Country       0.60      0.09      0.15        35
         Electronic       0.65      0.82      0.72      1263
       Experimental       0.49      0.40      0.44       450
               Folk       0.64      0.63      0.63       304
            Hip-Hop       0.72      0.64      0.68       440
       Instrumental       0.57      0.44      0.50       270
      International       0.61      0.44      0.51       204
               Jazz       0.60      0.31      0.41        77
Old-Time / Historic       0.98      0.93      0.95       102
                Pop       0.44      0.07      0.12       237
               Rock       0.76      0.87      0.81      1421
           Soul-RnB       1.00      0.03      0.06        31

           accuracy                           0.68      4958
          macro avg       0.68      0.50      0.52      4958
       weighted avg   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

          Classical       0.74      0.80      0.77       124
            Country       0.00      0.00      0.00        35
         Electronic       0.63      0.85      0.72      1263
       Experimental       0.51      0.36      0.42       450
               Folk       0.61      0.59      0.60       304
            Hip-Hop       0.77      0.56      0.65       440
       Instrumental       0.57      0.42      0.49       270
      International       0.63      0.25      0.35       204
               Jazz       0.00      0.00      0.00        77
Old-Time / Historic       0.97      0.96      0.97       102
                Pop       0.00      0.00      0.00       237
               Rock       0.70      0.89      0.78      1421
           Soul-RnB       0.00      0.00      0.00        31

           accuracy                           0.66      4958
          macro avg       0.47      0.44      0.44      4958
       weighted avg   

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

          Classical       0.71      0.73      0.72       124
            Country       0.00      0.00      0.00        35
         Electronic       0.65      0.80      0.71      1263
       Experimental       0.47      0.33      0.39       450
               Folk       0.58      0.57      0.58       304
            Hip-Hop       0.65      0.62      0.63       440
       Instrumental       0.48      0.33      0.39       270
      International       0.47      0.23      0.31       204
               Jazz       0.70      0.09      0.16        77
Old-Time / Historic       0.87      0.95      0.91       102
                Pop       0.27      0.03      0.06       237
               Rock       0.69      0.88      0.77      1421
           Soul-RnB       0.00      0.00      0.00        31

           accuracy                           0.64      4958
          macro avg       0.50      0.43      0.43      4958
       weighted avg   

#### Repeat with class_weight=balanced to mitigate imbalance

In [8]:
classifiers = {'SVC': SVC(kernel='linear', class_weight='balanced', random_state=RANDOM_STATE),
               'SVC-RBF': SVC(kernel='rbf', class_weight='balanced', random_state=RANDOM_STATE),
               'LR' : LogisticRegression(class_weight='balanced', random_state=RANDOM_STATE)
              }

for (name, cl) in classifiers.items():
    cl.fit(X_train_scaled, y_train)
    y_pred = cl.predict(X_test_scaled)
    print(classification_report(y_test, y_pred))

                     precision    recall  f1-score   support

          Classical       0.67      0.83      0.74       124
            Country       0.16      0.60      0.26        35
         Electronic       0.77      0.57      0.66      1263
       Experimental       0.43      0.45      0.44       450
               Folk       0.60      0.64      0.62       304
            Hip-Hop       0.58      0.72      0.64       440
       Instrumental       0.39      0.59      0.47       270
      International       0.42      0.56      0.48       204
               Jazz       0.36      0.61      0.46        77
Old-Time / Historic       0.97      0.92      0.94       102
                Pop       0.21      0.29      0.24       237
               Rock       0.90      0.68      0.78      1421
           Soul-RnB       0.17      0.48      0.26        31

           accuracy                           0.61      4958
          macro avg       0.51      0.61      0.54      4958
       weighted avg   

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


#### Can hyperparameter tuning improve performance?

#### SVC-linear (slow!!!)

In [9]:
param_dist = {'C': sp.stats.loguniform(1e-2, 1e2)}

rsh = HalvingRandomSearchCV(SVC(kernel='linear',class_weight='balanced'), 
                            param_dist, 
                            scoring='f1_macro', 
                            random_state=RANDOM_STATE,
                            n_jobs=4 )

rsh.fit(X_train_scaled, y_train)
cl = rsh.best_estimator_
print(cl)

y_pred = cl.predict(X_test_scaled)
print(classification_report(y_test, y_pred))


SVC(C=2.693519333248338, class_weight='balanced', kernel='linear')
                     precision    recall  f1-score   support

          Classical       0.64      0.83      0.72       124
            Country       0.18      0.54      0.27        35
         Electronic       0.76      0.58      0.66      1263
       Experimental       0.41      0.43      0.42       450
               Folk       0.58      0.66      0.62       304
            Hip-Hop       0.56      0.71      0.63       440
       Instrumental       0.40      0.56      0.47       270
      International       0.40      0.54      0.46       204
               Jazz       0.38      0.60      0.46        77
Old-Time / Historic       0.98      0.91      0.94       102
                Pop       0.20      0.27      0.23       237
               Rock       0.90      0.69      0.78      1421
           Soul-RnB       0.19      0.35      0.24        31

           accuracy                           0.61      4958
          macro 

#### SVC-RBF with hyperparameter tuning (slow!!!)

In [10]:
param_dist = {'C': sp.stats.loguniform(1e-1, 1e1), 
              'gamma': sp.stats.loguniform(1e-4, 1e0)
}

rsh = HalvingRandomSearchCV(SVC(kernel='rbf',class_weight='balanced'), 
                            param_dist, 
                            scoring='f1_macro', 
                            random_state=RANDOM_STATE,
                            n_jobs=4 )
rsh.fit(X_train_scaled, y_train)
cl = rsh.best_estimator_
print(cl)

y_pred = cl.predict(X_test_scaled)
print(classification_report(y_test, y_pred))


SVC(C=2.412957842508776, class_weight='balanced', gamma=0.21965247640481667)
                     precision    recall  f1-score   support

          Classical       0.80      0.86      0.83       124
            Country       0.56      0.54      0.55        35
         Electronic       0.76      0.71      0.73      1263
       Experimental       0.48      0.58      0.53       450
               Folk       0.63      0.72      0.67       304
            Hip-Hop       0.62      0.74      0.67       440
       Instrumental       0.50      0.60      0.54       270
      International       0.59      0.62      0.61       204
               Jazz       0.65      0.51      0.57        77
Old-Time / Historic       1.00      0.98      0.99       102
                Pop       0.28      0.30      0.29       237
               Rock       0.88      0.76      0.82      1421
           Soul-RnB       0.64      0.29      0.40        31

           accuracy                           0.69      4958
      

#### Logistic regression - with hyperparameter tuning (slow!!!)

In [11]:
param_dist = {'max_iter': [5000],
              'C': sp.stats.loguniform(1e-1, 1e1) 
}

X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

rsh = HalvingRandomSearchCV(LogisticRegression(class_weight='balanced'),
                            param_dist, 
                            scoring='f1_macro', 
                            random_state=RANDOM_STATE,
                            n_jobs=4)

rsh.fit(X_train_scaled, y_train)
cl = rsh.best_estimator_
print(cl)

y_pred = cl.predict(X_test_scaled)
print(classification_report(y_test, y_pred))


LogisticRegression(C=2.641600609506842, class_weight='balanced', max_iter=5000)
                     precision    recall  f1-score   support

          Classical       0.67      0.84      0.75       124
            Country       0.14      0.66      0.23        35
         Electronic       0.80      0.51      0.62      1263
       Experimental       0.45      0.40      0.42       450
               Folk       0.57      0.62      0.59       304
            Hip-Hop       0.57      0.72      0.64       440
       Instrumental       0.38      0.50      0.43       270
      International       0.39      0.57      0.46       204
               Jazz       0.24      0.60      0.35        77
Old-Time / Historic       0.92      0.96      0.94       102
                Pop       0.22      0.31      0.26       237
               Rock       0.89      0.69      0.78      1421
           Soul-RnB       0.13      0.68      0.21        31

           accuracy                           0.59      4958
   

#### KNN - tuned

In [12]:
param_dist = {'n_neighbors': sp.stats.randint(2,50)}

rsh = HalvingRandomSearchCV(KNeighborsClassifier(),
                            param_dist, 
                            scoring='f1_macro', 
                            random_state=RANDOM_STATE,
                            n_jobs=4)

rsh.fit(X_train_scaled, y_train)
cl = rsh.best_estimator_
print(cl)

y_pred = cl.predict(X_test_scaled)
print(classification_report(y_test, y_pred))


KNeighborsClassifier(n_neighbors=7)
                     precision    recall  f1-score   support

          Classical       0.58      0.81      0.68       124
            Country       0.17      0.34      0.23        35
         Electronic       0.69      0.60      0.64      1263
       Experimental       0.48      0.23      0.31       450
               Folk       0.46      0.52      0.49       304
            Hip-Hop       0.49      0.57      0.53       440
       Instrumental       0.62      0.23      0.34       270
      International       0.43      0.47      0.45       204
               Jazz       0.43      0.38      0.40        77
Old-Time / Historic       0.89      0.98      0.93       102
                Pop       0.18      0.07      0.10       237
               Rock       0.63      0.86      0.73      1421
           Soul-RnB       0.57      0.13      0.21        31

           accuracy                           0.59      4958
          macro avg       0.51      0.48      0

#### Try all feature sets independently with SVC

In [13]:
feature_sets = features.columns.get_level_values(0).unique()

for fs in feature_sets:
    X_train_scaled = scaler.fit_transform(X_train[fs])
    X_test_scaled = scaler.transform(X_test[fs])
    classifier = SVC(kernel='rbf', class_weight='balanced', random_state=RANDOM_STATE)
    classifier.fit(X_train_scaled, y_train)
    y_pred = classifier.predict(X_test_scaled)
    print(fs)
    print(classification_report(y_test, y_pred))

chroma_cens
                     precision    recall  f1-score   support

          Classical       0.16      0.56      0.25       124
            Country       0.03      0.46      0.05        35
         Electronic       0.49      0.15      0.23      1263
       Experimental       0.25      0.21      0.23       450
               Folk       0.32      0.31      0.31       304
            Hip-Hop       0.22      0.35      0.27       440
       Instrumental       0.23      0.33      0.27       270
      International       0.15      0.07      0.09       204
               Jazz       0.10      0.42      0.16        77
Old-Time / Historic       0.11      0.45      0.17       102
                Pop       0.09      0.05      0.07       237
               Rock       0.59      0.22      0.32      1421
           Soul-RnB       0.03      0.29      0.05        31

           accuracy                           0.23      4958
          macro avg       0.21      0.30      0.19      4958
       wei

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


spectral_bandwidth
                     precision    recall  f1-score   support

          Classical       0.18      0.50      0.26       124
            Country       0.02      0.34      0.03        35
         Electronic       0.68      0.21      0.33      1263
       Experimental       0.26      0.03      0.06       450
               Folk       0.18      0.32      0.23       304
            Hip-Hop       0.22      0.37      0.28       440
       Instrumental       0.22      0.20      0.21       270
      International       0.19      0.15      0.16       204
               Jazz       0.03      0.09      0.04        77
Old-Time / Historic       0.31      0.66      0.42       102
                Pop       0.13      0.08      0.10       237
               Rock       0.67      0.18      0.29      1421
           Soul-RnB       0.02      0.48      0.04        31

           accuracy                           0.22      4958
          macro avg       0.24      0.28      0.19      4958
   