In [1]:
%load_ext autoreload
%autoreload 2

In [105]:
from source.datasets.fast_datasets import *
from source.datasets.sound_transforms import *
from source.datasets.util_transforms import *
from source.models_base.mb_vggish import MusicBertVGGish

BERT = MusicBertVGGish(name="test", num_encoder_layers=4).cuda() # , num_encoder_layers=6
BERT.load_model("models/mb_batch_negatives_JSD.pth")
BERT.eval()

# transform = Average_Pooling() # 
transform = Compose([BERT_Features(BERT), BERT_Pooling()])
# transform = Compose([BERT_Features(BERT), Average_Pooling()])
# transform = Compose([toVggishTorch(preprocess=False), Average_Pooling()])


genre_dataset = GTZANFastDataset(transform=transform)
emo_dataset = EmoMusicFastDataset(transform=transform)

deezer_dataset = DeezerFastDataset(transform=transform, length=2000)
mtat_dataset = MTATFastDataset(transform=transform, length=5000)

In [106]:
from sklearn.svm import SVR, SVC
from sklearn.multioutput import MultiOutputRegressor
from sklearn.multiclass import OneVsRestClassifier
from source.evaluation import r2_score_raw
from sklearn.neural_network import MLPClassifier, MLPRegressor

tasks_db = [
    {
        "task_name": "GTZAN",
        "model": SVC(),
#         MLPClassifier(hidden_layer_sizes=(128,128),
#                               activation="relu",
#                               max_iter=500), # 
        "dataset": genre_dataset,
        "metric": "accuracy"
    },
    {
        "task_name": "EmoMusic",
        "model": MultiOutputRegressor(SVR()),
#         MLPRegressor(hidden_layer_sizes=(128,128),
#                               activation="relu",
#                               max_iter=500), # 
        "dataset": emo_dataset,
        "metric": r2_score_raw()
    },
    {
        "task_name": "Deezer",
        "model": MultiOutputRegressor(SVR()),
#         MLPRegressor(hidden_layer_sizes=(128,128),
#                               activation="relu",
#                               max_iter=500), # MultiOutputRegressor(SVR()),
        "dataset": deezer_dataset,
        "metric": "r2"
    },
    {
        "task_name": "MagnaTagATune",
        "model": OneVsRestClassifier(SVC()),
#         MLPRegressor(hidden_layer_sizes=(128,128),
#                               activation="relu",
#                               max_iter=500), # MultiOutputRegressor(SVR()),
        "dataset": mtat_dataset,
        "metric": ["roc_auc", "recall_samples", "average_precision"]
    },
]

In [107]:
from source.evaluation import evaluate_on_task
from source.utils.generic_utils import allDone


for task in tasks_db:
    with torch.no_grad():
        scores = evaluate_on_task(task['model'],
                                  task['dataset'],
                                  task['metric'],
                                  k_fold=5)
    print(task['task_name'])
    for k,v in scores.items():
        print("%s %.2f ± %.2f %%" % (k, np.mean(v)*100, np.std(v)*100))

        
allDone()

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

GTZAN
test_score 83.00 ± 1.67 %


HBox(children=(FloatProgress(value=0.0, max=744.0), HTML(value='')))

EmoMusic
test_arousal 67.78 ± 2.88 %
test_valence 56.87 ± 3.14 %
test_overall 62.32 ± 0.49 %


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

Deezer
test_score 19.08 ± 2.78 %


HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))

MagnaTagATune
test_roc_auc 83.83 ± 0.46 %
test_recall_samples 24.36 ± 1.12 %
test_average_precision 35.22 ± 0.97 %


In [None]:
DV

GTZAN
test_score 73.40 ± 2.27 %
EmoMusic
test_arousal 24.44 ± 2.14 %
test_valence 15.22 ± 2.09 %
test_overall 19.83 ± 1.35 %
Deezer
test_score 9.98 ± 2.46 %
MagnaTagATune
test_roc_auc 84.41 ± 0.32 %
test_recall_samples 4.77 ± 0.26 %
test_average_precision 34.57 ± 1.28 %


full DV

GTZAN
test_score 73.80 ± 1.94 %
EmoMusic
test_arousal 25.08 ± 1.68 %
test_valence 17.27 ± 1.74 %
test_overall 21.18 ± 1.51 %
Deezer
test_score 11.87 ± 1.43 %
MagnaTagATune
test_roc_auc 84.45 ± 0.43 %
test_recall_samples 4.67 ± 0.25 %
test_average_precision 34.55 ± 0.42 %

In [None]:
infoNCE

GTZAN
test_score 75.80 ± 3.01 %
EmoMusic
test_arousal 47.60 ± 2.00 %
test_valence 32.20 ± 1.14 %
test_overall 39.90 ± 1.41 %
Deezer
test_score 13.02 ± 2.48 %
MagnaTagATune
test_roc_auc 84.95 ± 0.41 %
test_recall_samples 8.18 ± 0.52 %
test_average_precision 35.63 ± 0.51 %


full infoNCE

GTZAN
test_score 71.60 ± 0.97 %
EmoMusic
test_arousal 51.14 ± 1.99 %
test_valence 37.92 ± 3.71 %
test_overall 44.53 ± 1.97 %
Deezer
test_score 17.33 ± 0.92 %
MagnaTagATune
test_roc_auc 85.23 ± 0.42 %
test_recall_samples 14.11 ± 0.68 %
test_average_precision 35.43 ± 0.35 %

In [None]:
JSD

GTZAN
test_score 80.20 ± 2.36 %
EmoMusic
test_arousal 67.93 ± 2.68 %
test_valence 56.16 ± 4.38 %
test_overall 62.04 ± 3.15 %
Deezer
test_score 15.85 ± 1.97 %
MagnaTagATune
test_roc_auc 84.03 ± 0.49 %
test_recall_samples 24.77 ± 0.69 %
test_average_precision 36.62 ± 0.86 %


full JSD

GTZAN
test_score 81.90 ± 2.91 %
EmoMusic
test_arousal 65.43 ± 4.20 %
test_valence 50.63 ± 4.41 %
test_overall 58.03 ± 3.29 %
Deezer
test_score 16.29 ± 2.46 %
MagnaTagATune
test_roc_auc 83.79 ± 0.53 %
test_recall_samples 24.95 ± 1.21 %
test_average_precision 35.88 ± 1.17 %

In [99]:
from tqdm.auto import tqdm


def dataset_to_matrix(dataset):
    X, Y = [], []
    
    for sample in tqdm(dataset, leave = False):
        X.append(sample["song_features"])
        
        if sample.get("encoded_class") is not None:
            Y.append(sample['encoded_class'])
            classification = True
        else:
            Y.append(sample['target'])
            classification = False
        
    X, Y = np.array(X), np.array(Y).squeeze()
    
    return X, Y

X, Y = dataset_to_matrix(genre_dataset)

X.shape, Y.shape

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

((1000, 768), (1000,))

In [100]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
principalComponents = pca.fit_transform(X)

principalComponents.shape

(1000, 2)

In [101]:
%matplotlib widget

import matplotlib.pyplot as plt

fig = plt.figure(figsize = (8,8))
ax = fig.add_subplot(1,1,1) 
ax.set_xlabel('x', fontsize = 15)
ax.set_ylabel('y', fontsize = 15)
ax.set_title('2 Component PCA', fontsize = 20)


targets = genre_dataset.genres

for idx in range(10):
#     if idx in (0,1,2,3,5,7):
#         continue
    
    indicesToKeep = Y == idx
    ax.scatter(principalComponents[indicesToKeep, 0]
               , principalComponents[indicesToKeep, 1]
               , s = 40)
ax.legend(targets)
ax.grid()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …