In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
from source.datasets.fast_datasets import *
from source.datasets.sound_transforms import *
from source.datasets.util_transforms import *
from source.models_base.mb_vggish import MusicBertVGGish

BERT = MusicBertVGGish(name="test", num_encoder_layers=12).cuda() # , num_encoder_layers=6
BERT.load_model("models/music_bert_audioset_12layers.pth")
BERT.eval()

# transform = Average_Pooling() # 
transform = Compose([BERT_Features(BERT), BERT_Pooling()])
# transform = Compose([BERT_Features(BERT), Average_Pooling()])
# transform = Compose([toVggishTorch(preprocess=False), Average_Pooling()])


genre_dataset = GTZANFastDataset(transform=transform)
emo_dataset = EmoMusicFastDataset(transform=transform)
ms_dataset = MusicSpeechFastDataset(transform=transform)

deezer_dataset = DeezerFastDataset(transform=transform, length=2000)
mtat_dataset = MTATFastDataset(transform=transform, length=5000)

In [9]:
from sklearn.svm import SVR, SVC
from sklearn.multioutput import MultiOutputRegressor
from sklearn.multiclass import OneVsRestClassifier
from source.evaluation import r2_score_raw
from sklearn.neural_network import MLPClassifier, MLPRegressor

tasks_db = [
    {
        "task_name": "GTZAN",
        "model": SVC(),
#         MLPClassifier(hidden_layer_sizes=(128,128),
#                               activation="relu",
#                               max_iter=500), # 
        "dataset": genre_dataset,
        "metric": "accuracy"
    },
    {
        "task_name": "EmoMusic",
        "model": MultiOutputRegressor(SVR()),
#         MLPRegressor(hidden_layer_sizes=(128,128),
#                               activation="relu",
#                               max_iter=500), # 
        "dataset": emo_dataset,
        "metric": r2_score_raw()
    },
    {
        "task_name": "Deezer",
        "model": MultiOutputRegressor(SVR()),
#         MLPRegressor(hidden_layer_sizes=(128,128),
#                               activation="relu",
#                               max_iter=500), # MultiOutputRegressor(SVR()),
        "dataset": deezer_dataset,
        "metric": "r2"
    },
    {
        "task_name": "MagnaTagATune",
        "model": OneVsRestClassifier(SVC()),
#         MLPRegressor(hidden_layer_sizes=(128,128),
#                               activation="relu",
#                               max_iter=500), # MultiOutputRegressor(SVR()),
        "dataset": mtat_dataset,
        "metric": ["roc_auc","recall_samples"]
    },
]

In [10]:
from source.evaluation import evaluate_on_task


for task in tasks_db:
    scores = evaluate_on_task(task['model'],
                              task['dataset'],
                              task['metric'],
                              k_fold=5)
    print(task['task_name'])
    for k,v in scores.items():
        print("%s %.2f ± %.2f %%" % (k, np.mean(v)*100, np.std(v)*100))



HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

GTZAN
test_score 79.50 ± 2.12 %


HBox(children=(IntProgress(value=0, max=744), HTML(value='')))

EmoMusic
test_arousal 67.05 ± 3.41 %
test_valence 49.68 ± 4.69 %
test_overall 58.37 ± 2.82 %


HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))

Deezer
test_score 15.11 ± 2.60 %


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

MagnaTagATune
test_roc_auc 83.41 ± 0.51 %
test_recall_samples 22.62 ± 1.01 %


In [None]:
GTZAN
test_score 78.90 ± 2.06 %
EmoMusic
test_arousal 65.90 ± 2.38 %
test_valence 47.02 ± 4.12 %
test_overall 56.46 ± 1.62 %
Deezer
test_score 15.57 ± 3.17 %
MagnaTagATune
test_roc_auc 83.56 ± 0.45 %
test_recall_samples 23.94 ± 0.41 %

In [None]:
MLP
test_valence 72.68 ± 3.97 %
test_arousal 59.30 ± 5.49 %


SVR
test_valence 70.44 ± 4.59 %
test_arousal 57.10 ± 2.98 %


EmoMusic
test_arousal 67.47 ± 1.48 %
test_valence 56.09 ± 6.56 %
test_overall 61.78 ± 2.68 %
Music/Speech
test_score 99.23 ± 1.54 %
Deezer
test_score 19.38 ± 2.83 %
MagnaTagATune
test_roc_auc 83.96 ± 0.29 %
test_recall_samples 26.11 ± 0.81 %

In [None]:
GTZAN
test_score 84.10 ± 1.24 %
EmoMusic
test_arousal 69.70 ± 1.78 %
test_valence 51.28 ± 2.75 %
test_overall 60.49 ± 1.35 %
Deezer
test_score 16.15 ± 3.54 %
MagnaTagATune
test_roc_auc 85.72 ± 0.43 %
test_recall_samples 24.04 ± 0.69 %