In [1]:
import wandb
import os
os.environ["WANDB_SILENT"] = "true"

import numpy as np
import pandas as pd

# import sys
# sys.path.insert(0, "../..")

%cd ..

from helpers import load_model
from model import GenreClassifier
from data.src.dataLoaders import Groove2Drum2BarDataset

import torch


from bokeh.palettes import inferno, Category20b
from bokeh.core.enums import MarkerType
from bokeh.plotting import figure, show, save
from bokeh.io import output_notebook, reset_output
# output_notebook()

In [2]:

down_sampled_ratio=None
# load dataset as torch.utils.data.Dataset
dataset = Groove2Drum2BarDataset(
    dataset_setting_json_path="data/dataset_json_settings/Balanced_6000_performed.json",
    subset_tag="test",
    max_len=32,
    tapped_voice_idx=2,
    collapse_tapped_sequence=True,
    num_voice_density_bins=3,
    num_tempo_bins=6,
    num_global_density_bins=7,
    augment_dataset=False,
    force_regenerate=False
)


In [3]:
from collections import Counter
dataset.genre_tags

# Download model, load and Serialize

In [4]:
epoch = "45" #"605"
version = 3
run_name = "driven-frost-24"

artifact_path = f"behzadhaki/GenreClassifier/model_epoch_{epoch}:v{version}"
epoch = artifact_path.split("model_epoch_")[-1].split(":")[0]

local_path = f"artifacts/model_epoch_{epoch}:v{version}/{run_name}.pth"
if not os.path.exists(local_path):
    print("Downloading artifact")
    run = wandb.init()
    artifact = run.use_artifact(artifact_path, type='model')
    artifact_dir = artifact.download()
    # rename {epoch}.pth to {run_name}.pth
    os.rename(os.path.join(artifact_dir, f"{epoch}.pth"), os.path.join(artifact_dir, f"{run_name}.pth"))
    print("Artifact downloaded to: ", artifact_dir)
else:
    print("Artifact already downloaded")
    artifact_dir = os.path.dirname(local_path)
    
print(os.path.join(artifact_dir, f"{run_name}.pth"))

In [5]:
print(os.path.join(artifact_dir, f"{run_name}.pth"))
model = load_model(os.path.join(artifact_dir, f"{run_name}.pth"), model_class=GenreClassifier)

import datetime

#model.predict(torch.randn(1, 32, 3))
model.serialize(save_folder=f"{run_name}", filename=f"Gen_{run_name}_{epoch}_serialized__{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")


# Run Inference

In [6]:
drum_patterns = dataset.output_grooves
target_genres = dataset.genre_targets
target_genres

In [7]:
predicted_genres, _ = model.predict(drum_patterns)

In [8]:
predicted_genres

In [11]:
# plot confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

cm = confusion_matrix(target_genres, predicted_genres)
cm = cm / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(6, 6))

labels = dataset.genre_tags
labels = [label.split("/")[0] for label in labels]
# font -> 5 also for x, y tick labels
sns.set(font_scale=0.8)
sns.heatmap(cm, annot=True, fmt=".2f", cmap="Blues", xticklabels=labels, yticklabels=labels, cbar=False, annot_kws={"size": 7})

In [12]:
# accuracy, precision, recall, f1
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(target_genres, predicted_genres) * 100
precision = precision_score(target_genres, predicted_genres, average="weighted") * 100
recall = recall_score(target_genres, predicted_genres, average="weighted") * 100
f1 = f1_score(target_genres, predicted_genres, average="weighted") * 100

print(f"Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}")