In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/home/khai/malist_project/piano-transcribe


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchaudio
from pathlib import Path
import numpy as np

from src.data.datasets import MAPSDataset
from src.data.audio import MadmomSpectrogram
from src.data.data_modules import MAPSDataModule
from sklearn.linear_model import SGDClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_multilabel_classification



In [3]:
audio_transform = MadmomSpectrogram(hop_length=441*4, sample_rate=16000)
mapsDataModule = MAPSDataModule(batch_size=4, 
                                max_steps=5, 
                                sample_rate=16000, 
                                audio_transform=audio_transform, 
                                lazy_loading=True,
                                hop_length = 441*4)

In [4]:
# setup data 
mapsDataModule.setup()

In [5]:
train_loader = mapsDataModule.train_dataloader()
validate_loader = mapsDataModule.val_dataloader()
test_loader = mapsDataModule.test_dataloader()

In [6]:
train_loader.dataset[0]['frames'].shape

torch.Size([5, 88])

In [7]:
clf = MultiOutputClassifier(SGDClassifier(loss='log'))
for i_batch,batch in enumerate(train_loader):
    batch_input = torch.reshape(batch['audio'], [4, 5*294])
    batch_output = batch['frames'][:,2,:]
    clf.partial_fit(batch_input.numpy(), batch_output.numpy().astype(np.int), classes=[np.array([0,1]) for i in range(88)])

In [9]:
accuracy = 0
for batch in validate_loader:
    batch_input = torch.reshape(batch['audio'], [batch['audio'].shape[0], 5*294])
    batch_output = batch['frames'][:,2,:]
    batch_pred = clf.predict(batch_input.numpy())
    accuracy += accuracy_score(batch_output.numpy().astype(int).flatten(), batch_pred.flatten())

print('Average accurary: ', accuracy/len(validate_loader))

Average accurary:  0.9094460227272728


In [10]:
svm_clf = MultiOutputClassifier(SGDClassifier(loss='hinge'))
for i_batch,batch in enumerate(train_loader):
    batch_input = torch.reshape(batch['audio'], [4, 5*294])
    batch_output = batch['frames'][:,2,:]
    svm_clf.partial_fit(batch_input.numpy(), batch_output.numpy().astype(np.int), classes=[np.array([0,1]) for i in range(88)])

In [11]:
accuracy = 0
for batch in validate_loader:
    batch_input = torch.reshape(batch['audio'], [batch['audio'].shape[0], 5*294])
    batch_output = batch['frames'][:,2,:]
    batch_pred = svm_clf.predict(batch_input.numpy())
    accuracy += accuracy_score(batch_output.numpy().astype(int).flatten(), batch_pred.flatten())

print('Average accurary: ', accuracy/len(validate_loader))

Average accurary:  0.9019886363636365
