In [1]:
import glob
import csv
import sys, os.path
import pandas as pd
import sklearn
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision import transforms
#from torch.utils.tensorboard import SummaryWriter
#from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 1000

In [2]:
# Our modules
import sys
sys.path.append('.')
sys.path.append('..')

from vae import configs, train, plot_utils, models
from vae.data import build_dataloader
from vae.latent_spaces import dimensionality_reduction, plot_spaces
from vae.reconstructions import plot_reconstructions
from vae.models import model_utils

In [3]:
def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    
    acc = torch.round(acc * 100)
    
    return acc

In [4]:
classes = ['violin', 'viola', 'cello', 'double-bass',
                'clarinet', 'bass-clarinet', 'saxophone', 'flute', 'oboe', 'bassoon', 'contrabassoon',
                'french-horn', 'trombone', 'trumpet', 'tuba', 'english-horn',
                'guitar', 'mandolin', 'banjo', 'chromatic-percussion']

chromatic_perc = ['agogo-bells', 'banana-shaker', 'bass-drum', 'bell-tree', 'cabasa', 'Chinese-hand-cymbals',
                        'castanets', 'Chinese-cymbal', 'clash-cymbals', 'cowbell', 'djembe', 'djundjun', 'flexatone', 'guiro',
                        'lemon-shaker',  'motor-horn',  'ratchet', 'sheeps-toenails', 'sizzle-cymbal', 'sleigh-bells', 'snare-drum',
                        'spring-coil', 'squeaker', 'strawberry-shaker', 'surdo', 'suspended-cymbal', 'swanee-whistle',
                        'tambourine', 'tam-tam', 'tenor-drum', 'Thai-gong', 'tom-toms', 'train-whistle', 'triangle',
                        'vibraslap', 'washboard', 'whip', 'wind-chimes', 'woodblock', 'cor-anglais']

labels_list = [i for i in range(len(classes))]

In [5]:
# Create dict of classes
classes_dict = {classes[i]: classes.index(classes[i]) for i in range(len(classes))}


df = pd.DataFrame.from_dict(classes_dict, orient='index', columns=['class'])

In [6]:
df

Unnamed: 0,class
violin,0
viola,1
cello,2
double-bass,3
clarinet,4
bass-clarinet,5
saxophone,6
flute,7
oboe,8
bassoon,9


In [20]:
model_name = 'supervised_timbre'
input = 'mel_cut'
trained_epochs = 320

test_dataset, test_dataloader = build_dataloader.build_testset(input, model_name)
print('Number of files in the training dataset:', len(test_dataset))

Number of files in the training dataset: 1357


In [21]:
model = model_utils.import_model(model_name, input)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
checkpoint = torch.load(os.path.join(configs.ParamsConfig.TRAINED_MODELS_PATH, 'saved_model_' + str(trained_epochs) + "epochs.pth"))
model.load_state_dict(checkpoint['model'])
with torch.no_grad():
    model.eval()
    
    y_pred_list = []
    y_test_list = []
    for sample_batch, file, y in test_dataloader:
            sample_batch = sample_batch.to(device, dtype=torch.float)
            y_test_list.append(y)
            y_test_pred, w = model(sample_batch)
            _, y_pred_tags = torch.max(y_test_pred, dim = 1)
            y_pred_list.append(y_pred_tags.cpu().numpy())

y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
y_test = [a.squeeze().tolist() for a in y_test_list]

In [22]:
model_utils.show_total_params(model)

Number of parameters: 122388
AttentionTimbreEncoder(
  (multihead_attn): MultiheadAttention(
    (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
  )
  (fc_1): Linear(in_features=2816, out_features=20, bias=True)
)


In [23]:
assert y_pred_list != y_test

In [24]:
from sklearn.metrics import  classification_report
print(classification_report(y_test, y_pred_list))

              precision    recall  f1-score   support

           0       0.43      0.65      0.51       150
           1       0.78      0.70      0.74        97
           2       0.56      0.46      0.51        89
           3       0.79      0.59      0.68        85
           4       0.63      0.56      0.60        85
           5       0.81      0.49      0.61        94
           6       0.46      0.33      0.38        73
           7       0.66      0.76      0.71        88
           8       0.68      0.87      0.76        60
           9       0.88      0.53      0.66        72
          10       0.90      0.51      0.65        71
          11       0.62      0.69      0.65        65
          12       0.35      0.64      0.45        83
          13       0.82      0.77      0.80        48
          14       0.76      0.81      0.79        97
          15       0.69      0.59      0.64        69
          16       1.00      0.73      0.84        11
          17       0.67    

In [12]:
h8_report = classification_report(y_test, y_pred_list, output_dict=True)

In [13]:
h8_report['0']['f1-score']

0.5132275132275133

In [14]:
import seaborn as sns

for key in h8_report.keys():
    try:
        key = int(key)
        sns.pointplot(x=h8_report[key]['f1-score'], y=key)
        plt.show()
    except:
        continue

## References

* https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab