In [1]:
import wandb
import os
os.environ["WANDB_SILENT"] = "true"

import numpy as np
import pandas as pd

# import sys
# sys.path.insert(0, "../..")

%cd ..

from helpers import load_model
from model import BaseVAE, GenreClassifier, MuteVAE
from data.src.dataLoaders import Groove2Drum2BarDataset

import torch

from bokeh.palettes import inferno, Category20b
from bokeh.core.enums import MarkerType
from bokeh.plotting import figure, show, save
from bokeh.io import output_notebook, reset_output
# output_notebook()


# Description

In this notebook, we are evaluating the GenreClassification using samples not in the training. In fact, we are not using GrooveMIDI at all here

In [2]:

down_sampled_ratio=None

# test set
dataset = Groove2Drum2BarDataset(
    dataset_setting_json_path="data/dataset_json_settings/Balanced_6000_performed.json",
    subset_tag="test",
    max_len=32,
    tapped_voice_idx=2,
    collapse_tapped_sequence=True,
    num_voice_density_bins=3,
    num_tempo_bins=6,
    num_global_density_bins=7,
    augment_dataset=False,
    force_regenerate=False
)

# # no GMD
# dataset = Groove2Drum2BarDataset(
#     dataset_setting_json_path="data/dataset_json_settings/Balanced_1000_performed_no_GMD.json",
#     subset_tag="test",
#     max_len=32,
#     tapped_voice_idx=2,
#     collapse_tapped_sequence=True,
#     num_voice_density_bins=3,
#     num_tempo_bins=6,
#     num_global_density_bins=7,
#     augment_dataset=False,
#     force_regenerate=False
# )

# access GenreClassifier and BaseVAE models

In [3]:
from helpers import download_model_from_wandb, predict_using_model, load_model
    
# download_model_from_wandb("45", 3, "driven-frost-24", GenreClassifier, new_path="./trained_models/genre_classifier.pth")
# download_model_from_wandb("155", 1, "lively-pond-9", BaseVAE, new_path="./trained_models/base_vae_beta_0_2.pth")
# download_model_from_wandb("405", 0, "polished-pyramid-1", MuteVAE, new_path="./trained_models/mute_vae_beta_0_2.pth")

model_classifier = load_model("./trained_models/genre_classifier.pth", GenreClassifier)
model_BaseVAE = load_model("./trained_models/base_vae_beta_0_5.pth", BaseVAE)# load_model("./trained_models/base_vae_beta_0_2.pth", BaseVAE)
model_MuteVAE = load_model("./trained_models/mute_vae_beta_0_2.pth", MuteVAE)

# model_MuteVAE


# model.serialize(save_folder=f"{run_name}", filename=f"Gen_{run_name}_{epoch}_serialized__{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")

# Step 1 - Confusion Matrix of Classifier on the evaluation set

In [4]:
from helpers import classifier_confusion_matrix, plot_confusion_matrix

In [5]:
cm, scores, labels = classifier_confusion_matrix(dataset, model_classifier)

In [6]:
plot_confusion_matrix(cm, labels)

# Generate Patterns using BaseVAE and Re-evaluate the Confusion Matrix

groove -> BaseVAE -> drum_pattern -> GenreClassifier -> genre_prediction <---- target_genre

In [7]:

def predict_using_batch_data_base_vae(dataset_, model_=model_BaseVAE):
    model_.eval()
    flat_hvo_groove = dataset_.input_grooves
    with torch.no_grad():
        hvo, latent_z = model_.predict(flat_hvo_groove=flat_hvo_groove)
    return hvo, latent_z



In [None]:
# generate drum patterns using BaseVAE
hvo, latent_z = predict_using_batch_data_base_vae(dataset, model_BaseVAE)

In [None]:
# predict genre using GenreClassifier
drum_patterns = hvo
target_genres = dataset.genre_targets
predicted_genres, _ = model_classifier.predict(drum_patterns)


In [None]:
# plot confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

cm = confusion_matrix(target_genres, predicted_genres)
cm = cm / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(6, 6))

labels = dataset.genre_tags
labels = [label.split("/")[0] for label in labels]
# font -> 5 also for x, y tick labels
sns.set(font_scale=0.8)
sns.heatmap(cm, annot=True, fmt=".2f", cmap="Blues", xticklabels=labels, yticklabels=labels, cbar=False, annot_kws={"size": 9})

In [None]:
# accuracy, precision, recall, f1
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(target_genres, predicted_genres)
precision = precision_score(target_genres, predicted_genres, average="weighted")
recall = recall_score(target_genres, predicted_genres, average="weighted")
f1 = f1_score(target_genres, predicted_genres, average="weighted")

print(f"Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")

# Generate Patterns using MuteVAE and Re-evaluate the Confusion Matrix

groove -> MuteVAE -> drum_pattern -> GenreClassifier -> genre_prediction <---- target_genre

In [None]:
def predict_using_batch_data_mute_vae(dataset_, model_=model_MuteVAE):
    model_.eval()
    flat_hvo_groove = dataset_.input_grooves
    kick_is_muted = dataset_.kick_is_muted
    snare_is_muted = dataset_.snare_is_muted
    hat_is_muted = dataset_.hat_is_muted
    tom_is_muted = dataset_.tom_is_muted
    cymbal_is_muted = dataset_.cymbal_is_muted
    
    with torch.no_grad():
        hvo, latent_z = model_.predict(
            flat_hvo_groove=flat_hvo_groove,
            kick_is_muted=kick_is_muted,
            snare_is_muted=snare_is_muted,
            hat_is_muted=hat_is_muted,
            tom_is_muted=tom_is_muted,
            cymbal_is_muted=cymbal_is_muted)
    return hvo, latent_z



In [None]:
# generate drum patterns using MuteVAE
hvo, latent_z = predict_using_batch_data_mute_vae(dataset, model_MuteVAE)


In [None]:
# predict genre using GenreClassifier
drum_patterns = hvo
target_genres = dataset.genre_targets
predicted_genres, _ = model_classifier.predict(drum_patterns)


In [None]:
# plot confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
    
cm = confusion_matrix(target_genres, predicted_genres)
cm = cm / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(6, 6))

labels = dataset.genre_tags
labels = [label.split("/")[0] for label in labels]
# font -> 5 also for x, y tick labels
sns.set(font_scale=0.8)
sns.heatmap(cm, annot=True, fmt=".2f", cmap="Blues", xticklabels=labels, yticklabels=labels, cbar=False, annot_kws={"size": 9})

In [None]:
# accuracy, precision, recall, f1
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(target_genres, predicted_genres)
precision = precision_score(target_genres, predicted_genres, average="weighted")
recall = recall_score(target_genres, predicted_genres, average="weighted")
f1 = f1_score(target_genres, predicted_genres, average="weighted")

print(f"Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")